In [1]:
import torch
from torchvision import datasets
from torchvision.models.optical_flow.raft import ResidualBlock
from torchvision.transforms import ToTensor
from torch import nn
from torch import Tensor
from torch.utils.data import DataLoader

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Network definition

In [3]:
class NetA1(nn.Module):
    def __init__(self, num_classes: int, freeze: bool = False):
        super(NetA1, self).__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=3, stride=2)
        self.flatten = nn.Flatten(start_dim=-3)
        self.linear1 = nn.Linear(676, num_classes)
        self.softmax = nn.Softmax(dim=0)
        if freeze:
            for param in self.conv1.parameters():
                param.requires_grad = False

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.flatten(x)
        x = self.relu(self.linear1(x))
        x = self.softmax(x)
        return x

In [4]:
class NetA2(nn.Module):
    def __init__(self, num_classes: int, freeze: bool = False):
        super(NetA2, self).__init__()
        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=4, kernel_size=3, stride=2)
        self.conv2 = nn.Conv2d(in_channels=4, out_channels=6, kernel_size=3)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.flatten = nn.Flatten(start_dim=-5)
        self.linear1 = nn.Linear(150, 200)
        self.linear2 = nn.Linear(200, 100)
        self.linear3 = nn.Linear(100, num_classes)
        self.softmax = nn.Softmax(dim=0)
        if freeze:
            for param in self.conv1.parameters():
                param.requires_grad = False

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.pool1(x))
        x = self.flatten(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.relu(self.linear3(x))
        x = self.softmax(x)
        return x

### Network initialization

In [5]:
initialization_weights = torch.tensor([
    [[[1, 0, 1], [0, 1, 0], [1, 0, 1]]],
    [[[1, 1, 0], [0, 0, 1], [1, 1, 0]]],
    [[[0, 1, 1], [1, 0, 0], [0, 1, 1]]],
    [[[0, 1, 0], [1, 1, 0], [0, 1, 0]]]
            ], dtype=torch.float32)

initialization_biases = torch.tensor([0,0,0,0], dtype=torch.float32)
initialization_weights.shape

torch.Size([4, 1, 3, 3])

In [6]:
net_a1_hf = NetA1(10, True)
net_a1_ht = NetA1(10)
net_a1_dt = NetA1(10)

#set conv1 initialization
net_a1_hf.conv1.weight = nn.Parameter(initialization_weights)
net_a1_hf.conv1.bias = nn.Parameter(initialization_biases)

# set same weights and bias to each layer of each network (except for cov1 of net_a1_dt)
net_a1_ht.load_state_dict(net_a1_hf.state_dict())
net_a1_dt.linear1.load_state_dict(net_a1_hf.linear1.state_dict())

#save weights and bias of nat_a1_h* and net_a1_dt
torch.save({'initialization': net_a1_hf.state_dict()}, 'NetA1H+_init.pt')
torch.save({'initialization': net_a1_dt.state_dict()}, 'NetA1DT_init.pt')


# print weights and bias
print("Net_A1_HF: \n \t", net_a1_hf.state_dict())
print("Net_A1_HT: \n \t", net_a1_ht.state_dict())
print("Net_A1_DT: \n \t", net_a1_dt.state_dict())

Net_A1_HF: 
 	 OrderedDict([('conv1.weight', tensor([[[[1., 0., 1.],
          [0., 1., 0.],
          [1., 0., 1.]]],


        [[[1., 1., 0.],
          [0., 0., 1.],
          [1., 1., 0.]]],


        [[[0., 1., 1.],
          [1., 0., 0.],
          [0., 1., 1.]]],


        [[[0., 1., 0.],
          [1., 1., 0.],
          [0., 1., 0.]]]])), ('conv1.bias', tensor([0., 0., 0., 0.])), ('linear1.weight', tensor([[-0.0263, -0.0308,  0.0384,  ...,  0.0129, -0.0315,  0.0162],
        [-0.0317,  0.0038, -0.0353,  ...,  0.0243,  0.0187, -0.0363],
        [ 0.0232,  0.0346, -0.0316,  ...,  0.0108, -0.0291, -0.0165],
        ...,
        [-0.0332, -0.0109,  0.0028,  ..., -0.0177,  0.0285, -0.0358],
        [-0.0114, -0.0106, -0.0180,  ..., -0.0125, -0.0106, -0.0105],
        [ 0.0046, -0.0014, -0.0378,  ..., -0.0194,  0.0318,  0.0112]])), ('linear1.bias', tensor([ 0.0020, -0.0011, -0.0044, -0.0213,  0.0053,  0.0065,  0.0043, -0.0234,
         0.0307, -0.0009]))])
Net_A1_HT: 
 	 OrderedDict

In [11]:
net_a2_hf = NetA2(10, True)
net_a2_ht = NetA2(10)
net_a2_dt = NetA2(10)

# set same weights and bias to each layer of each network (except for cov1 of net_a1_dt)
net_a2_ht.load_state_dict(net_a2_hf.state_dict())
net_a2_dt.load_state_dict(net_a2_hf.state_dict())

#set conv1 initialization
net_a2_hf.conv1.load_state_dict(net_a1_hf.conv1.state_dict())
net_a2_ht.conv1.load_state_dict(net_a1_hf.conv1.state_dict())
net_a2_dt.conv1.load_state_dict(net_a1_dt.conv1.state_dict())

#save weights and bias of nat_a1_h* and net_a1_dt
torch.save({'initialization': net_a1_hf.state_dict()}, 'NetA2H+_init.pt')
torch.save({'initialization': net_a1_dt.state_dict()}, 'NetA2DT_init.pt')


# print weights and bias
print("Net_A1_HF: \n \t", net_a2_hf.state_dict())
print("Net_A1_HT: \n \t", net_a2_ht.state_dict())
print("Net_A1_DT: \n \t", net_a2_dt.state_dict())

Net_A1_HF: 
 	 OrderedDict([('conv1.weight', tensor([[[[1., 0., 1.],
          [0., 1., 0.],
          [1., 0., 1.]]],


        [[[1., 1., 0.],
          [0., 0., 1.],
          [1., 1., 0.]]],


        [[[0., 1., 1.],
          [1., 0., 0.],
          [0., 1., 1.]]],


        [[[0., 1., 0.],
          [1., 1., 0.],
          [0., 1., 0.]]]])), ('conv1.bias', tensor([0., 0., 0., 0.])), ('conv2.weight', tensor([[[[ 0.0277,  0.1149, -0.1208],
          [ 0.1161,  0.1605, -0.0957],
          [-0.1079, -0.1183, -0.0668]],

         [[-0.1547, -0.0691, -0.1027],
          [-0.0794, -0.0236,  0.0004],
          [-0.0838,  0.0406, -0.0066]],

         [[ 0.0825, -0.1004,  0.0103],
          [ 0.0283,  0.0741, -0.0702],
          [-0.0884,  0.0566, -0.1151]],

         [[ 0.0678,  0.1228, -0.0446],
          [ 0.1102,  0.0567, -0.1177],
          [-0.0948,  0.1466,  0.1233]]],


        [[[ 0.0481, -0.1236, -0.0548],
          [-0.1211,  0.0455, -0.0900],
          [-0.0881, -0.0630,  0.012

### Preliminary Analysys

In [14]:
print( "Net_A1: \n",
       "\t|W_{conv_a1_hf} - W_{conv_a1_ht}| =", torch.norm(net_a1_hf.conv1.weight - net_a1_ht.conv1.weight),"\n",
      "\t|W_{linear_a1_hf} - W_{linear_a1_ht}| =", torch.norm(net_a1_hf.linear1.weight - net_a1_ht.linear1.weight), "\n",
      "\t|W_{linear_a1_hf} - W_{linear_a1_dt}| =", torch.norm(net_a1_hf.linear1.weight - net_a1_dt.linear1.weight), "\n")

print( "Net_A2: \n",
       "\t|W_{conv1_a2_hf} - W_{conv1_a2_ht}| =", torch.norm(net_a2_hf.conv1.weight - net_a2_ht.conv1.weight),"\n",
       "\t|W_{conv2_a2_hf} - W_{conv2_a2_ht}| =", torch.norm(net_a2_hf.conv2.weight - net_a2_ht.conv2.weight),"\n",
       "\t|W_{linear1_a2_hf} - W_{linear1_a2_ht}| =", torch.norm(net_a2_hf.linear1.weight - net_a2_ht.linear1.weight), "\n",
       "\t|W_{linear1_a2_hf} - W_{linear1_a2_dt}| =", torch.norm(net_a2_hf.linear1.weight - net_a2_dt.linear1.weight), "\n",
       "\t|W_{linear2_a2_hf} - W_{linear2_a2_ht}| =", torch.norm(net_a2_hf.linear2.weight - net_a2_ht.linear2.weight), "\n",
       "\t|W_{linear2_a2_hf} - W_{linear2_a2_dt}| =", torch.norm(net_a2_hf.linear2.weight - net_a2_dt.linear2.weight), "\n",
       "\t|W_{linear3_a2_hf} - W_{linear3_a2_ht}| =", torch.norm(net_a2_hf.linear3.weight - net_a2_ht.linear3.weight), "\n",
       "\t|W_{linear3_a2_hf} - W_{linear3_a2_dt}| =", torch.norm(net_a2_hf.linear3.weight - net_a2_dt.linear3.weight), "\n")

print( "Net_A1 Vs Net_A2: \n",
       "\t|W_{conv1_a1_hf} - W_{conv1_a2_hf}| =", torch.norm(net_a1_hf.conv1.weight - net_a2_hf.conv1.weight),"\n",
       "\t|W_{conv1_a1_ht} - W_{conv2_a2_ht}| =", torch.norm(net_a1_ht.conv1.weight - net_a2_ht.conv1.weight),"\n",
       "\t|W_{conv1_a1_dt} - W_{conv2_a2_dt}| =", torch.norm(net_a1_dt.conv1.weight - net_a2_dt.conv1.weight),"\n")

Net_A1: 
 	|W_{conv_a1_hf} - W_{conv_a1_ht}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{linear_a1_hf} - W_{linear_a1_ht}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{linear_a1_hf} - W_{linear_a1_dt}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 

Net_A2: 
 	|W_{conv1_a2_hf} - W_{conv1_a2_ht}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{conv2_a2_hf} - W_{conv2_a2_ht}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{linear1_a2_hf} - W_{linear1_a2_ht}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{linear1_a2_hf} - W_{linear1_a2_dt}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{linear2_a2_hf} - W_{linear2_a2_ht}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{linear2_a2_hf} - W_{linear2_a2_dt}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{linear3_a2_hf} - W_{linear3_a2_ht}| = tensor(0., grad_fn=<LinalgVectorNormBackward0>) 
 	|W_{linear3_a2_hf} - W_{linear3_a2_dt}| = tensor(0., grad_fn=<Linal

### Data Loading

In [6]:
train_data= datasets.FashionMNIST(root='data', train=True, download=True, transform=ToTensor(),)

test_data = datasets.FashionMNIST(root='data', train=False, download=True, transform=ToTensor(),)

In [6]:
labels_map={
    0: 'T-shirt',
    1: 'Trouser',
    2: 'Pullover',
    3: 'Dress',
    4: 'Coat',
    5: 'Sandal',
    6: 'Shirt',
    7: 'Sneaker',
    8: 'Bag',
    9: 'Ankle Boot',
}
sample_idx = torch.randint(len(train_data), size = (1,)).item()
image, label = train_data[sample_idx]
image.shape

torch.Size([1, 28, 28])

In [23]:
batch_size = 128

train_dataloader= DataLoader(train_data, batch_size = batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

### Training/Test Loop

In [ ]:
def train_loop(device, dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 1000 == 0:
            loss, current = loss.item(), (batch + 1) * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

    return model.state_dict()


def test_loop(device, dataloader, model, loss_fn):
      size = len(dataloader.dataset)
      num_batches = len(dataloader)
      test_loss, correct = 0, 0

      with torch.no_grad():
        for X, y in dataloader:
          X, y = X.to(device), y.to(device)
          pred = model(X)
          test_loss += loss_fn(pred, y).item()
          correct += (pred.argmax(1) == y).type(torch.float).sum().item()

      test_loss /= num_batches
      correct /= size
      print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
      return (100*correct)

### Training

### Testing