In [1]:
import torch
from torchvision import datasets
from torchvision.models.optical_flow.raft import ResidualBlock
from torchvision.transforms import ToTensor
from torch import nn
from torch import Tensor
import time
import copy
from torch.utils.data import DataLoader
import pandas as pd

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Network definition

In [3]:
class NetA1(nn.Module):
    def __init__(self, num_classes: int):
        super(NetA1, self).__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=3, stride=2)
        self.flatten = nn.Flatten(start_dim=-3)
        self.linear1 = nn.Linear(676, num_classes)
        self.softmax = nn.Softmax(dim=0)
            
    def freeze(self, layer: str):
        for param in getattr(self, layer).parameters():
                param.requires_grad = False

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.flatten(x)
        x = self.relu(self.linear1(x))
        x = self.softmax(x)
        return x

In [4]:
class NetA2(nn.Module):
    def __init__(self, num_classes: int):
        super(NetA2, self).__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=3, stride=2)
        self.conv2 = nn.Conv2d(in_channels=4, out_channels=6, kernel_size=3, stride=2)
        self.flatten = nn.Flatten(start_dim=-3)
        self.linear1 = nn.Linear(216, 260)
        self.linear2 = nn.Linear(260, 160)
        self.linear3 = nn.Linear(160, num_classes)
        self.softmax = nn.Softmax(dim=0)
        
    def freeze(self, layer: str):
        for param in getattr(self, layer).parameters():
                param.requires_grad = False

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.flatten(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.relu(self.linear3(x))
        x = self.softmax(x)
        return x

### Network initialization

In [5]:
initialization_weights = torch.tensor([
    [[[1, 0, 1], [0, 1, 0], [1, 0, 1]]],
    [[[1, 1, 0], [0, 0, 1], [1, 1, 0]]],
    [[[0, 1, 1], [1, 0, 0], [0, 1, 1]]],
    [[[0, 1, 0], [1, 1, 0], [0, 1, 0]]]
            ], dtype=torch.float32)

initialization_biases = torch.tensor([0,0,0,0], dtype=torch.float32)
initialization_weights.shape

torch.Size([4, 1, 3, 3])

In [6]:
net_a2_hf = NetA2(10)
net_a2_ht = NetA2(10)
net_a2_dt = NetA2(10)

#set conv1 initialization of net_a2_hf
net_a2_hf.conv1.weight = nn.Parameter(copy.deepcopy(initialization_weights))
net_a2_hf.conv1.bias = nn.Parameter(copy.deepcopy(initialization_biases))

# set same weights and bias to each layer of each network
net_a2_ht.load_state_dict(net_a2_hf.state_dict())
for name, param in net_a2_hf.state_dict().items():
    if "conv1" not in name:
        net_a2_dt.state_dict()[name].copy_(param)

#set conv1 initialization
#net_a2_dt.conv1.load_state_dict(net_a1_dt.conv1.state_dict())

#freeze conv1 layer of net_a2_hf
net_a2_hf.freeze("conv1")

#save weights and bias of nat_a1_h* and net_a1_dt
torch.save({'initialization': net_a2_hf.state_dict()}, 'NetA2HF_init.pt')
torch.save({'initialization': net_a2_ht.state_dict()}, 'NetA2HT_init.pt')
torch.save({'initialization': net_a2_dt.state_dict()}, 'NetA2DT_init.pt')


# print weights and bias
print("Net_A1_HF: \n \t", net_a2_hf.state_dict())
print("Net_A1_HT: \n \t", net_a2_ht.state_dict())
print("Net_A1_DT: \n \t", net_a2_dt.state_dict())

Net_A1_HF: 
 	 OrderedDict([('conv1.weight', tensor([[[[1., 0., 1.],
          [0., 1., 0.],
          [1., 0., 1.]]],


        [[[1., 1., 0.],
          [0., 0., 1.],
          [1., 1., 0.]]],


        [[[0., 1., 1.],
          [1., 0., 0.],
          [0., 1., 1.]]],


        [[[0., 1., 0.],
          [1., 1., 0.],
          [0., 1., 0.]]]])), ('conv1.bias', tensor([0., 0., 0., 0.])), ('conv2.weight', tensor([[[[ 0.1285,  0.1293, -0.1389],
          [-0.0664,  0.0077, -0.1630],
          [-0.0279,  0.0455,  0.1343]],

         [[ 0.1068, -0.1197,  0.0090],
          [-0.0064, -0.0420, -0.0895],
          [-0.1413,  0.0500,  0.0897]],

         [[-0.0711, -0.1296, -0.1002],
          [-0.0239,  0.0477,  0.1364],
          [-0.0958, -0.1656, -0.1276]],

         [[-0.1230, -0.0910,  0.0032],
          [-0.0958, -0.1549, -0.1324],
          [ 0.1377, -0.1392,  0.1484]]],


        [[[ 0.0510,  0.0343,  0.1636],
          [ 0.0559,  0.0953, -0.1076],
          [ 0.0381,  0.0672, -0.126

In [7]:
net_a1_hf = NetA1(10)
net_a1_ht = NetA1(10)
net_a1_dt = NetA1(10)

#set conv1 initialization of net_a1_hf
net_a1_hf.conv1.weight = nn.Parameter(copy.deepcopy(initialization_weights))
net_a1_hf.conv1.bias = nn.Parameter(copy.deepcopy(initialization_biases))

# set same weights and bias to each layer of each network
net_a1_ht.load_state_dict(net_a1_hf.state_dict())
net_a1_dt.load_state_dict(net_a1_hf.state_dict())

#set conv1 initialization
net_a1_dt.conv1.load_state_dict(net_a2_dt.conv1.state_dict())

#freeze conv1 layer of net_a2_hf
net_a1_hf.freeze("conv1")

#save weights and bias of nat_a1_h* and net_a1_dt
torch.save({'initialization': net_a1_hf.state_dict()}, 'NetA1HF_init.pt')
torch.save({'initialization': net_a1_ht.state_dict()}, 'NetA1HT_init.pt')
torch.save({'initialization': net_a1_dt.state_dict()}, 'NetA1DT_init.pt')


# print weights and bias
print("Net_A1_HF: \n \t", net_a1_hf.state_dict())
print("Net_A1_HT: \n \t", net_a1_ht.state_dict())
print("Net_A1_DT: \n \t", net_a1_dt.state_dict())

Net_A1_HF: 
 	 OrderedDict([('conv1.weight', tensor([[[[1., 0., 1.],
          [0., 1., 0.],
          [1., 0., 1.]]],


        [[[1., 1., 0.],
          [0., 0., 1.],
          [1., 1., 0.]]],


        [[[0., 1., 1.],
          [1., 0., 0.],
          [0., 1., 1.]]],


        [[[0., 1., 0.],
          [1., 1., 0.],
          [0., 1., 0.]]]])), ('conv1.bias', tensor([0., 0., 0., 0.])), ('linear1.weight', tensor([[ 0.0178,  0.0251, -0.0195,  ...,  0.0242,  0.0078,  0.0143],
        [ 0.0378, -0.0285, -0.0059,  ...,  0.0324, -0.0183, -0.0011],
        [ 0.0119,  0.0369,  0.0041,  ...,  0.0214,  0.0243, -0.0055],
        ...,
        [ 0.0301, -0.0059, -0.0222,  ...,  0.0063, -0.0187, -0.0121],
        [-0.0203, -0.0247,  0.0333,  ...,  0.0254, -0.0090,  0.0344],
        [-0.0116, -0.0273, -0.0160,  ..., -0.0193,  0.0077, -0.0220]])), ('linear1.bias', tensor([-0.0202,  0.0237,  0.0262,  0.0373, -0.0302, -0.0359, -0.0348, -0.0348,
        -0.0305, -0.0113]))])
Net_A1_HT: 
 	 OrderedDict

### Preliminary Analysys

In [8]:
print( "Net_A1: \n",
       "\t|W_{conv_a1_hf} - W_{conv_a1_ht}| =", torch.norm(net_a1_hf.conv1.weight - net_a1_ht.conv1.weight),"\n",
      "\t|W_{linear_a1_hf} - W_{linear_a1_ht}| =", torch.norm(net_a1_hf.linear1.weight - net_a1_ht.linear1.weight), "\n",
      "\t|W_{linear_a1_hf} - W_{linear_a1_dt}| =", torch.norm(net_a1_hf.linear1.weight - net_a1_dt.linear1.weight), "\n")

print( "Net_A2: \n",
       "\t|W_{conv1_a2_hf} - W_{conv1_a2_ht}| =", torch.norm(net_a2_hf.conv1.weight - net_a2_ht.conv1.weight),"\n",
       "\t|W_{conv2_a2_hf} - W_{conv2_a2_ht}| =", torch.norm(net_a2_hf.conv2.weight - net_a2_ht.conv2.weight),"\n",
       "\t|W_{linear1_a2_hf} - W_{linear1_a2_ht}| =", torch.norm(net_a2_hf.linear1.weight - net_a2_ht.linear1.weight), "\n",
       "\t|W_{linear1_a2_hf} - W_{linear1_a2_dt}| =", torch.norm(net_a2_hf.linear1.weight - net_a2_dt.linear1.weight), "\n",
       "\t|W_{linear2_a2_hf} - W_{linear2_a2_ht}| =", torch.norm(net_a2_hf.linear2.weight - net_a2_ht.linear2.weight), "\n",
       "\t|W_{linear2_a2_hf} - W_{linear2_a2_dt}| =", torch.norm(net_a2_hf.linear2.weight - net_a2_dt.linear2.weight), "\n",
       "\t|W_{linear3_a2_hf} - W_{linear3_a2_ht}| =", torch.norm(net_a2_hf.linear3.weight - net_a2_ht.linear3.weight), "\n",
       "\t|W_{linear3_a2_hf} - W_{linear3_a2_dt}| =", torch.norm(net_a2_hf.linear3.weight - net_a2_dt.linear3.weight), "\n")

print( "Net_A1 Vs Net_A2: \n",
       "\t|W_{conv1_a1_hf} - W_{conv1_a2_hf}| =", torch.norm(net_a1_hf.conv1.weight - net_a2_hf.conv1.weight),"\n",
       "\t|W_{conv1_a1_ht} - W_{conv2_a2_ht}| =", torch.norm(net_a1_ht.conv1.weight - net_a2_ht.conv1.weight),"\n",
       "\t|W_{conv1_a1_dt} - W_{conv2_a2_dt}| =", torch.norm(net_a1_dt.conv1.weight - net_a2_dt.conv1.weight),"\n")

Net_A1: 
 	|W_{conv_a1_hf} - W_{conv_a1_ht}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{linear_a1_hf} - W_{linear_a1_ht}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{linear_a1_hf} - W_{linear_a1_dt}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 

Net_A2: 
 	|W_{conv1_a2_hf} - W_{conv1_a2_ht}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{conv2_a2_hf} - W_{conv2_a2_ht}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{linear1_a2_hf} - W_{linear1_a2_ht}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{linear1_a2_hf} - W_{linear1_a2_dt}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{linear2_a2_hf} - W_{linear2_a2_ht}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{linear2_a2_hf} - W_{linear2_a2_dt}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{linear3_a2_hf} - W_{linear3_a2_ht}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{linear3_a2_hf} - W_{linear3_a2_dt}| = tensor(0., grad_fn=<Linal

In [9]:
print("Net_A1HF:")
for param in net_a1_hf.conv1.parameters():
    print("\t",param.requires_grad)
print("Net_A2HF:")
for param in net_a2_hf.conv1.parameters():
    print("\t",param.requires_grad)
print("Net_A1HT:")
for param in net_a1_ht.conv1.parameters():
    print("\t",param.requires_grad)
print("Net_A2HT:")
for param in net_a2_ht.conv1.parameters():
    print("\t",param.requires_grad)
print("Net_A1DT:")
for param in net_a1_dt.conv1.parameters():
    print("\t",param.requires_grad)
print("Net_A2DT:")
for param in net_a2_dt.conv1.parameters():
    print("\t",param.requires_grad)

Net_A1HF:
	 False
	 False
Net_A2HF:
	 False
	 False
Net_A1HT:
	 True
	 True
Net_A2HT:
	 True
	 True
Net_A1DT:
	 True
	 True
Net_A2DT:
	 True
	 True


In [10]:
print(net_a1_ht.conv1.weight is net_a1_hf.conv1.weight)
print(net_a1_ht.linear1.weight is net_a1_hf.linear1.weight)
print(net_a1_hf.linear1.weight is net_a1_dt.linear1.weight)
print(net_a2_hf.conv1.weight is net_a1_hf.conv1.weight)
print(net_a2_ht.conv1.weight is net_a1_ht.conv1.weight)
print(net_a2_dt.conv1.weight is net_a1_dt.conv1.weight)
print(net_a2_hf.linear1.weight is net_a2_ht.linear1.weight)
print(net_a2_hf.linear1.weight is net_a2_dt.linear1.weight)

False
False
False
False
False
False
False
False


### Data Loading

In [11]:
train_data= datasets.FashionMNIST(root='data', train=True, download=True, transform=ToTensor(),)

test_data = datasets.FashionMNIST(root='data', train=False, download=True, transform=ToTensor(),)

In [12]:
labels_map={
    0: 'T-shirt',
    1: 'Trouser',
    2: 'Pullover',
    3: 'Dress',
    4: 'Coat',
    5: 'Sandal',
    6: 'Shirt',
    7: 'Sneaker',
    8: 'Bag',
    9: 'Ankle Boot',
}
sample_idx = torch.randint(len(train_data), size = (1,)).item()
image, label = train_data[sample_idx]
image.shape

torch.Size([1, 28, 28])

In [13]:
batch_size = 160

train_dataloader= DataLoader(train_data, batch_size = batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

### Training/Test Loop

In [14]:
def train_loop(device, dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 1000 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

    return model.state_dict()


def test_loop(device, dataloader, model, loss_fn):
      size = len(dataloader.dataset)
      num_batches = len(dataloader)
      test_loss, correct = 0, 0

      with torch.no_grad():
        for X, y in dataloader:
          X, y = X.to(device), y.to(device)
          pred = model(X)
          test_loss += loss_fn(pred, y).item()
          correct += (pred.argmax(1) == y).type(torch.float).sum().item()

      test_loss /= num_batches
      correct /= size
      print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
      return 100*correct, test_loss

### Training

In [15]:
learning_rate = 3e-6#4-7
epochs = 30

In [16]:
def train_test(device, train_dataloader, test_dataloader, net, learning_rate, epochs):
    net.to(device)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
    
    accuracies = []
    losses = []
    times=[]
    
    time_s = time.time()
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train_loop(device, train_dataloader, net, loss_fn, optimizer)
        acc, loss = test_loop(device, test_dataloader, net, loss_fn)
        accuracies.append(acc)
        losses.append(loss)
        times.append(time.time() - time_s)
        time_s = time.time()
    print("Done!")
    return pd.DataFrame(
        {
            "epoch": [ i for i in range(epochs)],
            "times": times,
            "loss": losses,
            "accuracy": accuracies
        }
    )

NetA1 -> HF Train 

In [None]:
print(net_a1_hf.state_dict())
df_net_a1_hf = train_test(device, train_dataloader, test_dataloader, net_a1_hf, learning_rate, epochs)
df_net_a1_hf.to_csv('NetA1HF_results.csv', index=False)
torch.save({'initialization': net_a1_hf.state_dict()}, 'NetA1HF_trained.pt')
net_a1_hf.state_dict()

NetA1-> HT train

In [None]:
print(net_a1_ht.state_dict())
df_net_a1_ht = train_test(device, train_dataloader, test_dataloader, net_a1_ht, learning_rate, epochs)
df_net_a1_ht.to_csv('NetA1HT_results.csv', index=False)
torch.save({'initialization': net_a1_ht.state_dict()}, 'NetA1HT_trained.pt')
net_a1_ht.state_dict()

NetA1-> DT train

In [None]:
print(net_a1_dt.state_dict())
df_net_a1_dt = train_test(device, train_dataloader, test_dataloader, net_a1_dt, learning_rate, epochs)
df_net_a1_dt.to_csv('NetA1DT_results.csv', index=False)
torch.save({'initialization': net_a1_dt.state_dict()}, 'NetA1DT_trained.pt')
net_a1_dt.state_dict()

NetA2-> HF Train

In [None]:
print(net_a2_hf.state_dict())
df_net_a2_hf = train_test(device, train_dataloader, test_dataloader, net_a2_hf, learning_rate, epochs)
df_net_a2_hf.to_csv('NetA2HF_results.csv', index=False)
torch.save({'initialization': net_a2_hf.state_dict()}, 'NetA2HF_trained.pt')
net_a2_hf.state_dict() 

NetA2-> HT Train

In [18]:
print(net_a2_ht.state_dict())
df_net_a2_ht = train_test(device, train_dataloader, test_dataloader, net_a2_ht, learning_rate, epochs)
df_net_a2_ht.to_csv('NetA2HT_results.csv', index=False)
torch.save({'initialization': net_a2_ht.state_dict()}, 'NetA2HT_trained.pt')
net_a2_ht.state_dict()

OrderedDict([('conv1.weight', tensor([[[[1., 0., 1.],
          [0., 1., 0.],
          [1., 0., 1.]]],


        [[[1., 1., 0.],
          [0., 0., 1.],
          [1., 1., 0.]]],


        [[[0., 1., 1.],
          [1., 0., 0.],
          [0., 1., 1.]]],


        [[[0., 1., 0.],
          [1., 1., 0.],
          [0., 1., 0.]]]])), ('conv1.bias', tensor([0., 0., 0., 0.])), ('conv2.weight', tensor([[[[ 0.1285,  0.1293, -0.1389],
          [-0.0664,  0.0077, -0.1630],
          [-0.0279,  0.0455,  0.1343]],

         [[ 0.1068, -0.1197,  0.0090],
          [-0.0064, -0.0420, -0.0895],
          [-0.1413,  0.0500,  0.0897]],

         [[-0.0711, -0.1296, -0.1002],
          [-0.0239,  0.0477,  0.1364],
          [-0.0958, -0.1656, -0.1276]],

         [[-0.1230, -0.0910,  0.0032],
          [-0.0958, -0.1549, -0.1324],
          [ 0.1377, -0.1392,  0.1484]]],


        [[[ 0.0510,  0.0343,  0.1636],
          [ 0.0559,  0.0953, -0.1076],
          [ 0.0381,  0.0672, -0.1269]],

         

OrderedDict([('conv1.weight',
              tensor([[[[ 0.9899,  0.0101,  1.0153],
                        [-0.0017,  1.0102,  0.0125],
                        [ 1.0012,  0.0098,  1.0119]]],
              
              
                      [[[ 1.0421,  1.0405,  0.0387],
                        [ 0.0417,  0.0404,  1.0390],
                        [ 1.0410,  1.0397,  0.0394]]],
              
              
                      [[[ 0.0202,  1.0297,  1.0338],
                        [ 1.0238,  0.0348,  0.0375],
                        [ 0.0243,  1.0374,  1.0400]]],
              
              
                      [[[ 0.0407,  1.0414,  0.0342],
                        [ 1.0378,  1.0341,  0.0162],
                        [ 0.0230,  1.0151,  0.0099]]]])),
             ('conv1.bias', tensor([ 0.0121,  0.0182, -0.0047, -0.0044])),
             ('conv2.weight',
              tensor([[[[ 0.1527,  0.1468, -0.1090],
                        [-0.0415,  0.0346, -0.1262],
                      

NetA2-> DT Train

In [17]:
print(net_a2_dt.state_dict())
df_net_a2_dt = train_test(device, train_dataloader, test_dataloader, net_a2_dt, learning_rate, epochs)
df_net_a2_dt.to_csv('NetA2DT_results.csv', index=False)
torch.save({'initialization': net_a2_dt.state_dict()}, 'NetA2HT_trained.pt')
net_a2_dt.state_dict()

OrderedDict([('conv1.weight', tensor([[[[-0.2328, -0.1575,  0.1404],
          [-0.3092, -0.2124, -0.0475],
          [-0.0328,  0.1321, -0.1510]]],


        [[[-0.2308, -0.2374,  0.0224],
          [-0.1765,  0.0793,  0.1701],
          [ 0.3186, -0.2032, -0.2326]]],


        [[[-0.2162,  0.2119,  0.2384],
          [ 0.2922,  0.2603, -0.1914],
          [ 0.0441, -0.1089, -0.0154]]],


        [[[ 0.1941, -0.2157, -0.1246],
          [-0.2838,  0.3148, -0.3312],
          [ 0.2337,  0.1579,  0.2093]]]])), ('conv1.bias', tensor([-0.2088,  0.1795,  0.3233, -0.0254])), ('conv2.weight', tensor([[[[ 0.1285,  0.1293, -0.1389],
          [-0.0664,  0.0077, -0.1630],
          [-0.0279,  0.0455,  0.1343]],

         [[ 0.1068, -0.1197,  0.0090],
          [-0.0064, -0.0420, -0.0895],
          [-0.1413,  0.0500,  0.0897]],

         [[-0.0711, -0.1296, -0.1002],
          [-0.0239,  0.0477,  0.1364],
          [-0.0958, -0.1656, -0.1276]],

         [[-0.1230, -0.0910,  0.0032],
          

OrderedDict([('conv1.weight',
              tensor([[[[-0.2328, -0.1575,  0.1404],
                        [-0.3092, -0.2124, -0.0475],
                        [-0.0328,  0.1321, -0.1510]]],
              
              
                      [[[-0.1892, -0.2603,  0.0035],
                        [-0.1642,  0.0485,  0.1463],
                        [ 0.3193, -0.2339, -0.2553]]],
              
              
                      [[[-0.1930,  0.1941,  0.2052],
                        [ 0.2881,  0.2280, -0.2327],
                        [ 0.0260, -0.1460, -0.0579]]],
              
              
                      [[[ 0.2465, -0.1635, -0.0718],
                        [-0.2308,  0.3675, -0.2780],
                        [ 0.2867,  0.2115,  0.2629]]]])),
             ('conv1.bias',
              tensor([-2.0875e-01,  2.3210e-01,  3.4513e-01, -1.1122e-07])),
             ('conv2.weight',
              tensor([[[[ 0.1285,  0.1293, -0.1389],
                        [-0.0664,  0.0077, -0