In [18]:
import torch
from torchvision import datasets
from torchvision.models.optical_flow.raft import ResidualBlock
from torchvision.transforms import ToTensor
from torch import nn
from torch import Tensor
import time
import copy
from torch.utils.data import DataLoader
import pandas as pd

In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Network definition

In [20]:
class NetA1(nn.Module):
    def __init__(self, num_classes: int):
        super(NetA1, self).__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=5, stride=2)
        self.flatten = nn.Flatten(start_dim=-3)
        self.linear1 = nn.Linear(576, num_classes)
        self.softmax = nn.Softmax(dim=0)
            
    def freeze(self, layer: str):
        for param in getattr(self, layer).parameters():
                param.requires_grad = False

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.flatten(x)
        x = self.relu(self.linear1(x))
        x = self.softmax(x)
        return x

In [21]:
class NetA2(nn.Module):
    def __init__(self, num_classes: int):
        super(NetA2, self).__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=5, stride=2)
        self.conv2 = nn.Conv2d(in_channels=4, out_channels=12, kernel_size=3, stride=2)
        self.flatten = nn.Flatten(start_dim=-3)
        self.linear1 = nn.Linear(300, num_classes)
        self.softmax = nn.Softmax(dim=0)
        
    def freeze(self, layer: str):
        for param in getattr(self, layer).parameters():
                param.requires_grad = False

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.flatten(x)
        x = self.relu(self.linear1(x))
        return x

### Network initialization

In [22]:
initialization_weights = torch.tensor([
    [[
        [1, 0, 0, 0, 1],
        [0, 1, 0, 1, 0], 
        [0, 0, 1, 0, 0], 
        [0, 1, 0, 1, 0],
        [1, 0, 0, 0, 1]
    ]],
    [[
        [0, 0, 1, 0, 0],
        [1, 1, 0, 1, 1], 
        [0, 0, 1, 0, 0], 
        [1, 1, 0, 1, 1],
        [0, 0, 1, 0, 0]
    ]],
    [[
        [0, 1, 1, 1, 0],
        [1, 1, 0, 1, 1], 
        [1, 0, 0, 0, 1], 
        [1, 1, 0, 1, 1],
        [0, 1, 1, 1, 0]
    ]],
    [[
        [1, 1, 0, 1, 1], 
        [0, 1, 0, 1, 0], 
        [0, 0, 1, 0, 0], 
        [1, 1, 0, 1, 1],
        [1, 1, 0, 1, 1]
    ]]], dtype=torch.float32)

initialization_biases = torch.tensor([0,0,0,0], dtype=torch.float32)
initialization_weights.shape

torch.Size([4, 1, 5, 5])

In [23]:
net_a1_hf = NetA1(10)
net_a1_ht = NetA1(10)
net_a1_dt = NetA1(10)

#set conv1 initialization of net_a1_hf
net_a1_hf.conv1.weight = nn.Parameter(copy.deepcopy(initialization_weights))
net_a1_hf.conv1.bias = nn.Parameter(copy.deepcopy(initialization_biases))

# set same weights and bias to each layer of each network
net_a1_ht.load_state_dict(net_a1_hf.state_dict())
for name, param in net_a1_hf.state_dict().items():
    if "conv1" not in name:
        net_a1_dt.state_dict()[name].copy_(param)

#set conv1 initialization
#net_a1_dt.conv1.load_state_dict(net_a2_dt.conv1.state_dict())

#freeze conv1 layer of net_a2_hf
net_a1_hf.freeze("conv1")

#save weights and bias of nat_a1_h* and net_a1_dt
torch.save({'initialization': net_a1_hf.state_dict()}, 'NetA1HF_init.pt')
torch.save({'initialization': net_a1_ht.state_dict()}, 'NetA1HT_init.pt')
torch.save({'initialization': net_a1_dt.state_dict()}, 'NetA1DT_init.pt')


# print weights and bias
print("Net_A1_HF: \n \t", net_a1_hf.state_dict())
print("Net_A1_HT: \n \t", net_a1_ht.state_dict())
print("Net_A1_DT: \n \t", net_a1_dt.state_dict())

