In [None]:
from torch import nn
from collections import OrderedDict
import torch.nn.functional as F
import torch
from torch.utils.data import DataLoader
import torchvision
import random
from torch.utils.data import Subset
from matplotlib import pyplot as plt
from torchsummary import summary
from torchvision import transforms
import progressbar as pb
import numpy as np

In [None]:
SUM = lambda x,y : x+y

In [None]:
def check_equity(property,a,b):
    pa = getattr(a,property)
    pb = getattr(b,property)
    assert  pa==pb, "Different {}: {}!={}".format(property,pa,pb)

    return pa

In [None]:
def module_unwrap(mod:nn.Module,recursive=False):
    children = OrderedDict()
    try:
        for name, module in mod.named_children():
            if (recursive):
                recursive_call = module_unwrap(module,recursive=True)
                if (len(recursive_call)>0):
                    for k,v in recursive_call.items():
                        children[name+"_"+k] = v
                else:
                    children[name] = module
            else:
                children[name] = module
    except AttributeError:
        pass

    return children

In [None]:
class VGGBlock(nn.Module):
    def __init__(self, in_channels, out_channels,batch_norm=False):

        super().__init__()

        conv2_params = {'kernel_size': (3, 3),
                        'stride'     : (1, 1),
                        'padding'   : 1
                        }

        noop = lambda x : x

        self._batch_norm = batch_norm

        self.conv1 = nn.Conv2d(in_channels=in_channels,out_channels=out_channels , **conv2_params)
        #self.bn1 = nn.BatchNorm2d(out_channels) if batch_norm else noop
        self.bn1 = nn.GroupNorm(16, out_channels) if batch_norm else noop

        self.conv2 = nn.Conv2d(in_channels=out_channels,out_channels=out_channels, **conv2_params)
        #self.bn2 = nn.BatchNorm2d(out_channels) if batch_norm else noop
        self.bn2 = nn.GroupNorm(16, out_channels) if batch_norm else noop

        self.max_pooling = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))

    @property
    def batch_norm(self):
        return self._batch_norm

    def forward(self,x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)

        x = self.max_pooling(x)

        return x

In [None]:
class Classifier(nn.Module):

    def __init__(self,num_classes=1):
        super().__init__()

        self.classifier = nn.Sequential(
            nn.Linear(2048, 2048),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 512),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self,x):

        return self.classifier(x)

In [None]:
class VGG16(nn.Module):

  def __init__(self, input_size, batch_norm=False):
    super(VGG16, self).__init__()

    self.in_channels,self.in_width,self.in_height = input_size

    self.block_1 = VGGBlock(self.in_channels,64,batch_norm=batch_norm)
    self.block_2 = VGGBlock(64, 128,batch_norm=batch_norm)
    self.block_3 = VGGBlock(128, 256,batch_norm=batch_norm)
    self.block_4 = VGGBlock(256,512,batch_norm=batch_norm)


  @property
  def input_size(self):
      return self.in_channels,self.in_width,self.in_height

  def forward(self, x):

    x = self.block_1(x)
    x = self.block_2(x)
    x = self.block_3(x)
    x = self.block_4(x)
    # x = self.avgpool(x)
    x = torch.flatten(x,1)

    return x

In [None]:
class CombinedLoss(nn.Module):
    def __init__(self, loss_a, loss_b, loss_combo, _lambda=1.0):
        super().__init__()
        self.loss_a = loss_a
        self.loss_b = loss_b
        self.loss_combo = loss_combo

        self.register_buffer('_lambda',torch.tensor(float(_lambda),dtype=torch.float32))


    def forward(self,y_hat,y):

        return self.loss_a(y_hat[0],y[0]) + self.loss_b(y_hat[1],y[1]) + self._lambda * self.loss_combo(y_hat[2],torch.cat(y,0))

----------------------------------------------------------------------

In [None]:
DO='TRAIN'

In [None]:
random.seed(47)

In [None]:
combo_fn = SUM

In [None]:
lambda_reg = 1

In [None]:
def train(nets, loaders, optimizer, criterion, epochs=20, dev=None, save_param=False, model_name="valerio"):
    # try:
      nets = [n.to(dev) for n in nets]

      model_a = module_unwrap(nets[0], True)
      model_b = module_unwrap(nets[1], True)
      model_c = module_unwrap(nets[2], True)

      reg_loss = nn.MSELoss()

      criterion.to(dev)
      reg_loss.to(dev)

      # Initialize history
      history_loss = {"train": [], "val": [], "test": []}
      history_accuracy = {"train": [], "val": [], "test": []}
      # Store the best val accuracy
      best_val_accuracy = 0

      # Process each epoch
      for epoch in range(epochs):
        # Initialize epoch variables
        sum_loss = {"train": 0, "val": 0, "test": 0}
        sum_accuracy = {"train": [0,0,0], "val": [0,0,0], "test": [0,0,0]}

        progbar = None
        # Process each split
        for split in ["train", "val", "test"]:
          if split == "train":
            for n in nets:
              n.train()
            widgets = [
              ' [', pb.Timer(), '] ',
              pb.Bar(),
              ' [', pb.ETA(), '] ', pb.Variable('ta','[Train Acc: {formatted_value}]')
            ]

            progbar = pb.ProgressBar(max_value=len(loaders[split][0]),widgets=widgets,redirect_stdout=True)

          else:
            for n in nets:
              n.eval()
          # Process each batch
          for j,((input_a, labels_a),(input_b, labels_b)) in enumerate(zip(loaders[split][0],loaders[split][1])):

            input_a = input_a.to(dev)
            input_b = input_b.to(dev)

            labels_a = labels_a.float().to(dev)
            labels_b = labels_b.float().to(dev)


            inputs = torch.cat([input_a,input_b],axis=0)
            labels = torch.cat([labels_a, labels_b])


            # Reset gradients
            optimizer.zero_grad()
            # Compute output
            features_a = nets[0](input_a)
            features_b = nets[1](input_b)
            features_c = nets[2](inputs)

            pred_a = torch.squeeze(nets[3](features_a))
            pred_b = torch.squeeze(nets[3](features_b))
            pred_c = torch.squeeze(nets[3](features_c))

            loss = criterion(pred_a, labels_a) + criterion(pred_b, labels_b) + criterion(pred_c, labels)

            for n in model_a:
              layer_a = model_a[n]
              layer_b = model_b[n]
              layer_c = model_c[n]
              if (isinstance(layer_a,nn.Conv2d)):
                loss += lambda_reg * reg_loss(combo_fn(layer_a.weight,layer_b.weight),layer_c.weight)
                if (layer_a.bias is not None):
                  loss += lambda_reg * reg_loss(combo_fn(layer_a.bias, layer_b.bias), layer_c.bias)

            # Update loss
            sum_loss[split] += loss.item()
            # Check parameter update
            if split == "train":
              # Compute gradients
              loss.backward()
              # Optimize
              optimizer.step()

            # Compute accuracy

            #https://discuss.pytorch.org/t/bcewithlogitsloss-and-model-accuracy-calculation/59293/ 2
            pred_labels_a = (pred_a >= 0.0).long()  # Binarize predictions to 0 and 1
            pred_labels_b = (pred_b >= 0.0).long()  # Binarize predictions to 0 and 1
            pred_labels_c = (pred_c >= 0.0).long()  # Binarize predictions to 0 and 1


            batch_accuracy_a = (pred_labels_a == labels_a).sum().item() / len(labels_a)
            batch_accuracy_b = (pred_labels_b == labels_b).sum().item() / len(labels_b)
            batch_accuracy_c = (pred_labels_c == labels).sum().item() / len(labels)

            # Update accuracy
            sum_accuracy[split][0] += batch_accuracy_a
            sum_accuracy[split][1] += batch_accuracy_b
            sum_accuracy[split][2] += batch_accuracy_c


            if (split=='train'):
              progbar.update(j, ta=batch_accuracy_c)

        if (progbar is not None):
          progbar.finish()
        # Compute epoch loss/accuracy
        epoch_loss = {split: sum_loss[split] / len(loaders[split][0]) for split in ["train", "val", "test"]}
        epoch_accuracy = {split: [sum_accuracy[split][i] / len(loaders[split][0]) for i in range(len(sum_accuracy[split])) ] for split in ["train", "val", "test"]}

        # # Store params at the best validation accuracy
        # if save_param and epoch_accuracy["val"] > best_val_accuracy:
        #   # torch.save(net.state_dict(), f"{net.__class__.__name__}_best_val.pth")
        #   torch.save(net.state_dict(), f"{model_name}_best_val.pth")
        #   best_val_accuracy = epoch_accuracy["val"]

        print(f"Epoch {epoch + 1}:")
        # Update history
        for split in ["train", "val", "test"]:
          history_loss[split].append(epoch_loss[split])
          history_accuracy[split].append(epoch_accuracy[split])
          # Print info
          print(f"\t{split}\tLoss: {epoch_loss[split]:0.5}\tVGG 1:{epoch_accuracy[split][0]:0.5}"
                f"\tVGG 2:{epoch_accuracy[split][1]:0.5}\tVGG *:{epoch_accuracy[split][2]:0.5}")

      if save_param:
        torch.save({'vgg_a':nets[0].state_dict(),'vgg_b':nets[1].state_dict(),'vgg_star':nets[2].state_dict(),'classifier':nets[3].state_dict()},f'{model_name}.pth')

