In [1]:
import torch
from torchvision import datasets
from torchvision.models.optical_flow.raft import ResidualBlock
from torchvision.transforms import ToTensor
from torch import nn
from torch import Tensor
from torch.utils.data import DataLoader
import pandas as pd

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Network definition

In [3]:
class NetA1(nn.Module):
    def __init__(self, num_classes: int):
        super(NetA1, self).__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=3, stride=2)
        self.flatten = nn.Flatten(start_dim=-3)
        self.linear1 = nn.Linear(676, num_classes)
        self.softmax = nn.Softmax(dim=0)
            
    def freeze(self, layer: str):
        for param in getattr(self, layer).parameters():
                param.requires_grad = False

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.flatten(x)
        x = self.relu(self.linear1(x))
        x = self.softmax(x)
        return x

In [4]:
class NetA2(nn.Module):
    def __init__(self, num_classes: int):
        super(NetA2, self).__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=3, stride=2)
        self.conv2 = nn.Conv2d(in_channels=4, out_channels=6, kernel_size=3, stride=2)
        self.flatten = nn.Flatten(start_dim=-3)
        self.linear1 = nn.Linear(216, 260)
        self.linear2 = nn.Linear(260, 160)
        self.linear3 = nn.Linear(160, num_classes)
        self.softmax = nn.Softmax(dim=0)
        
    def freeze(self, layer: str):
        for param in getattr(self, layer).parameters():
                param.requires_grad = False

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.flatten(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.relu(self.linear3(x))
        x = self.softmax(x)
        return x

### Network initialization

In [5]:
initialization_weights = torch.tensor([
    [[[1, 0, 1], [0, 1, 0], [1, 0, 1]]],
    [[[1, 1, 0], [0, 0, 1], [1, 1, 0]]],
    [[[0, 1, 1], [1, 0, 0], [0, 1, 1]]],
    [[[0, 1, 0], [1, 1, 0], [0, 1, 0]]]
            ], dtype=torch.float32)

initialization_biases = torch.tensor([0,0,0,0], dtype=torch.float32)
initialization_weights.shape

torch.Size([4, 1, 3, 3])

In [6]:
net_a1_hf = NetA1(10)
net_a1_ht = NetA1(10)
net_a1_dt = NetA1(10)

#set conv1 initialization of net_a1_hf
net_a1_hf.conv1.weight = nn.Parameter(initialization_weights)
net_a1_hf.conv1.bias = nn.Parameter(initialization_biases)

# set same weights and bias to each layer of each network (except for cov1 of net_a1_dt)
net_a1_ht.load_state_dict(net_a1_hf.state_dict())
net_a1_dt.linear1.load_state_dict(net_a1_hf.linear1.state_dict())

#freeze conv1 layer of net_a1_hf
net_a1_hf.freeze("conv1")

#save weights and bias of nat_a1_h* and net_a1_dt
torch.save({'initialization': net_a1_hf.state_dict()}, 'NetA1HF_init.pt')
torch.save({'initialization': net_a1_ht.state_dict()}, 'NetA1HT_init.pt')
torch.save({'initialization': net_a1_dt.state_dict()}, 'NetA1DT_init.pt')


# print weights and bias
print("Net_A1_HF: \n \t", net_a1_hf.state_dict())
print("Net_A1_HT: \n \t", net_a1_ht.state_dict())
print("Net_A1_DT: \n \t", net_a1_dt.state_dict())

Net_A1_HF: 
 	 OrderedDict([('conv1.weight', tensor([[[[1., 0., 1.],
          [0., 1., 0.],
          [1., 0., 1.]]],


        [[[1., 1., 0.],
          [0., 0., 1.],
          [1., 1., 0.]]],


        [[[0., 1., 1.],
          [1., 0., 0.],
          [0., 1., 1.]]],


        [[[0., 1., 0.],
          [1., 1., 0.],
          [0., 1., 0.]]]])), ('conv1.bias', tensor([0., 0., 0., 0.])), ('linear1.weight', tensor([[-0.0285,  0.0300,  0.0261,  ...,  0.0316, -0.0054, -0.0235],
        [-0.0001,  0.0382, -0.0200,  ...,  0.0139, -0.0020,  0.0334],
        [-0.0203, -0.0349,  0.0247,  ..., -0.0347,  0.0140, -0.0029],
        ...,
        [ 0.0017, -0.0161,  0.0051,  ..., -0.0248,  0.0294, -0.0144],
        [ 0.0372,  0.0311, -0.0312,  ..., -0.0147,  0.0037,  0.0027],
        [-0.0196,  0.0066, -0.0079,  ...,  0.0210, -0.0169,  0.0009]])), ('linear1.bias', tensor([ 0.0192,  0.0374,  0.0258, -0.0071,  0.0304, -0.0120, -0.0228,  0.0204,
        -0.0246, -0.0212]))])
Net_A1_HT: 
 	 OrderedDict

In [7]:
net_a2_hf = NetA2(10)
net_a2_ht = NetA2(10)
net_a2_dt = NetA2(10)

#set conv1 initialization of net_a2_hf
net_a2_hf.conv1.weight = nn.Parameter(initialization_weights)
net_a2_hf.conv1.bias = nn.Parameter(initialization_biases)

# set same weights and bias to each layer of each network (except for cov1 of net_a1_dt)
net_a2_ht.load_state_dict(net_a2_hf.state_dict())
net_a2_dt.load_state_dict(net_a2_hf.state_dict())

#set conv1 initialization
net_a2_dt.conv1.load_state_dict(net_a1_dt.conv1.state_dict())

#freeze conv1 layer of net_a2_hf
net_a2_hf.freeze("conv1")

#save weights and bias of nat_a1_h* and net_a1_dt
torch.save({'initialization': net_a1_hf.state_dict()}, 'NetA2HF_init.pt')
torch.save({'initialization': net_a1_ht.state_dict()}, 'NetA2HT_init.pt')
torch.save({'initialization': net_a1_dt.state_dict()}, 'NetA2DT_init.pt')


# print weights and bias
print("Net_A1_HF: \n \t", net_a2_hf.state_dict())
print("Net_A1_HT: \n \t", net_a2_ht.state_dict())
print("Net_A1_DT: \n \t", net_a2_dt.state_dict())

Net_A1_HF: 
 	 OrderedDict([('conv1.weight', tensor([[[[1., 0., 1.],
          [0., 1., 0.],
          [1., 0., 1.]]],


        [[[1., 1., 0.],
          [0., 0., 1.],
          [1., 1., 0.]]],


        [[[0., 1., 1.],
          [1., 0., 0.],
          [0., 1., 1.]]],


        [[[0., 1., 0.],
          [1., 1., 0.],
          [0., 1., 0.]]]])), ('conv1.bias', tensor([0., 0., 0., 0.])), ('conv2.weight', tensor([[[[ 0.0325,  0.0284, -0.0623],
          [-0.1249, -0.0975, -0.0574],
          [-0.1139,  0.1273,  0.1448]],

         [[-0.0991, -0.1663,  0.0272],
          [-0.0730, -0.1493, -0.0154],
          [-0.0410, -0.0499, -0.0529]],

         [[-0.0919, -0.1395, -0.0721],
          [ 0.0850, -0.0454, -0.1195],
          [-0.1009, -0.1657, -0.0756]],

         [[-0.1144,  0.1369, -0.1300],
          [ 0.0900,  0.1615, -0.0008],
          [ 0.0127, -0.0586, -0.0175]]],


        [[[ 0.0809,  0.0132, -0.1606],
          [ 0.0781, -0.1317, -0.0225],
          [ 0.1474, -0.1250,  0.012

### Preliminary Analysys

In [8]:
print( "Net_A1: \n",
       "\t|W_{conv_a1_hf} - W_{conv_a1_ht}| =", torch.norm(net_a1_hf.conv1.weight - net_a1_ht.conv1.weight),"\n",
      "\t|W_{linear_a1_hf} - W_{linear_a1_ht}| =", torch.norm(net_a1_hf.linear1.weight - net_a1_ht.linear1.weight), "\n",
      "\t|W_{linear_a1_hf} - W_{linear_a1_dt}| =", torch.norm(net_a1_hf.linear1.weight - net_a1_dt.linear1.weight), "\n")

print( "Net_A2: \n",
       "\t|W_{conv1_a2_hf} - W_{conv1_a2_ht}| =", torch.norm(net_a2_hf.conv1.weight - net_a2_ht.conv1.weight),"\n",
       "\t|W_{conv2_a2_hf} - W_{conv2_a2_ht}| =", torch.norm(net_a2_hf.conv2.weight - net_a2_ht.conv2.weight),"\n",
       "\t|W_{linear1_a2_hf} - W_{linear1_a2_ht}| =", torch.norm(net_a2_hf.linear1.weight - net_a2_ht.linear1.weight), "\n",
       "\t|W_{linear1_a2_hf} - W_{linear1_a2_dt}| =", torch.norm(net_a2_hf.linear1.weight - net_a2_dt.linear1.weight), "\n",
       "\t|W_{linear2_a2_hf} - W_{linear2_a2_ht}| =", torch.norm(net_a2_hf.linear2.weight - net_a2_ht.linear2.weight), "\n",
       "\t|W_{linear2_a2_hf} - W_{linear2_a2_dt}| =", torch.norm(net_a2_hf.linear2.weight - net_a2_dt.linear2.weight), "\n",
       "\t|W_{linear3_a2_hf} - W_{linear3_a2_ht}| =", torch.norm(net_a2_hf.linear3.weight - net_a2_ht.linear3.weight), "\n",
       "\t|W_{linear3_a2_hf} - W_{linear3_a2_dt}| =", torch.norm(net_a2_hf.linear3.weight - net_a2_dt.linear3.weight), "\n")

print( "Net_A1 Vs Net_A2: \n",
       "\t|W_{conv1_a1_hf} - W_{conv1_a2_hf}| =", torch.norm(net_a1_hf.conv1.weight - net_a2_hf.conv1.weight),"\n",
       "\t|W_{conv1_a1_ht} - W_{conv2_a2_ht}| =", torch.norm(net_a1_ht.conv1.weight - net_a2_ht.conv1.weight),"\n",
       "\t|W_{conv1_a1_dt} - W_{conv2_a2_dt}| =", torch.norm(net_a1_dt.conv1.weight - net_a2_dt.conv1.weight),"\n")

Net_A1: 
 	|W_{conv_a1_hf} - W_{conv_a1_ht}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{linear_a1_hf} - W_{linear_a1_ht}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{linear_a1_hf} - W_{linear_a1_dt}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 

Net_A2: 
 	|W_{conv1_a2_hf} - W_{conv1_a2_ht}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{conv2_a2_hf} - W_{conv2_a2_ht}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{linear1_a2_hf} - W_{linear1_a2_ht}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{linear1_a2_hf} - W_{linear1_a2_dt}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{linear2_a2_hf} - W_{linear2_a2_ht}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{linear2_a2_hf} - W_{linear2_a2_dt}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{linear3_a2_hf} - W_{linear3_a2_ht}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{linear3_a2_hf} - W_{linear3_a2_dt}| = tensor(0., grad_fn=<Linal

### Data Loading

In [9]:
train_data= datasets.FashionMNIST(root='data', train=True, download=True, transform=ToTensor(),)

test_data = datasets.FashionMNIST(root='data', train=False, download=True, transform=ToTensor(),)

In [10]:
labels_map={
    0: 'T-shirt',
    1: 'Trouser',
    2: 'Pullover',
    3: 'Dress',
    4: 'Coat',
    5: 'Sandal',
    6: 'Shirt',
    7: 'Sneaker',
    8: 'Bag',
    9: 'Ankle Boot',
}
sample_idx = torch.randint(len(train_data), size = (1,)).item()
image, label = train_data[sample_idx]
image.shape

torch.Size([1, 28, 28])

In [11]:
batch_size = 128

train_dataloader= DataLoader(train_data, batch_size = batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

### Training/Test Loop

In [12]:
def train_loop(device, dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 1000 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

    return model.state_dict()


def test_loop(device, dataloader, model, loss_fn):
      size = len(dataloader.dataset)
      num_batches = len(dataloader)
      test_loss, correct = 0, 0

      with torch.no_grad():
        for X, y in dataloader:
          X, y = X.to(device), y.to(device)
          pred = model(X)
          test_loss += loss_fn(pred, y).item()
          correct += (pred.argmax(1) == y).type(torch.float).sum().item()

      test_loss /= num_batches
      correct /= size
      print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
      return 100*correct, test_loss

### Training

In [13]:
learning_rate = 1e-4
epochs = 50

In [14]:
def train_test(device, train_dataloader, test_dataloader, net, learning_rate, epochs):
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
    
    accuracies = []
    losses = []
    
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train_loop(device, train_dataloader, net, loss_fn, optimizer)
        acc, loss = test_loop(device, test_dataloader, net, loss_fn)
        accuracies.append(acc)
        losses.append(loss)
    print("Done!")
    return pd.DataFrame(
        {
            "epoch": [ i for i in range(epochs)],
            "loss": losses,
            "accuracy": accuracies
        }
    )

NetA1 -> HF Train 

In [15]:
print(net_a1_hf.state_dict())
df_net_a1_hf = train_test(device, train_dataloader, test_dataloader, net_a1_hf, learning_rate, epochs)
df_net_a1_hf.to_csv('NetA1HF_results.csv', index=False)
net_a1_hf.state_dict()

OrderedDict([('conv1.weight', tensor([[[[1., 0., 1.],
          [0., 1., 0.],
          [1., 0., 1.]]],


        [[[1., 1., 0.],
          [0., 0., 1.],
          [1., 1., 0.]]],


        [[[0., 1., 1.],
          [1., 0., 0.],
          [0., 1., 1.]]],


        [[[0., 1., 0.],
          [1., 1., 0.],
          [0., 1., 0.]]]])), ('conv1.bias', tensor([0., 0., 0., 0.])), ('linear1.weight', tensor([[ 0.0104, -0.0263, -0.0227,  ..., -0.0365,  0.0185,  0.0133],
        [-0.0147, -0.0146,  0.0140,  ..., -0.0009,  0.0057, -0.0075],
        [-0.0164, -0.0220,  0.0229,  ...,  0.0109,  0.0119,  0.0116],
        ...,
        [ 0.0013, -0.0225, -0.0024,  ..., -0.0003, -0.0223, -0.0032],
        [ 0.0315,  0.0013,  0.0172,  ...,  0.0241,  0.0051,  0.0278],
        [ 0.0120,  0.0103, -0.0237,  ..., -0.0320, -0.0291, -0.0278]])), ('linear1.bias', tensor([-0.0365,  0.0243, -0.0267, -0.0056, -0.0325,  0.0328,  0.0208,  0.0131,
        -0.0117,  0.0210]))])
Epoch 1
-------------------------------
l

KeyboardInterrupt: 

NetA1-> HT train

In [None]:
print(net_a1_ht.state_dict())
df_net_a1_ht = train_test(device, train_dataloader, test_dataloader, net_a1_ht, learning_rate, epochs)
df_net_a1_ht.to_csv('NetA1HT_results.csv', index=False)
net_a1_ht.state_dict()

NetA1-> DT train

In [15]:
print(net_a1_dt.state_dict())
df_net_a1_dt = train_test(device, train_dataloader, test_dataloader, net_a1_dt, learning_rate, epochs)
df_net_a1_dt.to_csv('NetA1DT_results.csv', index=False)
net_a1_dt.state_dict()

OrderedDict([('conv1.weight', tensor([[[[-0.1554,  0.0492,  0.3141],
          [-0.1258, -0.3250, -0.0167],
          [ 0.0201, -0.1168,  0.3246]]],


        [[[-0.0118,  0.2718, -0.2517],
          [-0.0388, -0.2801,  0.1957],
          [ 0.1790,  0.2293,  0.0635]]],


        [[[ 0.1669,  0.1119,  0.1287],
          [ 0.0206, -0.3305, -0.0627],
          [ 0.2295, -0.1768,  0.1198]]],


        [[[ 0.1811, -0.2077, -0.1208],
          [ 0.0919, -0.1740,  0.1239],
          [-0.3191, -0.0752, -0.1332]]]])), ('conv1.bias', tensor([-0.1342,  0.2487,  0.1029, -0.3292])), ('linear1.weight', tensor([[-0.0285,  0.0300,  0.0261,  ...,  0.0316, -0.0054, -0.0235],
        [-0.0001,  0.0382, -0.0200,  ...,  0.0139, -0.0020,  0.0334],
        [-0.0203, -0.0349,  0.0247,  ..., -0.0347,  0.0140, -0.0029],
        ...,
        [ 0.0017, -0.0161,  0.0051,  ..., -0.0248,  0.0294, -0.0144],
        [ 0.0372,  0.0311, -0.0312,  ..., -0.0147,  0.0037,  0.0027],
        [-0.0196,  0.0066, -0.0079,  ...,

OrderedDict([('conv1.weight',
              tensor([[[[-0.0564,  0.3738,  1.0812],
                        [ 0.1055,  0.0993,  0.9112],
                        [ 0.2583,  0.3009,  1.1307]]],
              
              
                      [[[ 0.1899,  0.5847, -0.1145],
                        [ 0.2957,  0.2213,  0.6387],
                        [ 0.5168,  0.6410,  0.4810]]],
              
              
                      [[[ 0.7863,  0.4414,  0.6641],
                        [ 0.8376,  0.0954,  0.6608],
                        [ 0.7549, -0.0214,  0.6964]]],
              
              
                      [[[ 0.1811, -0.2077, -0.1208],
                        [ 0.0919, -0.1740,  0.1239],
                        [-0.3191, -0.0752, -0.1332]]]])),
             ('conv1.bias',
              tensor([-6.8523e-02,  1.0596e+00,  4.8831e-04, -3.2920e-01])),
             ('linear1.weight',
              tensor([[ 0.0663, -0.0055, -0.0764,  ...,  0.0316, -0.0054, -0.0235],
            

NetA2-> HF Train

In [None]:
print(net_a2_hf.state_dict())
df_net_a2_hf = train_test(device, train_dataloader, test_dataloader, net_a2_hf, learning_rate, epochs)
df_net_a2_hf.to_csv('NetA2HF_results.csv', index=False)
net_a2_hf.state_dict() 

NetA2-> HT Train

In [None]:
print(net_a2_ht.state_dict())
df_net_a2_ht = train_test(device, train_dataloader, test_dataloader, net_a2_ht, learning_rate, epochs)
df_net_a2_ht.to_csv('NetA2HT_results.csv', index=False)
net_a2_ht.state_dict()

NetA2-> DT Train

In [None]:
print(net_a2_dt.state_dict())
df_net_a2_dt = train_test(device, train_dataloader, test_dataloader, net_a2_dt, learning_rate, epochs)
df_net_a2_dt.to_csv('NetA2DT_results.csv', index=False)
net_a2_dt.state_dict()