Net_A1_HF: 
 	 OrderedDict([('conv1.weight', tensor([[[[1., 0., 0., 0., 1.],
          [0., 1., 0., 1., 0.],
          [0., 0., 1., 0., 0.],
          [0., 1., 0., 1., 0.],
          [1., 0., 0., 0., 1.]]],


        [[[0., 0., 1., 0., 0.],
          [1., 1., 0., 1., 1.],
          [0., 0., 1., 0., 0.],
          [1., 1., 0., 1., 1.],
          [0., 0., 1., 0., 0.]]],


        [[[0., 1., 1., 1., 0.],
          [1., 1., 0., 1., 1.],
          [1., 0., 0., 0., 1.],
          [1., 1., 0., 1., 1.],
          [0., 1., 1., 1., 0.]]],


        [[[1., 1., 0., 1., 1.],
          [0., 1., 0., 1., 0.],
          [0., 0., 1., 0., 0.],
          [1., 1., 0., 1., 1.],
          [1., 1., 0., 1., 1.]]]])), ('conv1.bias', tensor([0., 0., 0., 0.])), ('linear1.weight', tensor([[-0.0308,  0.0350, -0.0375,  ...,  0.0210, -0.0087, -0.0375],
        [-0.0086, -0.0217, -0.0280,  ..., -0.0164,  0.0151,  0.0401],
        [-0.0174, -0.0276,  0.0234,  ..., -0.0107, -0.0063,  0.0163],
        ...,
        [ 0.00

In [24]:
net_a2_hf = NetA2(10)
net_a2_ht = NetA2(10)
net_a2_dt = NetA2(10)

#set conv1 initialization of net_a2_hf
net_a2_hf.conv1.weight = nn.Parameter(copy.deepcopy(initialization_weights))
net_a2_hf.conv1.bias = nn.Parameter(copy.deepcopy(initialization_biases))

# set same weights and bias to each layer of each network
net_a2_ht.load_state_dict(net_a2_hf.state_dict())
for name, param in net_a2_hf.state_dict().items():
    if "conv1" not in name:
        net_a2_dt.state_dict()[name].copy_(param)

#set conv1 initialization
net_a2_dt.conv1.load_state_dict(net_a1_dt.conv1.state_dict())

#freeze conv1 layer of net_a2_hf
net_a2_hf.freeze("conv1")

#save weights and bias of nat_a1_h* and net_a1_dt
torch.save({'initialization': net_a2_hf.state_dict()}, 'NetA2HF_init.pt')
torch.save({'initialization': net_a2_ht.state_dict()}, 'NetA2HT_init.pt')
torch.save({'initialization': net_a2_dt.state_dict()}, 'NetA2DT_init.pt')


# print weights and bias
print("Net_A1_HF: \n \t", net_a2_hf.state_dict())
print("Net_A1_HT: \n \t", net_a2_ht.state_dict())
print("Net_A1_DT: \n \t", net_a2_dt.state_dict())

Net_A1_HF: 
 	 OrderedDict([('conv1.weight', tensor([[[[1., 0., 0., 0., 1.],
          [0., 1., 0., 1., 0.],
          [0., 0., 1., 0., 0.],
          [0., 1., 0., 1., 0.],
          [1., 0., 0., 0., 1.]]],


        [[[0., 0., 1., 0., 0.],
          [1., 1., 0., 1., 1.],
          [0., 0., 1., 0., 0.],
          [1., 1., 0., 1., 1.],
          [0., 0., 1., 0., 0.]]],


        [[[0., 1., 1., 1., 0.],
          [1., 1., 0., 1., 1.],
          [1., 0., 0., 0., 1.],
          [1., 1., 0., 1., 1.],
          [0., 1., 1., 1., 0.]]],


        [[[1., 1., 0., 1., 1.],
          [0., 1., 0., 1., 0.],
          [0., 0., 1., 0., 0.],
          [1., 1., 0., 1., 1.],
          [1., 1., 0., 1., 1.]]]])), ('conv1.bias', tensor([0., 0., 0., 0.])), ('conv2.weight', tensor([[[[-1.6052e-01, -6.2000e-02,  1.0837e-01],
          [-5.0771e-02, -8.0880e-02,  2.3348e-02],
          [-6.8565e-02, -1.7490e-03,  3.6397e-03]],

         [[-3.7457e-03,  4.4198e-02,  1.4329e-01],
          [ 1.4342e-03,  9.4503e-

### Preliminary Analysys

In [45]:
print( "Net_A1: \n",
       "\t|W_{conv_a1_hf} - W_{conv_a1_ht}| =", torch.norm(net_a1_hf.conv1.weight - net_a1_ht.conv1.weight),"\n",
      "\t|W_{linear_a1_hf} - W_{linear_a1_ht}| =", torch.norm(net_a1_hf.linear1.weight - net_a1_ht.linear1.weight), "\n",
      "\t|W_{linear_a1_hf} - W_{linear_a1_dt}| =", torch.norm(net_a1_hf.linear1.weight - net_a1_dt.linear1.weight), "\n")

print( "Net_A2: \n",
       "\t|W_{conv1_a2_hf} - W_{conv1_a2_ht}| =", torch.norm(net_a2_hf.conv1.weight - net_a2_ht.conv1.weight),"\n",
       "\t|W_{conv2_a2_hf} - W_{conv2_a2_ht}| =", torch.norm(net_a2_hf.conv2.weight - net_a2_ht.conv2.weight),"\n",
       "\t|W_{linear1_a2_hf} - W_{linear1_a2_ht}| =", torch.norm(net_a2_hf.linear1.weight - net_a2_ht.linear1.weight), "\n",
       "\t|W_{linear1_a2_hf} - W_{linear1_a2_dt}| =", torch.norm(net_a2_hf.linear1.weight - net_a2_dt.linear1.weight), "\n")

print( "Net_A1 Vs Net_A2: \n",
       "\t|W_{conv1_a1_hf} - W_{conv1_a2_hf}| =", torch.norm(net_a1_hf.conv1.weight - net_a2_hf.conv1.weight),"\n",
       "\t|W_{conv1_a1_ht} - W_{conv2_a2_ht}| =", torch.norm(net_a1_ht.conv1.weight - net_a2_ht.conv1.weight),"\n",
       "\t|W_{conv1_a1_dt} - W_{conv2_a2_dt}| =", torch.norm(net_a1_dt.conv1.weight - net_a2_dt.conv1.weight),"\n")

Net_A1: 
 	|W_{conv_a1_hf} - W_{conv_a1_ht}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{linear_a1_hf} - W_{linear_a1_ht}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{linear_a1_hf} - W_{linear_a1_dt}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 

Net_A2: 
 	|W_{conv1_a2_hf} - W_{conv1_a2_ht}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{conv2_a2_hf} - W_{conv2_a2_ht}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{linear1_a2_hf} - W_{linear1_a2_ht}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{linear1_a2_hf} - W_{linear1_a2_dt}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 

Net_A1 Vs Net_A2: 
 	|W_{conv1_a1_hf} - W_{conv1_a2_hf}| = tensor(0.) 
 	|W_{conv1_a1_ht} - W_{conv2_a2_ht}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{conv1_a1_dt} - W_{conv2_a2_dt}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 


In [26]:
print("Net_A1HF:")
for param in net_a1_hf.conv1.parameters():
    print("\t",param.requires_grad)
print("Net_A2HF:")
for param in net_a2_hf.conv1.parameters():
    print("\t",param.requires_grad)
print("Net_A1HT:")
for param in net_a1_ht.conv1.parameters():
    print("\t",param.requires_grad)
print("Net_A2HT:")
for param in net_a2_ht.conv1.parameters():
    print("\t",param.requires_grad)
print("Net_A1DT:")
for param in net_a1_dt.conv1.parameters():
    print("\t",param.requires_grad)
print("Net_A2DT:")
for param in net_a2_dt.conv1.parameters():
    print("\t",param.requires_grad)

Net_A1HF:
	 False
	 False
Net_A2HF:
	 False
	 False
Net_A1HT:
	 True
	 True
Net_A2HT:
	 True
	 True
Net_A1DT:
	 True
	 True
Net_A2DT:
	 True
	 True


In [27]:
print(net_a1_ht.conv1.weight is net_a1_hf.conv1.weight)
print(net_a1_ht.linear1.weight is net_a1_hf.linear1.weight)
print(net_a1_hf.linear1.weight is net_a1_dt.linear1.weight)
print(net_a2_hf.conv1.weight is net_a1_hf.conv1.weight)
print(net_a2_ht.conv1.weight is net_a1_ht.conv1.weight)
print(net_a2_dt.conv1.weight is net_a1_dt.conv1.weight)
print(net_a2_hf.linear1.weight is net_a2_ht.linear1.weight)
print(net_a2_hf.linear1.weight is net_a2_dt.linear1.weight)

False
False
False
False
False
False
False
False


### Data Loading

In [28]:
train_data= datasets.FashionMNIST(root='data', train=True, download=True, transform=ToTensor(),)

test_data = datasets.FashionMNIST(root='data', train=False, download=True, transform=ToTensor(),)

In [29]:
labels_map={
    0: 'T-shirt',
    1: 'Trouser',
    2: 'Pullover',
    3: 'Dress',
    4: 'Coat',
    5: 'Sandal',
    6: 'Shirt',
    7: 'Sneaker',
    8: 'Bag',
    9: 'Ankle Boot',
}
sample_idx = torch.randint(len(train_data), size = (1,)).item()
image, label = train_data[sample_idx]
image.shape

torch.Size([1, 28, 28])

In [30]:
batch_size = 128

train_dataloader= DataLoader(train_data, batch_size = batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

### Training/Test Loop

In [31]:
def train_loop(device, dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 1000 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

    return model.state_dict()


def test_loop(device, dataloader, model, loss_fn):
      size = len(dataloader.dataset)
      num_batches = len(dataloader)
      test_loss, correct = 0, 0

      with torch.no_grad():
        for X, y in dataloader:
          X, y = X.to(device), y.to(device)
          pred = model(X)
          test_loss += loss_fn(pred, y).item()
          correct += (pred.argmax(1) == y).type(torch.float).sum().item()

      test_loss /= num_batches
      correct /= size
      print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
      return 100*correct, test_loss

### Training

In [32]:
learning_rate = 3.35e-4#4.1
epochs = 60

In [33]:
def train_test(device, train_dataloader, test_dataloader, net, learning_rate, epochs):
    net.to(device)
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)
    
    accuracies = []
    losses = []
    times=[]
    
    time_s = time.time()
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------------------------")
        train_loop(device, train_dataloader, net, loss_fn, optimizer)
        acc, loss = test_loop(device, test_dataloader, net, loss_fn)
        accuracies.append(acc)
        losses.append(loss)
        times.append(time.time() - time_s)
    print("Done!")
    return pd.DataFrame(
        {
            "epoch": [ i for i in range(epochs)],
            "times": times,
            "loss": losses,
            "accuracy": accuracies
        }
    )

NetA1 -> HF Train 

In [34]:
print(net_a1_hf.state_dict())
df_net_a1_hf = train_test(device, train_dataloader, test_dataloader, net_a1_hf, learning_rate, epochs)
df_net_a1_hf.to_csv('NetA1HF_results.csv', index=False)
torch.save({'initialization': net_a1_hf.state_dict()}, 'NetA1HF_trained.pt')
net_a1_hf.state_dict()

OrderedDict([('conv1.weight', tensor([[[[1., 0., 0., 0., 1.],
          [0., 1., 0., 1., 0.],
          [0., 0., 1., 0., 0.],
          [0., 1., 0., 1., 0.],
          [1., 0., 0., 0., 1.]]],


        [[[0., 0., 1., 0., 0.],
          [1., 1., 0., 1., 1.],
          [0., 0., 1., 0., 0.],
          [1., 1., 0., 1., 1.],
          [0., 0., 1., 0., 0.]]],


        [[[0., 1., 1., 1., 0.],
          [1., 1., 0., 1., 1.],
          [1., 0., 0., 0., 1.],
          [1., 1., 0., 1., 1.],
          [0., 1., 1., 1., 0.]]],


        [[[1., 1., 0., 1., 1.],
          [0., 1., 0., 1., 0.],
          [0., 0., 1., 0., 0.],
          [1., 1., 0., 1., 1.],
          [1., 1., 0., 1., 1.]]]])), ('conv1.bias', tensor([0., 0., 0., 0.])), ('linear1.weight', tensor([[-0.0308,  0.0350, -0.0375,  ...,  0.0210, -0.0087, -0.0375],
        [-0.0086, -0.0217, -0.0280,  ..., -0.0164,  0.0151,  0.0401],
        [-0.0174, -0.0276,  0.0234,  ..., -0.0107, -0.0063,  0.0163],
        ...,
        [ 0.0014,  0.0093,  0

OrderedDict([('conv1.weight',
              tensor([[[[1., 0., 0., 0., 1.],
                        [0., 1., 0., 1., 0.],
                        [0., 0., 1., 0., 0.],
                        [0., 1., 0., 1., 0.],
                        [1., 0., 0., 0., 1.]]],
              
              
                      [[[0., 0., 1., 0., 0.],
                        [1., 1., 0., 1., 1.],
                        [0., 0., 1., 0., 0.],
                        [1., 1., 0., 1., 1.],
                        [0., 0., 1., 0., 0.]]],
              
              
                      [[[0., 1., 1., 1., 0.],
                        [1., 1., 0., 1., 1.],
                        [1., 0., 0., 0., 1.],
                        [1., 1., 0., 1., 1.],
                        [0., 1., 1., 1., 0.]]],
              
              
                      [[[1., 1., 0., 1., 1.],
                        [0., 1., 0., 1., 0.],
                        [0., 0., 1., 0., 0.],
                        [1., 1., 0., 1., 1.],


NetA1-> HT train

In [35]:
print(net_a1_ht.state_dict())
df_net_a1_ht = train_test(device, train_dataloader, test_dataloader, net_a1_ht, learning_rate, epochs)
df_net_a1_ht.to_csv('NetA1HT_results.csv', index=False)
torch.save({'initialization': net_a1_ht.state_dict()}, 'NetA1HT_trained.pt')
net_a1_ht.state_dict()

OrderedDict([('conv1.weight', tensor([[[[1., 0., 0., 0., 1.],
          [0., 1., 0., 1., 0.],
          [0., 0., 1., 0., 0.],
          [0., 1., 0., 1., 0.],
          [1., 0., 0., 0., 1.]]],


        [[[0., 0., 1., 0., 0.],
          [1., 1., 0., 1., 1.],
          [0., 0., 1., 0., 0.],
          [1., 1., 0., 1., 1.],
          [0., 0., 1., 0., 0.]]],


        [[[0., 1., 1., 1., 0.],
          [1., 1., 0., 1., 1.],
          [1., 0., 0., 0., 1.],
          [1., 1., 0., 1., 1.],
          [0., 1., 1., 1., 0.]]],


        [[[1., 1., 0., 1., 1.],
          [0., 1., 0., 1., 0.],
          [0., 0., 1., 0., 0.],
          [1., 1., 0., 1., 1.],
          [1., 1., 0., 1., 1.]]]])), ('conv1.bias', tensor([0., 0., 0., 0.])), ('linear1.weight', tensor([[-0.0308,  0.0350, -0.0375,  ...,  0.0210, -0.0087, -0.0375],
        [-0.0086, -0.0217, -0.0280,  ..., -0.0164,  0.0151,  0.0401],
        [-0.0174, -0.0276,  0.0234,  ..., -0.0107, -0.0063,  0.0163],
        ...,
        [ 0.0014,  0.0093,  0

OrderedDict([('conv1.weight',
              tensor([[[[ 1.2183e+00,  2.8325e-02, -1.2216e-01,  2.1136e-01,  1.4437e+00],
                        [ 2.2626e-01,  1.2044e+00,  7.0670e-02,  1.4468e+00,  1.1033e-01],
                        [ 1.5126e-01,  2.6986e-01,  1.0204e+00,  4.6302e-01, -5.9271e-02],
                        [ 1.4950e-01,  1.0262e+00, -2.2274e-02,  1.0609e+00,  9.0414e-02],
                        [ 1.1168e+00,  6.7376e-02, -1.4932e-01,  1.9719e-01,  1.2310e+00]]],
              
              
                      [[[-1.4701e-03,  7.3060e-02,  1.8331e+00,  1.3713e-01, -1.5949e-01],
                        [ 1.1946e+00,  1.2364e+00,  8.1732e-01,  1.3197e+00,  9.9091e-01],
                        [ 2.4730e-01,  1.9423e-01,  1.5210e+00,  1.7718e-01, -3.1371e-01],
                        [ 1.4889e+00,  9.1130e-01,  7.5617e-01,  1.0226e+00,  9.4003e-01],
                        [-1.2306e-01, -2.3071e-01,  1.3387e+00, -1.7354e-01, -2.2742e-01]]],
              
           

NetA1-> DT train

In [36]:
print(net_a1_dt.state_dict())
df_net_a1_dt = train_test(device, train_dataloader, test_dataloader, net_a1_dt, learning_rate, epochs)
df_net_a1_dt.to_csv('NetA1DT_results.csv', index=False)
torch.save({'initialization': net_a1_dt.state_dict()}, 'NetA1DT_trained.pt')
net_a1_dt.state_dict()

OrderedDict([('conv1.weight', tensor([[[[ 0.0275,  0.0715,  0.1256, -0.1479, -0.1567],
          [-0.0226,  0.0100, -0.1715, -0.1725, -0.1208],
          [-0.0613,  0.0349,  0.0108,  0.1250, -0.1452],
          [ 0.1867,  0.1224,  0.1494, -0.0769,  0.1487],
          [ 0.1965,  0.1156, -0.1832,  0.0777, -0.0907]]],


        [[[ 0.1958, -0.1091, -0.0394,  0.0192,  0.1636],
          [ 0.0413,  0.0136, -0.0527,  0.0598, -0.0349],
          [-0.0775,  0.1340,  0.0160,  0.0956, -0.0025],
          [-0.0257,  0.1061,  0.1439, -0.0596,  0.0039],
          [-0.0475, -0.0550, -0.1336,  0.0186, -0.1608]]],


        [[[-0.1140,  0.0929, -0.1603, -0.1158, -0.0963],
          [ 0.0612,  0.1323, -0.0751, -0.0008,  0.1730],
          [-0.0210,  0.0880,  0.1101,  0.0027, -0.0185],
          [-0.1288,  0.1883,  0.1363, -0.1108, -0.1461],
          [-0.1278,  0.1464, -0.0103,  0.1723,  0.1606]]],


        [[[-0.0337, -0.1431, -0.0953, -0.1032, -0.0572],
          [ 0.1921, -0.1182, -0.0249, -0.1109,

OrderedDict([('conv1.weight',
              tensor([[[[ 0.5147,  0.6134,  0.4386, -0.1886, -0.3646],
                        [ 0.3639,  0.3088,  0.0538, -0.0511, -0.3185],
                        [ 0.3375,  0.2911,  0.2094,  0.4347, -0.1082],
                        [ 0.9686,  0.5961,  0.6051,  0.4283,  0.3701],
                        [ 0.9524,  0.5788, -0.0121,  0.3675,  0.0136]]],
              
              
                      [[[ 0.3201, -0.1353, -0.0719,  0.2254,  0.6204],
                        [ 0.4320,  0.2959,  0.3680,  0.5170,  0.2406],
                        [ 0.2979,  0.5338,  0.3017,  0.3671,  0.0988],
                        [ 0.3818,  0.3478,  0.3470,  0.1499,  0.0558],
                        [ 0.1031, -0.0164, -0.1185,  0.0867, -0.2251]]],
              
              
                      [[[-0.1509,  0.2358, -0.3273, -0.3818,  0.0629],
                        [ 0.3563,  0.6370,  0.3187,  0.2370,  0.6795],
                        [-0.0414,  0.6113,  0.2466, -0

NetA2-> HF Train

In [37]:
print(net_a2_hf.state_dict())
df_net_a2_hf = train_test(device, train_dataloader, test_dataloader, net_a2_hf, learning_rate, epochs)
df_net_a2_hf.to_csv('NetA2HF_results.csv', index=False)
torch.save({'initialization': net_a2_hf.state_dict()}, 'NetA2HF_trained.pt')
net_a2_hf.state_dict() 

OrderedDict([('conv1.weight', tensor([[[[1., 0., 0., 0., 1.],
          [0., 1., 0., 1., 0.],
          [0., 0., 1., 0., 0.],
          [0., 1., 0., 1., 0.],
          [1., 0., 0., 0., 1.]]],


        [[[0., 0., 1., 0., 0.],
          [1., 1., 0., 1., 1.],
          [0., 0., 1., 0., 0.],
          [1., 1., 0., 1., 1.],
          [0., 0., 1., 0., 0.]]],


        [[[0., 1., 1., 1., 0.],
          [1., 1., 0., 1., 1.],
          [1., 0., 0., 0., 1.],
          [1., 1., 0., 1., 1.],
          [0., 1., 1., 1., 0.]]],


        [[[1., 1., 0., 1., 1.],
          [0., 1., 0., 1., 0.],
          [0., 0., 1., 0., 0.],
          [1., 1., 0., 1., 1.],
          [1., 1., 0., 1., 1.]]]])), ('conv1.bias', tensor([0., 0., 0., 0.])), ('conv2.weight', tensor([[[[-1.6052e-01, -6.2000e-02,  1.0837e-01],
          [-5.0771e-02, -8.0880e-02,  2.3348e-02],
          [-6.8565e-02, -1.7490e-03,  3.6397e-03]],

         [[-3.7457e-03,  4.4198e-02,  1.4329e-01],
          [ 1.4342e-03,  9.4503e-02,  9.8810e-02

OrderedDict([('conv1.weight',
              tensor([[[[1., 0., 0., 0., 1.],
                        [0., 1., 0., 1., 0.],
                        [0., 0., 1., 0., 0.],
                        [0., 1., 0., 1., 0.],
                        [1., 0., 0., 0., 1.]]],
              
              
                      [[[0., 0., 1., 0., 0.],
                        [1., 1., 0., 1., 1.],
                        [0., 0., 1., 0., 0.],
                        [1., 1., 0., 1., 1.],
                        [0., 0., 1., 0., 0.]]],
              
              
                      [[[0., 1., 1., 1., 0.],
                        [1., 1., 0., 1., 1.],
                        [1., 0., 0., 0., 1.],
                        [1., 1., 0., 1., 1.],
                        [0., 1., 1., 1., 0.]]],
              
              
                      [[[1., 1., 0., 1., 1.],
                        [0., 1., 0., 1., 0.],
                        [0., 0., 1., 0., 0.],
                        [1., 1., 0., 1., 1.],


NetA2-> HT Train

In [38]:
print(net_a2_ht.state_dict())
df_net_a2_ht = train_test(device, train_dataloader, test_dataloader, net_a2_ht, learning_rate, epochs)
df_net_a2_ht.to_csv('NetA2HT_results.csv', index=False)
torch.save({'initialization': net_a2_ht.state_dict()}, 'NetA2HT_trained.pt')
net_a2_ht.state_dict()

OrderedDict([('conv1.weight', tensor([[[[1., 0., 0., 0., 1.],
          [0., 1., 0., 1., 0.],
          [0., 0., 1., 0., 0.],
          [0., 1., 0., 1., 0.],
          [1., 0., 0., 0., 1.]]],


        [[[0., 0., 1., 0., 0.],
          [1., 1., 0., 1., 1.],
          [0., 0., 1., 0., 0.],
          [1., 1., 0., 1., 1.],
          [0., 0., 1., 0., 0.]]],


        [[[0., 1., 1., 1., 0.],
          [1., 1., 0., 1., 1.],
          [1., 0., 0., 0., 1.],
          [1., 1., 0., 1., 1.],
          [0., 1., 1., 1., 0.]]],


        [[[1., 1., 0., 1., 1.],
          [0., 1., 0., 1., 0.],
          [0., 0., 1., 0., 0.],
          [1., 1., 0., 1., 1.],
          [1., 1., 0., 1., 1.]]]])), ('conv1.bias', tensor([0., 0., 0., 0.])), ('conv2.weight', tensor([[[[-1.6052e-01, -6.2000e-02,  1.0837e-01],
          [-5.0771e-02, -8.0880e-02,  2.3348e-02],
          [-6.8565e-02, -1.7490e-03,  3.6397e-03]],

         [[-3.7457e-03,  4.4198e-02,  1.4329e-01],
          [ 1.4342e-03,  9.4503e-02,  9.8810e-02

OrderedDict([('conv1.weight',
              tensor([[[[ 1.1262,  0.1075, -0.0870, -0.0846,  1.8457],
                        [-0.2603,  1.1281,  0.7913,  1.4658,  0.2752],
                        [-0.0234,  0.3841,  1.5265,  0.7119,  0.4967],
                        [-0.0956,  0.9792,  0.3302,  1.0249,  0.0470],
                        [ 0.9983, -0.0494,  0.2457,  0.0681,  1.0438]]],
              
              
                      [[[-0.3275,  0.2842,  1.5510,  0.5833, -0.3892],
                        [ 1.2335,  1.5805,  1.3342,  2.0544,  1.1795],
                        [-0.0717,  0.2067,  1.9478,  0.6418,  0.1125],
                        [-0.1573,  0.6388,  0.5613,  1.0572,  0.9055],
                        [-0.8954, -0.0262,  1.4276,  0.3624, -0.0451]]],
              
              
                      [[[ 0.2069,  1.3344,  1.7686,  1.3148, -0.1821],
                        [ 0.7095,  1.3760,  1.0920,  1.3641,  1.2512],
                        [ 1.0053,  0.1303,  0.4392,  0

NetA2-> DT Train

In [39]:
print(net_a2_dt.state_dict())
df_net_a2_dt = train_test(device, train_dataloader, test_dataloader, net_a2_dt, learning_rate, epochs)
df_net_a2_dt.to_csv('NetA2DT_results.csv', index=False)
torch.save({'initialization': net_a2_dt.state_dict()}, 'NetA2DT_trained.pt')
net_a2_dt.state_dict()

OrderedDict([('conv1.weight', tensor([[[[ 0.0275,  0.0715,  0.1256, -0.1479, -0.1567],
          [-0.0226,  0.0100, -0.1715, -0.1725, -0.1208],
          [-0.0613,  0.0349,  0.0108,  0.1250, -0.1452],
          [ 0.1867,  0.1224,  0.1494, -0.0769,  0.1487],
          [ 0.1965,  0.1156, -0.1832,  0.0777, -0.0907]]],


        [[[ 0.1958, -0.1091, -0.0394,  0.0192,  0.1636],
          [ 0.0413,  0.0136, -0.0527,  0.0598, -0.0349],
          [-0.0775,  0.1340,  0.0160,  0.0956, -0.0025],
          [-0.0257,  0.1061,  0.1439, -0.0596,  0.0039],
          [-0.0475, -0.0550, -0.1336,  0.0186, -0.1608]]],


        [[[-0.1140,  0.0929, -0.1603, -0.1158, -0.0963],
          [ 0.0612,  0.1323, -0.0751, -0.0008,  0.1730],
          [-0.0210,  0.0880,  0.1101,  0.0027, -0.0185],
          [-0.1288,  0.1883,  0.1363, -0.1108, -0.1461],
          [-0.1278,  0.1464, -0.0103,  0.1723,  0.1606]]],


        [[[-0.0337, -0.1431, -0.0953, -0.1032, -0.0572],
          [ 0.1921, -0.1182, -0.0249, -0.1109,

OrderedDict([('conv1.weight',
              tensor([[[[ 0.4309,  0.7635,  0.4031, -0.1265, -0.1777],
                        [ 0.2771,  0.5113,  0.4370, -0.1511, -0.3245],
                        [ 0.2340,  0.2092,  0.3924,  0.2479, -0.0751],
                        [ 0.4789,  0.5039,  0.4232, -0.1991,  0.0963],
                        [ 0.6423,  0.5881,  0.2937,  0.1533, -0.1147]]],
              
              
                      [[[ 0.7746,  0.6257,  0.5472,  0.6948,  0.3142],
                        [ 0.5473,  0.7029,  0.8824,  0.9372,  0.2170],
                        [-0.0467,  0.0691,  0.2509,  0.3820,  0.1022],
                        [-0.4274,  0.0692,  0.3995,  0.0965, -0.1380],
                        [-0.3820, -0.0621,  0.0978,  0.1411, -0.2133]]],
              
              
                      [[[ 0.1741,  0.2850,  0.1097, -0.0692, -0.2193],
                        [ 0.1892,  0.3428,  0.3135,  0.3275,  0.3277],
                        [ 0.0565,  0.2706,  0.4631,  0