In [None]:
def test(net,classifier, loader):

      net.to(dev)
      classifier.to(dev)

      net.eval()

      sum_accuracy = 0

      # Process each batch
      for j, (input, labels) in enumerate(loader):

        input = input.to(dev)
        labels = labels.float().to(dev)

        features = net(input)

        pred = torch.squeeze(classifier(features))

        # https://discuss.pytorch.org/t/bcewithlogitsloss-and-model-accuracy-calculation/59293/ 2
        pred_labels = (pred >= 0.0).long()  # Binarize predictions to 0 and 1

        batch_accuracy = (pred_labels == labels).sum().item() / len(labels)

        # Update accuracy
        sum_accuracy += batch_accuracy

      epoch_accuracy = sum_accuracy / len(loader)

      print(f"Accuracy: {epoch_accuracy:0.5}")

In [None]:
def parse_dataset(dataset):

  dataset.targets = dataset.targets % 2

  return dataset

In [None]:
root_dir = './'

In [None]:
rescale_data = transforms.Lambda(lambda x : x/255)

# Compose transformations
data_transform = transforms.Compose([
  transforms.Resize(32),
  transforms.RandomHorizontalFlip(),
  transforms.ToTensor(),
  rescale_data,
  #transforms.Normalize((-0.7376), (0.5795))
])

test_transform = transforms.Compose([
  transforms.Resize(32),
  transforms.ToTensor(),
  rescale_data,
  #transforms.Normalize((0.1327), (0.2919))
])

In [None]:
# Load MNIST dataset with transforms
train_set = torchvision.datasets.MNIST(root=root_dir, train=True, download=True, transform=data_transform)
test_set = torchvision.datasets.MNIST(root=root_dir, train=False, download=True, transform=test_transform)

In [None]:
train_set = parse_dataset(train_set)
test_set = parse_dataset(test_set)

In [None]:
train_idx = np.random.permutation(np.arange(len(train_set)))
test_idx = np.arange(len(test_set))

val_frac = 0.1

n_val = int(len(train_idx) * val_frac)
val_idx = train_idx[0:n_val]
train_idx = train_idx[n_val:]

h = len(train_idx)//2

train_set_a = Subset(train_set,train_idx[0:h])
train_set_b = Subset(train_set,train_idx[h:])

h = len(val_idx)//2

val_set_a = Subset(train_set,val_idx[0:h])
val_set_b = Subset(train_set,val_idx[h:])

h = len(test_idx)//2

test_set_a = Subset(test_set,test_idx[0:h])
test_set_b = Subset(test_set,test_idx[h:])

In [None]:
# Define loaders

train_loader_a = DataLoader(train_set_a, batch_size=128, num_workers=0, shuffle=True, drop_last=True)
val_loader_a   = DataLoader(val_set_a,   batch_size=128, num_workers=0, shuffle=False, drop_last=False)
test_loader_a  = DataLoader(test_set_a,  batch_size=128, num_workers=0, shuffle=False, drop_last=False)

train_loader_b = DataLoader(train_set_b, batch_size=128, num_workers=0, shuffle=True, drop_last=True)
val_loader_b   = DataLoader(val_set_b,   batch_size=128, num_workers=0, shuffle=False, drop_last=False)
test_loader_b  = DataLoader(test_set_b,  batch_size=128, num_workers=0, shuffle=False, drop_last=False)

test_loader_all = DataLoader(test_set,batch_size=128, num_workers=0,shuffle=False,drop_last=False)


# Define dictionary of loaders
loaders = {"train": [train_loader_a,train_loader_b],
           "val":   [val_loader_a,val_loader_b],
           "test":  [test_loader_a,test_loader_b]}

In [None]:
model1 = VGG16((1,32,32),batch_norm=True)
model2 = VGG16((1,32,32),batch_norm=True)
model3 = VGG16((1,32,32),batch_norm=True)
classifier = Classifier(num_classes=1)

In [None]:
nets = [model1,model2,model3,classifier]

In [None]:
dev = torch.device('cuda')

In [None]:
parameters = set()

In [None]:
for n in nets:
  parameters |= set(n.parameters())

In [None]:
optimizer = torch.optim.SGD(parameters, lr = 0.01)
# Define a loss
criterion = nn.BCEWithLogitsLoss()#,nn.BCEWithLogitsLoss(),nn.BCEWithLogitsLoss(),_lambda = 1)
n_params = 0

In [None]:
if (DO=='TRAIN'):
  train(nets, loaders, optimizer, criterion, epochs=50, dev=dev,save_param=True)
else:
  state_dicts = torch.load('model.pth')
  model1.load_state_dict(state_dicts['vgg_a']) #questi state_dict vengono dalla funzione di training
  model2.load_state_dict(state_dicts['vgg_b'])
  model3.load_state_dict(state_dicts['vgg_star'])
  classifier.load_state_dict(state_dicts['classifier'])

  test(model1,classifier,test_loader_all)
  test(model2, classifier, test_loader_all)
  test(model3, classifier, test_loader_all)

  summed_state_dict = OrderedDict()

  for key in state_dicts['vgg_star']:
    if key.find('conv') >=0:
      print(key)
      summed_state_dict[key] = combo_fn(state_dicts['vgg_a'][key],state_dicts['vgg_b'][key])
    else:
      summed_state_dict[key] = state_dicts['vgg_star'][key]

  model3.load_state_dict(summed_state_dict)
  test(model3, classifier, test_loader_all)

 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:   0.59]


Epoch 1:
	train	Loss: 2.2846	VGG 1:0.49706	VGG 2:0.4981	VGG *:0.50346
	val	Loss: 2.258	VGG 1:0.50284	VGG 2:0.55752	VGG *:0.6276
	test	Loss: 2.258	VGG 1:0.50859	VGG 2:0.57402	VGG *:0.60625


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.469]


Epoch 2:
	train	Loss: 2.2467	VGG 1:0.50387	VGG 2:0.50621	VGG *:0.50945
	val	Loss: 2.2306	VGG 1:0.50284	VGG 2:0.51014	VGG *:0.51446
	test	Loss: 2.2306	VGG 1:0.50859	VGG 2:0.50293	VGG *:0.51387


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:   0.52]


Epoch 3:
	train	Loss: 2.2217	VGG 1:0.50539	VGG 2:0.50201	VGG *:0.52238
	val	Loss: 2.2022	VGG 1:0.50284	VGG 2:0.48986	VGG *:0.70219
	test	Loss: 2.2018	VGG 1:0.50859	VGG 2:0.49707	VGG *:0.71035


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.574]


Epoch 4:
	train	Loss: 2.1906	VGG 1:0.50461	VGG 2:0.50365	VGG *:0.58023
	val	Loss: 2.1993	VGG 1:0.50284	VGG 2:0.48986	VGG *:0.49351
	test	Loss: 2.1986	VGG 1:0.50859	VGG 2:0.49707	VGG *:0.49424


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.676]


Epoch 5:
	train	Loss: 2.1524	VGG 1:0.50167	VGG 2:0.50461	VGG *:0.59714
	val	Loss: 2.0624	VGG 1:0.50284	VGG 2:0.51014	VGG *:0.71038
	test	Loss: 2.0531	VGG 1:0.50859	VGG 2:0.50293	VGG *:0.71689


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:   0.66]


Epoch 6:
	train	Loss: 2.0141	VGG 1:0.5035	VGG 2:0.50614	VGG *:0.73756
	val	Loss: 2.1296	VGG 1:0.50284	VGG 2:0.48986	VGG *:0.76809
	test	Loss: 2.1277	VGG 1:0.50859	VGG 2:0.49707	VGG *:0.77402


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.875]


Epoch 7:
	train	Loss: 1.9429	VGG 1:0.49929	VGG 2:0.50967	VGG *:0.78503
	val	Loss: 1.8425	VGG 1:0.49716	VGG 2:0.48986	VGG *:0.84431
	test	Loss: 1.8378	VGG 1:0.49141	VGG 2:0.49707	VGG *:0.8459


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.703]


Epoch 8:
	train	Loss: 1.9094	VGG 1:0.50525	VGG 2:0.53315	VGG *:0.79621
	val	Loss: 2.1154	VGG 1:0.50284	VGG 2:0.51014	VGG *:0.5924
	test	Loss: 2.105	VGG 1:0.50859	VGG 2:0.50293	VGG *:0.60049


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:   0.91]


Epoch 9:
	train	Loss: 1.8778	VGG 1:0.49896	VGG 2:0.57277	VGG *:0.80002
	val	Loss: 1.6439	VGG 1:0.49716	VGG 2:0.70512	VGG *:0.88642
	test	Loss: 1.6052	VGG 1:0.49141	VGG 2:0.7459	VGG *:0.90479


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.832]


Epoch 10:
	train	Loss: 1.7233	VGG 1:0.50848	VGG 2:0.65852	VGG *:0.85645
	val	Loss: 1.6934	VGG 1:0.50284	VGG 2:0.50223	VGG *:0.89609
	test	Loss: 1.6727	VGG 1:0.50859	VGG 2:0.51074	VGG *:0.90459


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.602]


Epoch 11:
	train	Loss: 1.6386	VGG 1:0.5109	VGG 2:0.80167	VGG *:0.78538
	val	Loss: 1.6764	VGG 1:0.50284	VGG 2:0.82603	VGG *:0.76004
	test	Loss: 1.6085	VGG 1:0.50859	VGG 2:0.84824	VGG *:0.77559


 [Elapsed Time: 0:00:59] |###############| [Time:  0:00:59] [Train Acc:  0.938]


Epoch 12:
	train	Loss: 1.3621	VGG 1:0.51842	VGG 2:0.85309	VGG *:0.90409
	val	Loss: 1.1565	VGG 1:0.49716	VGG 2:0.91002	VGG *:0.94489
	test	Loss: 1.1457	VGG 1:0.49141	VGG 2:0.91465	VGG *:0.94189


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.965]


Epoch 13:
	train	Loss: 1.1836	VGG 1:0.52746	VGG 2:0.88698	VGG *:0.95039
	val	Loss: 1.0382	VGG 1:0.51125	VGG 2:0.9421	VGG *:0.96131
	test	Loss: 0.99586	VGG 1:0.49922	VGG 2:0.95586	VGG *:0.96436


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.957]


Epoch 14:
	train	Loss: 1.0375	VGG 1:0.55208	VGG 2:0.93765	VGG *:0.95884
	val	Loss: 0.95135	VGG 1:0.49781	VGG 2:0.95773	VGG *:0.97096
	test	Loss: 0.9268	VGG 1:0.49277	VGG 2:0.9666	VGG *:0.97275


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.949]


Epoch 15:
	train	Loss: 0.98504	VGG 1:0.58259	VGG 2:0.94498	VGG *:0.96568
	val	Loss: 1.0518	VGG 1:0.49716	VGG 2:0.93941	VGG *:0.96131
	test	Loss: 1.0514	VGG 1:0.49141	VGG 2:0.94453	VGG *:0.96152


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:   0.98]


Epoch 16:
	train	Loss: 0.80943	VGG 1:0.72184	VGG 2:0.95822	VGG *:0.96948
	val	Loss: 0.64794	VGG 1:0.79813	VGG 2:0.97289	VGG *:0.97573
	test	Loss: 0.64864	VGG 1:0.7877	VGG 2:0.98086	VGG *:0.97695


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.973]


Epoch 17:
	train	Loss: 1.0971	VGG 1:0.77154	VGG 2:0.76916	VGG *:0.97307
	val	Loss: 0.69527	VGG 1:0.83761	VGG 2:0.92718	VGG *:0.97768
	test	Loss: 0.66795	VGG 1:0.81758	VGG 2:0.95352	VGG *:0.97988


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.984]


Epoch 18:
	train	Loss: 0.69982	VGG 1:0.80606	VGG 2:0.94498	VGG *:0.97563
	val	Loss: 0.48525	VGG 1:0.87998	VGG 2:0.96987	VGG *:0.98196
	test	Loss: 0.47239	VGG 1:0.87461	VGG 2:0.98105	VGG *:0.98018


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.988]


Epoch 19:
	train	Loss: 0.61609	VGG 1:0.82634	VGG 2:0.96901	VGG *:0.97959
	val	Loss: 0.47802	VGG 1:0.87788	VGG 2:0.97205	VGG *:0.98196
	test	Loss: 0.46425	VGG 1:0.87031	VGG 2:0.98477	VGG *:0.98184


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:   0.52]


Epoch 20:
	train	Loss: 0.79113	VGG 1:0.90465	VGG 2:0.97221	VGG *:0.72431
	val	Loss: 1.0368	VGG 1:0.92932	VGG 2:0.96554	VGG *:0.50649
	test	Loss: 1.0249	VGG 1:0.91133	VGG 2:0.9793	VGG *:0.50576


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.691]


Epoch 21:
	train	Loss: 1.0415	VGG 1:0.91254	VGG 2:0.97522	VGG *:0.5189
	val	Loss: 1.1521	VGG 1:0.849	VGG 2:0.97819	VGG *:0.73686
	test	Loss: 1.1444	VGG 1:0.8375	VGG 2:0.98867	VGG *:0.75127


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.809]


Epoch 22:
	train	Loss: 0.97288	VGG 1:0.90856	VGG 2:0.98073	VGG *:0.62442
	val	Loss: 0.67251	VGG 1:0.95219	VGG 2:0.97484	VGG *:0.83812
	test	Loss: 0.65279	VGG 1:0.94531	VGG 2:0.9832	VGG *:0.84316


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.996]


Epoch 23:
	train	Loss: 0.41871	VGG 1:0.95856	VGG 2:0.97954	VGG *:0.91994
	val	Loss: 0.2924	VGG 1:0.9548	VGG 2:0.9834	VGG *:0.97761
	test	Loss: 0.28794	VGG 1:0.95254	VGG 2:0.98574	VGG *:0.97793


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.977]


Epoch 24:
	train	Loss: 0.35771	VGG 1:0.96399	VGG 2:0.94304	VGG *:0.97675
	val	Loss: 0.30303	VGG 1:0.96717	VGG 2:0.97224	VGG *:0.97187
	test	Loss: 0.2704	VGG 1:0.97051	VGG 2:0.97891	VGG *:0.975


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:   0.98]


Epoch 25:
	train	Loss: 0.31483	VGG 1:0.96812	VGG 2:0.95257	VGG *:0.97945
	val	Loss: 0.26029	VGG 1:0.97075	VGG 2:0.97656	VGG *:0.97891
	test	Loss: 0.23802	VGG 1:0.96758	VGG 2:0.98672	VGG *:0.98066


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.984]


Epoch 26:
	train	Loss: 0.50724	VGG 1:0.841	VGG 2:0.96953	VGG *:0.98196
	val	Loss: 0.56897	VGG 1:0.8271	VGG 2:0.97028	VGG *:0.9774
	test	Loss: 0.53319	VGG 1:0.83945	VGG 2:0.98594	VGG *:0.98018


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.688]


Epoch 27:
	train	Loss: 0.42657	VGG 1:0.94576	VGG 2:0.98333	VGG *:0.89501
	val	Loss: 0.80707	VGG 1:0.96856	VGG 2:0.98275	VGG *:0.62865
	test	Loss: 0.80222	VGG 1:0.95469	VGG 2:0.98906	VGG *:0.64014


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.965]


Epoch 28:
	train	Loss: 0.69238	VGG 1:0.97169	VGG 2:0.98609	VGG *:0.70277
	val	Loss: 0.25535	VGG 1:0.97173	VGG 2:0.98493	VGG *:0.96819
	test	Loss: 0.22826	VGG 1:0.97324	VGG 2:0.99023	VGG *:0.97471


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.992]


Epoch 29:
	train	Loss: 0.33731	VGG 1:0.97593	VGG 2:0.98411	VGG *:0.92472
	val	Loss: 0.21661	VGG 1:0.97931	VGG 2:0.98503	VGG *:0.98249
	test	Loss: 0.2149	VGG 1:0.97324	VGG 2:0.99082	VGG *:0.9792


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.977]


Epoch 30:
	train	Loss: 0.28386	VGG 1:0.97586	VGG 2:0.98713	VGG *:0.94609
	val	Loss: 0.21505	VGG 1:0.97866	VGG 2:0.98558	VGG *:0.9774
	test	Loss: 0.22101	VGG 1:0.97461	VGG 2:0.98828	VGG *:0.97617


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.977]


Epoch 31:
	train	Loss: 0.24733	VGG 1:0.97939	VGG 2:0.96164	VGG *:0.98253
	val	Loss: 0.23815	VGG 1:0.97377	VGG 2:0.98275	VGG *:0.97847
	test	Loss: 0.23426	VGG 1:0.96582	VGG 2:0.98867	VGG *:0.97637


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.992]


Epoch 32:
	train	Loss: 0.17611	VGG 1:0.98162	VGG 2:0.98791	VGG *:0.98426
	val	Loss: 0.19921	VGG 1:0.97931	VGG 2:0.98721	VGG *:0.98147
	test	Loss: 0.19805	VGG 1:0.97324	VGG 2:0.99023	VGG *:0.98369


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.988]


Epoch 33:
	train	Loss: 0.15984	VGG 1:0.98411	VGG 2:0.98925	VGG *:0.98629
	val	Loss: 0.1773	VGG 1:0.97954	VGG 2:0.9873	VGG *:0.98537
	test	Loss: 0.18808	VGG 1:0.97227	VGG 2:0.99004	VGG *:0.98359


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.984]


Epoch 34:
	train	Loss: 0.14495	VGG 1:0.98471	VGG 2:0.99085	VGG *:0.98676
	val	Loss: 0.19941	VGG 1:0.97368	VGG 2:0.98633	VGG *:0.98391
	test	Loss: 0.21055	VGG 1:0.96582	VGG 2:0.98926	VGG *:0.98242


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.996]


Epoch 35:
	train	Loss: 0.13653	VGG 1:0.98564	VGG 2:0.99022	VGG *:0.98778
	val	Loss: 0.17728	VGG 1:0.97954	VGG 2:0.98656	VGG *:0.98656
	test	Loss: 0.17906	VGG 1:0.97891	VGG 2:0.98926	VGG *:0.98418


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.969]


Epoch 36:
	train	Loss: 0.12649	VGG 1:0.98638	VGG 2:0.99237	VGG *:0.98876
	val	Loss: 0.19363	VGG 1:0.982	VGG 2:0.98363	VGG *:0.98347
	test	Loss: 0.1899	VGG 1:0.97695	VGG 2:0.98438	VGG *:0.98379


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.988]


Epoch 37:
	train	Loss: 0.22364	VGG 1:0.9885	VGG 2:0.93281	VGG *:0.98958
	val	Loss: 0.65787	VGG 1:0.98103	VGG 2:0.73419	VGG *:0.98689
	test	Loss: 0.65417	VGG 1:0.97578	VGG 2:0.71816	VGG *:0.98633


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.988]


Epoch 38:
	train	Loss: 0.39013	VGG 1:0.98865	VGG 2:0.859	VGG *:0.98977
	val	Loss: 0.1771	VGG 1:0.98256	VGG 2:0.98307	VGG *:0.9851
	test	Loss: 0.15979	VGG 1:0.98105	VGG 2:0.98984	VGG *:0.98818


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.996]


Epoch 39:
	train	Loss: 0.45697	VGG 1:0.98862	VGG 2:0.81399	VGG *:0.99036
	val	Loss: 0.25552	VGG 1:0.98298	VGG 2:0.95285	VGG *:0.98624
	test	Loss: 0.24121	VGG 1:0.97949	VGG 2:0.9541	VGG *:0.98691


 [Elapsed Time: 0:00:59] |###############| [Time:  0:00:59] [Train Acc:  0.988]


Epoch 40:
	train	Loss: 0.13408	VGG 1:0.99077	VGG 2:0.98289	VGG *:0.99109
	val	Loss: 0.2083	VGG 1:0.97466	VGG 2:0.98168	VGG *:0.98342
	test	Loss: 0.19734	VGG 1:0.97324	VGG 2:0.98691	VGG *:0.98418


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.984]


Epoch 41:
	train	Loss: 0.11707	VGG 1:0.99096	VGG 2:0.98687	VGG *:0.99154
	val	Loss: 0.15875	VGG 1:0.98256	VGG 2:0.98633	VGG *:0.98672
	test	Loss: 0.15994	VGG 1:0.97695	VGG 2:0.98984	VGG *:0.98545


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.992]


Epoch 42:
	train	Loss: 0.10322	VGG 1:0.99122	VGG 2:0.99077	VGG *:0.9917
	val	Loss: 0.16106	VGG 1:0.98158	VGG 2:0.98568	VGG *:0.98835
	test	Loss: 0.1736	VGG 1:0.97227	VGG 2:0.99023	VGG *:0.98594


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.988]


Epoch 43:
	train	Loss: 0.092217	VGG 1:0.99007	VGG 2:0.9933	VGG *:0.99317
	val	Loss: 0.1544	VGG 1:0.98461	VGG 2:0.98819	VGG *:0.98493
	test	Loss: 0.17427	VGG 1:0.9793	VGG 2:0.99062	VGG *:0.98154


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.996]


Epoch 44:
	train	Loss: 0.087574	VGG 1:0.99249	VGG 2:0.99304	VGG *:0.9931
	val	Loss: 0.16519	VGG 1:0.98191	VGG 2:0.98763	VGG *:0.98803
	test	Loss: 0.16748	VGG 1:0.97734	VGG 2:0.99023	VGG *:0.9873


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.977]


Epoch 45:
	train	Loss: 0.082136	VGG 1:0.99297	VGG 2:0.99371	VGG *:0.99358
	val	Loss: 0.1697	VGG 1:0.98396	VGG 2:0.98828	VGG *:0.98514
	test	Loss: 0.15436	VGG 1:0.98301	VGG 2:0.99082	VGG *:0.98496


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.992]


Epoch 46:
	train	Loss: 0.076632	VGG 1:0.99196	VGG 2:0.9949	VGG *:0.99362
	val	Loss: 0.16375	VGG 1:0.98451	VGG 2:0.98568	VGG *:0.98656
	test	Loss: 0.17649	VGG 1:0.97813	VGG 2:0.99062	VGG *:0.9834


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.996]


Epoch 47:
	train	Loss: 0.077059	VGG 1:0.9939	VGG 2:0.99304	VGG *:0.99436
	val	Loss: 0.17609	VGG 1:0.97972	VGG 2:0.982	VGG *:0.98954
	test	Loss: 0.17591	VGG 1:0.97539	VGG 2:0.98887	VGG *:0.9873


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.988]


Epoch 48:
	train	Loss: 0.072806	VGG 1:0.99345	VGG 2:0.99449	VGG *:0.99446
	val	Loss: 0.15718	VGG 1:0.98363	VGG 2:0.99056	VGG *:0.9864
	test	Loss: 0.15707	VGG 1:0.98047	VGG 2:0.99277	VGG *:0.98633


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:    1.0]


Epoch 49:
	train	Loss: 0.061804	VGG 1:0.99438	VGG 2:0.99624	VGG *:0.99503
	val	Loss: 0.16968	VGG 1:0.97996	VGG 2:0.98819	VGG *:0.98717
	test	Loss: 0.14496	VGG 1:0.98008	VGG 2:0.99277	VGG *:0.98887


 [Elapsed Time: 0:00:58] |###############| [Time:  0:00:58] [Train Acc:  0.996]


Epoch 50:
	train	Loss: 0.072728	VGG 1:0.99371	VGG 2:0.9923	VGG *:0.99494
	val	Loss: 0.16869	VGG 1:0.98451	VGG 2:0.98275	VGG *:0.989
	test	Loss: 0.17232	VGG 1:0.97656	VGG 2:0.98945	VGG *:0.98623


In [None]:
DO = 'TEST'
if (DO=='TRAIN'):
  train(nets, loaders, optimizer, criterion, epochs=50, dev=dev,save_param=True)
else:
  state_dicts = torch.load('valerio.pth')
  model1.load_state_dict(state_dicts['vgg_a']) #questi state_dict vengono dalla funzione di training
  model2.load_state_dict(state_dicts['vgg_b'])
  model3.load_state_dict(state_dicts['vgg_star'])
  classifier.load_state_dict(state_dicts['classifier'])

  test(model1,classifier,test_loader_all)
  test(model2, classifier, test_loader_all)
  test(model3, classifier, test_loader_all)

  summed_state_dict = OrderedDict()

  for key in state_dicts['vgg_star']:
    if key.find('conv') >=0:
      print(key)
      summed_state_dict[key] = combo_fn(state_dicts['vgg_a'][key],state_dicts['vgg_b'][key])
    else:
      summed_state_dict[key] = state_dicts['vgg_star'][key]

  model3.load_state_dict(summed_state_dict)
  test(model3, classifier, test_loader_all)

Accuracy: 0.98576
Accuracy: 0.98289
Accuracy: 0.98754
block_1.conv1.weight
block_1.conv1.bias
block_1.conv2.weight
block_1.conv2.bias
block_2.conv1.weight
block_2.conv1.bias
block_2.conv2.weight
block_2.conv2.bias
block_3.conv1.weight
block_3.conv1.bias
block_3.conv2.weight
block_3.conv2.bias
block_4.conv1.weight
block_4.conv1.bias
block_4.conv2.weight
block_4.conv2.bias
Accuracy: 0.83416


In [None]:
weights11 = list(model1.block_1.parameters())
weights11

[Parameter containing:
 tensor([[[[ 0.2350, -0.0563,  0.2557],
           [ 0.2926,  0.0502, -0.1769],
           [ 0.0748, -0.2991, -0.2421]]],
 
 
         [[[ 0.2821,  0.2767,  0.2640],
           [ 0.2872,  0.3097,  0.3331],
           [ 0.3125,  0.2188, -0.1556]]],
 
 
         [[[-0.0940,  0.3259,  0.0999],
           [-0.0204, -0.0185, -0.3087],
           [ 0.1126, -0.1424,  0.2864]]],
 
 
         [[[-0.1262,  0.1147, -0.1848],
           [ 0.3077, -0.2682, -0.0260],
           [ 0.1535,  0.2444,  0.3458]]],
 
 
         [[[ 0.0804, -0.1650,  0.3128],
           [ 0.0133, -0.2200, -0.2818],
           [ 0.0911, -0.3181, -0.0082]]],
 
 
         [[[-0.1636,  0.2045, -0.1861],
           [ 0.2362,  0.1350,  0.0661],
           [ 0.0854, -0.0678,  0.2648]]],
 
 
         [[[-0.1364, -0.3079, -0.1902],
           [ 0.1071, -0.0757, -0.1022],
           [ 0.0473, -0.0123, -0.2513]]],
 
 
         [[[ 0.0335, -0.1298, -0.2520],
           [-0.0122,  0.0713,  0.1048],
           [ 0.

In [None]:
weights12 = list(model1.block_2.parameters())
weights12

[Parameter containing:
 tensor([[[[ 0.0017,  0.0386,  0.0080],
           [-0.0293,  0.0165, -0.0059],
           [-0.0128, -0.0406, -0.0389]],
 
          [[-0.0088,  0.0060,  0.0299],
           [ 0.0203, -0.0067, -0.0045],
           [-0.0345, -0.0107,  0.0178]],
 
          [[ 0.0165, -0.0107,  0.0072],
           [-0.0007,  0.0179,  0.0081],
           [-0.0401,  0.0289,  0.0044]],
 
          ...,
 
          [[-0.0094, -0.0105,  0.0364],
           [-0.0299,  0.0352, -0.0391],
           [ 0.0220, -0.0302, -0.0217]],
 
          [[ 0.0094,  0.0154, -0.0339],
           [-0.0341,  0.0258, -0.0418],
           [ 0.0383,  0.0278,  0.0044]],
 
          [[-0.0385, -0.0238, -0.0054],
           [ 0.0331, -0.0320,  0.0387],
           [ 0.0281,  0.0272,  0.0227]]],
 
 
         [[[-0.0253,  0.0239,  0.0352],
           [ 0.0194,  0.0360,  0.0153],
           [-0.0170,  0.0350, -0.0175]],
 
          [[ 0.0330, -0.0340,  0.0234],
           [ 0.0207,  0.0316,  0.0363],
           [-0.0

In [None]:
weights13 = list(model1.block_3.parameters())
weights13

[Parameter containing:
 tensor([[[[ 2.6235e-02, -2.4213e-02,  1.0596e-02],
           [ 6.1124e-03, -7.4373e-03,  1.5477e-02],
           [-2.3593e-02,  4.1108e-03,  2.6329e-03]],
 
          [[-2.7877e-02, -1.3510e-02, -2.2814e-02],
           [-8.9939e-03,  1.6797e-02, -2.5395e-02],
           [-2.7805e-03,  4.3453e-03, -3.4963e-02]],
 
          [[-3.0643e-02,  1.9780e-02,  2.0662e-03],
           [ 2.8571e-03,  1.7534e-02, -3.2839e-02],
           [ 1.7093e-02, -3.0366e-03, -2.4389e-02]],
 
          ...,
 
          [[ 3.2849e-03,  6.4357e-04,  4.5100e-03],
           [ 1.3417e-02, -1.2704e-02,  2.2582e-02],
           [ 9.5408e-03,  2.0843e-02,  2.4465e-02]],
 
          [[ 1.2890e-02,  2.3796e-02,  2.8769e-02],
           [-4.8637e-03, -8.6722e-03, -8.1848e-03],
           [ 2.5876e-03, -8.1970e-03, -1.6028e-02]],
 
          [[ 3.8722e-03, -6.2281e-03, -1.9739e-02],
           [ 4.1624e-03, -1.4322e-02,  1.3055e-02],
           [-1.8113e-02,  2.2856e-02, -5.4694e-03]]],
 
 
   

In [None]:
weights14 = list(model1.block_4.parameters())
weights14

[Parameter containing:
 tensor([[[[-0.0114, -0.0203, -0.0200],
           [-0.0168,  0.0132,  0.0102],
           [-0.0179, -0.0109, -0.0197]],
 
          [[-0.0196,  0.0201,  0.0178],
           [ 0.0038,  0.0151,  0.0050],
           [-0.0012,  0.0146, -0.0174]],
 
          [[-0.0073,  0.0124, -0.0004],
           [ 0.0044, -0.0110,  0.0031],
           [ 0.0089,  0.0076, -0.0090]],
 
          ...,
 
          [[-0.0053, -0.0115, -0.0018],
           [-0.0073,  0.0044, -0.0158],
           [ 0.0203, -0.0024,  0.0184]],
 
          [[-0.0127,  0.0004, -0.0144],
           [ 0.0015, -0.0085, -0.0040],
           [ 0.0001,  0.0079, -0.0092]],
 
          [[ 0.0003, -0.0154, -0.0065],
           [-0.0163,  0.0092,  0.0133],
           [-0.0213, -0.0132,  0.0116]]],
 
 
         [[[ 0.0123, -0.0013,  0.0103],
           [-0.0168,  0.0045,  0.0102],
           [ 0.0139,  0.0200, -0.0169]],
 
          [[ 0.0153, -0.0129, -0.0138],
           [-0.0014,  0.0056, -0.0093],
           [ 0.0

In [None]:
weights21 = list(model2.block_1.parameters())
weights21

[Parameter containing:
 tensor([[[[-0.1137,  0.1984,  0.0562],
           [ 0.0733,  0.1061, -0.1401],
           [ 0.2302, -0.2168, -0.2712]]],
 
 
         [[[ 0.0218, -0.2155, -0.0485],
           [ 0.0386,  0.2016,  0.1191],
           [-0.0608,  0.0082,  0.1475]]],
 
 
         [[[ 0.2700, -0.1646, -0.1243],
           [-0.2086,  0.2193,  0.3348],
           [ 0.0780,  0.3147,  0.0025]]],
 
 
         [[[-0.1970,  0.3185, -0.1783],
           [ 0.3422, -0.2180,  0.3274],
           [ 0.1561,  0.1767, -0.2409]]],
 
 
         [[[-0.1918,  0.2150,  0.1131],
           [-0.2910, -0.2684,  0.0046],
           [-0.2926, -0.0305, -0.0767]]],
 
 
         [[[ 0.1740,  0.3101,  0.2530],
           [ 0.1718, -0.0903, -0.2646],
           [-0.1042,  0.1935,  0.0049]]],
 
 
         [[[-0.1393,  0.2496, -0.1802],
           [-0.1760, -0.2235,  0.2493],
           [-0.2152, -0.1251, -0.0585]]],
 
 
         [[[ 0.0912,  0.0043,  0.2825],
           [ 0.0306, -0.2676, -0.1096],
           [ 0.

In [None]:
weights22 = list(model2.block_2.parameters())
weights22

[Parameter containing:
 tensor([[[[-0.0162, -0.0067,  0.0125],
           [ 0.0268,  0.0329, -0.0126],
           [-0.0006, -0.0216, -0.0235]],
 
          [[-0.0242,  0.0009, -0.0003],
           [-0.0283,  0.0337, -0.0057],
           [ 0.0255, -0.0222, -0.0148]],
 
          [[ 0.0003,  0.0006,  0.0313],
           [-0.0090,  0.0228,  0.0137],
           [-0.0247,  0.0317, -0.0302]],
 
          ...,
 
          [[ 0.0170,  0.0206,  0.0409],
           [-0.0333,  0.0035, -0.0166],
           [ 0.0028, -0.0236,  0.0194]],
 
          [[ 0.0052, -0.0024,  0.0301],
           [ 0.0072, -0.0156, -0.0159],
           [ 0.0233,  0.0265, -0.0310]],
 
          [[-0.0223,  0.0394,  0.0113],
           [ 0.0370, -0.0315, -0.0203],
           [-0.0009,  0.0061,  0.0180]]],
 
 
         [[[ 0.0278, -0.0340, -0.0155],
           [-0.0236,  0.0017, -0.0133],
           [ 0.0168,  0.0294, -0.0365]],
 
          [[-0.0320, -0.0257, -0.0025],
           [-0.0340, -0.0230,  0.0136],
           [ 0.0

In [None]:
weights23 = list(model2.block_3.parameters())
weights23

[Parameter containing:
 tensor([[[[ 2.5545e-02,  1.1165e-03,  1.5954e-02],
           [-2.8370e-02,  5.8127e-03, -1.1172e-02],
           [-1.1929e-03,  2.4560e-02, -1.8545e-03]],
 
          [[ 2.7708e-02, -1.2897e-02,  7.6667e-03],
           [-2.4900e-02, -5.7964e-03,  9.6179e-03],
           [ 1.1133e-02,  2.7113e-02,  1.5912e-02]],
 
          [[ 2.8086e-02, -5.8985e-03, -7.5760e-03],
           [-2.2659e-02, -2.1737e-02,  1.1539e-03],
           [ 7.6775e-03, -2.3518e-02, -1.6213e-02]],
 
          ...,
 
          [[ 9.4731e-03, -2.0602e-02,  2.2784e-02],
           [-8.7087e-03,  2.3731e-02, -2.8401e-02],
           [-6.0912e-03,  1.3045e-02, -1.3235e-02]],
 
          [[-1.5281e-02, -1.7205e-02,  1.9188e-03],
           [-3.0274e-02, -2.4663e-02,  8.7642e-03],
           [-7.7695e-03,  2.3470e-03,  1.7835e-05]],
 
          [[ 7.2075e-03, -1.9011e-02,  2.8136e-04],
           [ 2.1840e-02,  1.9834e-02, -1.6850e-02],
           [-4.3692e-03,  2.6741e-02,  2.8380e-02]]],
 
 
   

In [None]:
weights24 = list(model2.block_4.parameters())
weights24

[Parameter containing:
 tensor([[[[-0.0209,  0.0085,  0.0085],
           [ 0.0052, -0.0073,  0.0149],
           [-0.0208, -0.0099, -0.0068]],
 
          [[ 0.0118, -0.0055, -0.0018],
           [-0.0182, -0.0064,  0.0065],
           [-0.0159,  0.0005,  0.0023]],
 
          [[ 0.0099, -0.0183, -0.0178],
           [ 0.0130, -0.0201,  0.0080],
           [ 0.0083,  0.0037,  0.0008]],
 
          ...,
 
          [[-0.0029, -0.0197, -0.0217],
           [-0.0006, -0.0157,  0.0216],
           [-0.0064,  0.0169,  0.0110]],
 
          [[ 0.0005,  0.0180, -0.0005],
           [-0.0045,  0.0013, -0.0120],
           [-0.0028, -0.0101,  0.0167]],
 
          [[ 0.0174,  0.0091, -0.0148],
           [-0.0130, -0.0113, -0.0109],
           [ 0.0181, -0.0086, -0.0023]]],
 
 
         [[[-0.0039, -0.0122, -0.0022],
           [-0.0161,  0.0003,  0.0089],
           [ 0.0131, -0.0151,  0.0162]],
 
          [[ 0.0204,  0.0046, -0.0044],
           [ 0.0018,  0.0095, -0.0012],
           [ 0.0

In [None]:
weights31 = list(model3.block_1.parameters())
weights31

[Parameter containing:
 tensor([[[[-0.2023,  0.2848, -0.2433],
           [ 0.1093, -0.1119,  0.0604],
           [ 0.2954, -0.3100, -0.0798]]],
 
 
         [[[ 0.2415,  0.3000, -0.1064],
           [ 0.0850,  0.0231,  0.2242],
           [-0.0234,  0.2347,  0.0524]]],
 
 
         [[[ 0.1845,  0.1452, -0.2903],
           [-0.1575,  0.1178, -0.0657],
           [-0.2589, -0.1802,  0.2113]]],
 
 
         [[[ 0.0312,  0.3346, -0.0653],
           [ 0.2532, -0.2167,  0.3153],
           [-0.0280, -0.1071,  0.2815]]],
 
 
         [[[ 0.2872, -0.2546,  0.3250],
           [-0.1138, -0.0413,  0.2351],
           [ 0.1593, -0.2558, -0.1132]]],
 
 
         [[[-0.2219, -0.1422,  0.1550],
           [ 0.1517,  0.1587, -0.1126],
           [-0.2909, -0.1791, -0.1040]]],
 
 
         [[[-0.1606, -0.0760, -0.0538],
           [-0.1752, -0.0946,  0.1787],
           [-0.0515, -0.1587,  0.1573]]],
 
 
         [[[-0.1059,  0.1066,  0.1627],
           [ 0.1950, -0.2875, -0.0921],
           [-0.

In [None]:
weights32 = list(model3.block_2.parameters())
weights32

[Parameter containing:
 tensor([[[[ 1.2416e-02,  2.5708e-02, -3.8679e-02],
           [-2.6562e-02,  2.6579e-02,  7.8719e-03],
           [-2.4055e-02, -2.0759e-02,  2.3100e-02]],
 
          [[ 2.3006e-03, -2.6198e-02,  8.3492e-03],
           [ 1.0510e-02,  2.5981e-02,  4.7630e-02],
           [ 3.0636e-02, -2.8290e-02,  2.5406e-02]],
 
          [[ 9.1573e-03,  4.2503e-03,  3.7169e-02],
           [ 3.4620e-02, -1.4993e-02,  2.1113e-02],
           [ 2.4074e-02, -2.7500e-03,  2.7897e-02]],
 
          ...,
 
          [[-1.0810e-02,  2.9404e-03, -1.9244e-02],
           [ 4.1073e-02, -5.6443e-03,  4.3717e-02],
           [ 2.2588e-02,  2.0390e-02, -1.5162e-02]],
 
          [[-4.7891e-03,  4.9989e-03, -2.2920e-03],
           [-2.4839e-02, -3.2393e-02, -3.2858e-02],
           [ 1.7898e-02, -3.0133e-02,  1.8548e-02]],
 
          [[-3.6746e-02,  2.4915e-02,  1.7301e-02],
           [ 1.6966e-02,  2.4759e-02, -7.7586e-03],
           [ 4.0273e-02,  7.6281e-03, -3.0942e-02]]],
 
 
   

In [None]:
weights33 = list(model3.block_3.parameters())
weights33

[Parameter containing:
 tensor([[[[-2.8350e-02,  2.6731e-02,  2.7146e-02],
           [ 1.8557e-02, -2.2121e-02,  2.3460e-02],
           [ 1.9065e-02,  6.3231e-03,  5.1209e-04]],
 
          [[-1.2441e-02, -1.8517e-02,  1.8330e-02],
           [ 1.6244e-02, -2.2109e-02,  1.3770e-02],
           [-1.3615e-02, -1.3823e-02, -9.4690e-04]],
 
          [[ 2.1933e-03,  1.2861e-02, -7.7686e-03],
           [ 2.2564e-02,  1.8144e-02, -1.3935e-02],
           [-1.4047e-02,  9.5706e-03,  1.6686e-02]],
 
          ...,
 
          [[ 2.5194e-02, -2.8444e-02,  3.5629e-03],
           [-2.2153e-02, -1.1606e-02,  2.5151e-02],
           [ 2.2010e-03,  2.0204e-02,  1.8645e-02]],
 
          [[ 2.7313e-02, -2.7937e-02, -5.6176e-03],
           [ 8.3533e-03, -2.8730e-02,  2.5881e-02],
           [-1.5982e-02,  1.4420e-02, -1.0478e-02]],
 
          [[-3.3624e-04,  1.2877e-02,  1.9532e-02],
           [ 7.6554e-03, -2.8169e-02, -2.0093e-02],
           [-1.2148e-03, -1.8436e-03,  1.6385e-02]]],
 
 
   

In [None]:
weights34 = list(model3.block_4.parameters())
weights34

[Parameter containing:
 tensor([[[[ 4.8594e-03, -1.3169e-04,  9.3526e-04],
           [ 1.5966e-02, -9.8638e-03, -5.5669e-03],
           [-7.8802e-03,  1.8352e-03, -1.7840e-02]],
 
          [[ 1.7543e-02,  1.6971e-02,  1.4636e-02],
           [-1.4488e-02, -5.9967e-03,  1.9790e-02],
           [ 2.0465e-02, -2.3789e-03,  1.3944e-02]],
 
          [[ 1.0343e-02,  4.7584e-03,  7.9877e-03],
           [-9.8885e-03, -1.1414e-04, -1.0168e-02],
           [ 7.7177e-03,  1.7085e-02, -1.2655e-02]],
 
          ...,
 
          [[ 1.2597e-02, -1.9097e-02, -8.4095e-03],
           [-8.6651e-05,  1.1941e-03,  5.7456e-03],
           [-1.6083e-02, -1.1878e-02,  1.1799e-02]],
 
          [[ 7.6769e-03, -6.5895e-04,  1.3349e-02],
           [ 3.3684e-03, -9.1500e-03,  1.2509e-02],
           [-1.3822e-02,  1.6248e-02, -9.6858e-03]],
 
          [[-1.6219e-02,  1.2469e-02,  9.8419e-03],
           [ 1.2138e-02, -2.3683e-03,  4.9098e-04],
           [ 1.5800e-02,  4.8392e-03, -1.0130e-02]]],
 
 
   

In [None]:
!pip install --upgrade progressbar2

Collecting progressbar2
  Downloading https://files.pythonhosted.org/packages/25/8c/d28cd70b6e0b870a2d2a151bdbecf4c678199d31731edb44fc8035d3bb6d/progressbar2-3.53.1-py2.py3-none-any.whl
Installing collected packages: progressbar2
  Found existing installation: progressbar2 3.38.0
    Uninstalling progressbar2-3.38.0:
      Successfully uninstalled progressbar2-3.38.0
Successfully installed progressbar2-3.53.1
