In [None]:
from torch import nn
from collections import OrderedDict
import torch.nn.functional as F
import torch
from torch.utils.data import DataLoader
import torchvision
import random
from torch.utils.data import Subset
from matplotlib import pyplot as plt
from torchsummary import summary
from torchvision import transforms
import progressbar as pb
import numpy as np

In [None]:
SUM = lambda x,y : x+y

In [None]:
def check_equity(property,a,b):
    pa = getattr(a,property)
    pb = getattr(b,property)
    assert  pa==pb, "Different {}: {}!={}".format(property,pa,pb)

    return pa

In [None]:
def module_unwrap(mod:nn.Module,recursive=False):
    children = OrderedDict()
    try:
        for name, module in mod.named_children():
            if (recursive):
                recursive_call = module_unwrap(module,recursive=True)
                if (len(recursive_call)>0):
                    for k,v in recursive_call.items():
                        children[name+"_"+k] = v
                else:
                    children[name] = module
            else:
                children[name] = module
    except AttributeError:
        pass

    return children

In [None]:
class VGGBlock(nn.Module):
    def __init__(self, in_channels, out_channels,batch_norm=False):

        super().__init__()

        conv2_params = {'kernel_size': (3, 3),
                        'stride'     : (1, 1),
                        'padding'   : 1
                        }

        noop = lambda x : x

        self._batch_norm = batch_norm

        self.conv1 = nn.Conv2d(in_channels=in_channels,out_channels=out_channels , **conv2_params)
        #self.bn1 = nn.BatchNorm2d(out_channels) if batch_norm else noop
        self.bn1 = nn.GroupNorm(32, out_channels) if batch_norm else noop

        self.conv2 = nn.Conv2d(in_channels=out_channels,out_channels=out_channels, **conv2_params)
        #self.bn2 = nn.BatchNorm2d(out_channels) if batch_norm else noop
        self.bn2 = nn.GroupNorm(32, out_channels) if batch_norm else noop

        self.max_pooling = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))

    @property
    def batch_norm(self):
        return self._batch_norm

    def forward(self,x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)

        x = self.max_pooling(x)

        return x

In [None]:
class Classifier(nn.Module):

    def __init__(self,num_classes=10):
        super().__init__()

        self.classifier = nn.Sequential(
            nn.Linear(2048, 2048),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(2048, 512),
            nn.ReLU(True),
            nn.Dropout(p=0.5),
            nn.Linear(512, num_classes)
        )

    def forward(self,x):

        return self.classifier(x)

In [None]:
class VGG16(nn.Module):

  def __init__(self, input_size, batch_norm=False):
    super(VGG16, self).__init__()

    self.in_channels,self.in_width,self.in_height = input_size

    self.block_1 = VGGBlock(self.in_channels,64,batch_norm=batch_norm)
    self.block_2 = VGGBlock(64, 128,batch_norm=batch_norm)
    self.block_3 = VGGBlock(128, 256,batch_norm=batch_norm)
    self.block_4 = VGGBlock(256,512,batch_norm=batch_norm)


  @property
  def input_size(self):
      return self.in_channels,self.in_width,self.in_height

  def forward(self, x):

    x = self.block_1(x)
    x = self.block_2(x)
    x = self.block_3(x)
    x = self.block_4(x)
    # x = self.avgpool(x)
    x = torch.flatten(x,1)

    return x

In [None]:
class CombinedLoss(nn.Module):
    def __init__(self, loss_a, loss_b, loss_combo, _lambda=1.0):
        super().__init__()
        self.loss_a = loss_a
        self.loss_b = loss_b
        self.loss_combo = loss_combo

        self.register_buffer('_lambda',torch.tensor(float(_lambda),dtype=torch.float32))


    def forward(self,y_hat,y):

        return self.loss_a(y_hat[0],y[0]) + self.loss_b(y_hat[1],y[1]) + self._lambda * self.loss_combo(y_hat[2],torch.cat(y,0))

----------------------------------------------------------------------

In [None]:
DO='TRAIN'

In [None]:
random.seed(47)

In [None]:
combo_fn = SUM

In [None]:
lambda_reg = 1

In [None]:
def test(net,classifier, loader):

      net.to(dev)
      classifier.to(dev)

      net.eval()

      sum_accuracy = 0

      # Process each batch
      for j, (input, labels) in enumerate(loader):

        input = input.to(dev)
        labels = labels.float().to(dev)

        features = net(input)

        pred = torch.squeeze(classifier(features))

        # https://discuss.pytorch.org/t/bcewithlogitsloss-and-model-accuracy-calculation/59293/ 2
        #pred_labels = (pred >= 0.0).long()  # Binarize predictions to 0 and 1
        _,pred_label = torch.max(pred, dim = 1)
        pred_labels = (pred_label == labels).float()

        batch_accuracy = pred_labels.sum().item() / len(labels)

        # Update accuracy
        sum_accuracy += batch_accuracy

      epoch_accuracy = sum_accuracy / len(loader)

      print(f"Accuracy test: {epoch_accuracy:0.5}")
      return epoch_accuracy

In [None]:
def train(nets, loaders, optimizer, criterion, epochs=20, dev=None, save_param=False, model_name="valerio"):
    # try:
      nets = [n.to(dev) for n in nets]

      model_a = module_unwrap(nets[0], True)
      model_b = module_unwrap(nets[1], True)
      model_c = module_unwrap(nets[2], True)

      reg_loss = nn.MSELoss()

      criterion.to(dev)
      reg_loss.to(dev)

      # Initialize history
      history_loss = {"train": [], "val": [], "test": []}
      history_accuracy = {"train": [], "val": [], "test": []}
      history_test = 0
      # Store the best val accuracy
      best_val_accuracy = 0

      # Process each epoch
      for epoch in range(epochs):
        # Initialize epoch variables
        sum_loss = {"train": 0, "val": 0, "test": 0}
        sum_accuracy = {"train": [0,0,0], "val": [0,0,0], "test": [0,0,0]}

        progbar = None
        # Process each split
        for split in ["train", "val", "test"]:
          if split == "train":
            for n in nets:
              n.train()
            widgets = [
              ' [', pb.Timer(), '] ',
              pb.Bar(),
              ' [', pb.ETA(), '] ', pb.Variable('ta','[Train Acc: {formatted_value}]')
            ]

            progbar = pb.ProgressBar(max_value=len(loaders[split][0]),widgets=widgets,redirect_stdout=True)

          else:
            for n in nets:
              n.eval()
          # Process each batch
          for j,((input_a, labels_a),(input_b, labels_b)) in enumerate(zip(loaders[split][0],loaders[split][1])):

            input_a = input_a.to(dev)
            input_b = input_b.to(dev)

            labels_a = labels_a.long().to(dev)
            labels_b = labels_b.long().to(dev)
            #print(labels_a.shape)
            #labels_a = labels_a.squeeze()
            #labels_b = labels_b.squeeze()
            
            #labels_a = labels_a.unsqueeze(1)
            #labels_b = labels_b.unsqueeze(1)
            #print(labels_a.shape)
            #labels_a = labels_a.argmax(-1)
            #labels_b = labels_b.argmax(-1)

            inputs = torch.cat([input_a,input_b],axis=0)
            labels = torch.cat([labels_a, labels_b])

            #labels  = labels.squeeze()
            #print(labels.shape)
            #labels = labels.argmax(-1)

            # Reset gradients
            optimizer.zero_grad()
            # Compute output
            features_a = nets[0](input_a)
            features_b = nets[1](input_b)
            features_c = nets[2](inputs)

            pred_a = torch.squeeze(nets[3](features_a))
            pred_b = torch.squeeze(nets[3](features_b))
            pred_c = torch.squeeze(nets[3](features_c))

            loss = criterion(pred_a, labels_a) + criterion(pred_b, labels_b) + criterion(pred_c, labels)

            for n in model_a:
              layer_a = model_a[n]
              layer_b = model_b[n]
              layer_c = model_c[n]
              if (isinstance(layer_a,nn.Conv2d)):
                loss += lambda_reg * reg_loss(combo_fn(layer_a.weight,layer_b.weight),layer_c.weight)
                if (layer_a.bias is not None):
                  loss += lambda_reg * reg_loss(combo_fn(layer_a.bias, layer_b.bias), layer_c.bias)

            # Update loss
            sum_loss[split] += loss.item()
            # Check parameter update
            if split == "train":
              # Compute gradients
              loss.backward()
              # Optimize
              optimizer.step()

            # Compute accuracy

            #https://discuss.pytorch.org/t/bcewithlogitsloss-and-model-accuracy-calculation/59293/ 2
            #pred_labels_a = (pred_a >= 0.0).long()  # Binarize predictions to 0 and 1
            #pred_labels_b = (pred_b >= 0.0).long()  # Binarize predictions to 0 and 1
            #pred_labels_c = (pred_c >= 0.0).long()  # Binarize predictions to 0 and 1

            #print(pred_a.shape)

            _,pred_label_a = torch.max(pred_a, dim = 1)
            pred_labels_a = (pred_label_a == labels_a).float()

            _,pred_label_b = torch.max(pred_b, dim = 1)
            pred_labels_b = (pred_label_b == labels_b).float()

            _,pred_label_c = torch.max(pred_c, dim = 1)
            pred_labels_c = (pred_label_c == labels).float()

            batch_accuracy_a = pred_labels_a.sum().item() / len(labels_a)
            batch_accuracy_b = pred_labels_b.sum().item() / len(labels_b)
            batch_accuracy_c = pred_labels_c.sum().item() / len(labels)

            # Update accuracy
            sum_accuracy[split][0] += batch_accuracy_a
            sum_accuracy[split][1] += batch_accuracy_b
            sum_accuracy[split][2] += batch_accuracy_c


            if (split=='train'):
              progbar.update(j, ta=batch_accuracy_c)

        if (progbar is not None):
          progbar.finish()
        # Compute epoch loss/accuracy
        epoch_loss = {split: sum_loss[split] / len(loaders[split][0]) for split in ["train", "val", "test"]}
        epoch_accuracy = {split: [sum_accuracy[split][i] / len(loaders[split][0]) for i in range(len(sum_accuracy[split])) ] for split in ["train", "val", "test"]}

        # # Store params at the best validation accuracy
        # if save_param and epoch_accuracy["val"] > best_val_accuracy:
        #   # torch.save(net.state_dict(), f"{net.__class__.__name__}_best_val.pth")
        #   torch.save(net.state_dict(), f"{model_name}_best_val.pth")
        #   best_val_accuracy = epoch_accuracy["val"]

        print(f"Epoch {epoch + 1}:")
        # Update history
        for split in ["train", "val", "test"]:
          history_loss[split].append(epoch_loss[split])
          history_accuracy[split].append(epoch_accuracy[split])
          # Print info
          print(f"\t{split}\tLoss: {epoch_loss[split]:0.5}\tVGG 1:{epoch_accuracy[split][0]:0.5}"
                f"\tVGG 2:{epoch_accuracy[split][1]:0.5}\tVGG *:{epoch_accuracy[split][2]:0.5}")

        if save_param:
          torch.save({'vgg_a':nets[0].state_dict(),'vgg_b':nets[1].state_dict(),'vgg_star':nets[2].state_dict(),'classifier':nets[3].state_dict()},f'{model_name}.pth')


        test(nets[0], nets[3], test_loader_all)
        test(nets[1], nets[3], test_loader_all)
        test(nets[2], nets[3], test_loader_all)
 
        summed_state_dict = OrderedDict()
 
        for key in nets[2].state_dict():
          if key.find('conv') >=0:
            #print(key)
            summed_state_dict[key] = combo_fn(nets[0].state_dict()[key],nets[1].state_dict()[key])
          else:
            summed_state_dict[key] = nets[2].state_dict()[key]
 
        nets[2].load_state_dict(summed_state_dict)
        test(nets[2], nets[3], test_loader_all)


In [None]:
root_dir = './'

In [None]:
rescale_data = transforms.Lambda(lambda x : x/255)

# Compose transformations
data_transform = transforms.Compose([
  transforms.Resize(32),
  transforms.RandomHorizontalFlip(),
  transforms.ToTensor(),
  rescale_data,
  #transforms.Normalize((-0.7376), (0.5795))
])

test_transform = transforms.Compose([
  transforms.Resize(32),
  transforms.ToTensor(),
  rescale_data,
  #transforms.Normalize((0.1327), (0.2919))
])

In [None]:
# Load MNIST dataset with transforms
train_set = torchvision.datasets.MNIST(root=root_dir, train=True, download=True, transform=data_transform)
test_set = torchvision.datasets.MNIST(root=root_dir, train=False, download=True, transform=test_transform)

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [None]:
train_idx = np.random.permutation(np.arange(len(train_set)))
test_idx = np.arange(len(test_set))

val_frac = 0.1

n_val = int(len(train_idx) * val_frac)
val_idx = train_idx[0:n_val]
train_idx = train_idx[n_val:]

h = len(train_idx)//2

train_set_a = Subset(train_set,train_idx[0:h])
train_set_b = Subset(train_set,train_idx[h:])

h = len(val_idx)//2

val_set_a = Subset(train_set,val_idx[0:h])
val_set_b = Subset(train_set,val_idx[h:])

h = len(test_idx)//2

test_set_a = Subset(test_set,test_idx[0:h])
test_set_b = Subset(test_set,test_idx[h:])

In [None]:
# Define loaders

train_loader_a = DataLoader(train_set_a, batch_size=64, num_workers=0, shuffle=True, drop_last=True)
val_loader_a   = DataLoader(val_set_a,   batch_size=64, num_workers=0, shuffle=False, drop_last=False)
test_loader_a  = DataLoader(test_set_a,  batch_size=64, num_workers=0, shuffle=False, drop_last=False)

train_loader_b = DataLoader(train_set_b, batch_size=64, num_workers=0, shuffle=True, drop_last=True)
val_loader_b   = DataLoader(val_set_b,   batch_size=64, num_workers=0, shuffle=False, drop_last=False)
test_loader_b  = DataLoader(test_set_b,  batch_size=64, num_workers=0, shuffle=False, drop_last=False)

test_loader_all = DataLoader(test_set,batch_size=64, num_workers=0,shuffle=False,drop_last=False)


# Define dictionary of loaders
loaders = {"train": [train_loader_a,train_loader_b],
           "val":   [val_loader_a,val_loader_b],
           "test":  [test_loader_a,test_loader_b]}

In [None]:
class VGG16INIT(nn.Module):

  def __init__(self, input_size, model1, model2, batch_norm=False):
    super(VGG16INIT, self).__init__()

    self.in_channels,self.in_width,self.in_height = input_size

    self.block_1 = VGGBlock(self.in_channels,64,batch_norm=batch_norm)
    self.block_2 = VGGBlock(64, 128,batch_norm=batch_norm)
    self.block_3 = VGGBlock(128, 256,batch_norm=batch_norm)
    self.block_4 = VGGBlock(256,512,batch_norm=batch_norm)

    for p_out, p_in1, p_in2 in zip(self.parameters(), model1.parameters(), model2.parameters()):
            p_out.data = nn.Parameter(p_in1 +p_in2);

  @property
  def input_size(self):
      return self.in_channels,self.in_width,self.in_height


  def forward(self, x):

    x = self.block_1(x)
    x = self.block_2(x)
    x = self.block_3(x)
    x = self.block_4(x)
    # x = self.avgpool(x)
    x = torch.flatten(x,1)

    return x

In [None]:
model1 = VGG16((1,32,32),batch_norm=True)
model2 = VGG16((1,32,32),batch_norm=True)
model3 = VGG16((1,32,32),batch_norm=True)
classifier = Classifier(num_classes=10)

In [None]:
list(model1.parameters())

[Parameter containing:
 tensor([[[[ 0.2834, -0.1472, -0.0348],
           [ 0.0590,  0.2581,  0.1013],
           [ 0.0590, -0.1911,  0.2689]]],
 
 
         [[[-0.1490,  0.1402,  0.2304],
           [-0.3101,  0.2301, -0.3187],
           [ 0.1751, -0.2327,  0.2277]]],
 
 
         [[[ 0.1853,  0.1020, -0.1432],
           [ 0.2849, -0.3255, -0.0924],
           [-0.1881,  0.0302,  0.1723]]],
 
 
         [[[-0.2434, -0.1134,  0.1910],
           [ 0.2529,  0.3095, -0.2746],
           [ 0.0288,  0.0870, -0.1498]]],
 
 
         [[[-0.0865,  0.1618, -0.0064],
           [-0.3176, -0.0916,  0.0861],
           [-0.2164, -0.2241,  0.2743]]],
 
 
         [[[-0.2670, -0.1545,  0.2010],
           [ 0.3005,  0.1249,  0.3007],
           [-0.1765, -0.1321,  0.2782]]],
 
 
         [[[-0.1846, -0.1596,  0.1079],
           [ 0.0192,  0.1429, -0.2773],
           [ 0.0624,  0.2661,  0.0460]]],
 
 
         [[[ 0.1518, -0.0537, -0.0188],
           [ 0.2441,  0.1352, -0.2837],
           [-0.

In [None]:
list(model2.parameters())

[Parameter containing:
 tensor([[[[-7.0041e-02, -1.6367e-01, -2.8129e-01],
           [-1.5276e-01, -9.7146e-02,  5.3524e-02],
           [ 4.7512e-02, -3.0981e-01,  2.0391e-01]]],
 
 
         [[[ 2.8695e-01,  1.1308e-01, -3.0961e-02],
           [-1.7771e-01, -1.4862e-02, -1.7841e-01],
           [ 5.2227e-02,  1.1971e-01, -2.9420e-01]]],
 
 
         [[[-2.2275e-01, -2.2183e-01, -3.1980e-01],
           [ 3.0135e-01,  2.6588e-01,  1.1242e-01],
           [ 6.0518e-02,  1.9556e-01,  1.6368e-01]]],
 
 
         [[[-1.4872e-01,  2.6460e-01, -2.4543e-01],
           [ 3.4530e-02,  2.5496e-01,  2.2250e-01],
           [ 5.2097e-02, -1.2688e-01, -1.6287e-02]]],
 
 
         [[[ 4.8323e-02, -2.2569e-01, -1.0288e-01],
           [ 3.0980e-01, -3.0779e-01, -2.0215e-02],
           [-1.7668e-01, -1.4180e-01,  5.1765e-02]]],
 
 
         [[[-9.7850e-02, -4.9340e-03, -1.8995e-02],
           [-2.7239e-01,  1.8023e-01, -1.8863e-01],
           [ 1.6508e-01,  2.7278e-01, -6.2506e-02]]],
 
 
     

In [None]:
list(model3.parameters())

[Parameter containing:
 tensor([[[[ 0.2133, -0.3109, -0.3161],
           [-0.0937,  0.1609,  0.1548],
           [ 0.1065, -0.5009,  0.4728]]],
 
 
         [[[ 0.1379,  0.2533,  0.1995],
           [-0.4878,  0.2152, -0.4971],
           [ 0.2273, -0.1130, -0.0665]]],
 
 
         [[[-0.0374, -0.1199, -0.4630],
           [ 0.5863, -0.0596,  0.0201],
           [-0.1276,  0.2258,  0.3359]]],
 
 
         [[[-0.3921,  0.1512, -0.0544],
           [ 0.2875,  0.5644, -0.0521],
           [ 0.0809, -0.0399, -0.1661]]],
 
 
         [[[-0.0381, -0.0639, -0.1093],
           [-0.0078, -0.3993,  0.0659],
           [-0.3930, -0.3659,  0.3260]]],
 
 
         [[[-0.3648, -0.1594,  0.1820],
           [ 0.0282,  0.3052,  0.1121],
           [-0.0114,  0.1407,  0.2157]]],
 
 
         [[[ 0.0411, -0.4718,  0.0862],
           [ 0.3030,  0.3903, -0.0332],
           [ 0.2096,  0.5540, -0.2158]]],
 
 
         [[[ 0.1099, -0.1492, -0.1545],
           [ 0.2474, -0.0334, -0.2353],
           [-0.

In [None]:
nets = [model1,model2,model3,classifier]

In [None]:
dev = torch.device('cuda')

In [None]:
parameters = set()

In [None]:
for n in nets:
  parameters |= set(n.parameters())

In [None]:
optimizer = torch.optim.Adam(parameters, lr = 0.005)
# Define a loss
#criterion = nn.BCEWithLogitsLoss()#,nn.BCEWithLogitsLoss(),nn.BCEWithLogitsLoss(),_lambda = 1)
criterion = nn.CrossEntropyLoss()
n_params = 0

In [None]:
DO = 'TRAIN'
if (DO=='TRAIN'):
  train(nets, loaders, optimizer, criterion, epochs=35, dev=dev,save_param=True)
else:
  state_dicts = torch.load('model.pth')
  model1.load_state_dict(state_dicts['vgg_a']) #questi state_dict vengono dalla funzione di training
  model2.load_state_dict(state_dicts['vgg_b'])
  model3.load_state_dict(state_dicts['vgg_star'])
  classifier.load_state_dict(state_dicts['classifier'])

  test(model1,classifier,test_loader_all)
  test(model2, classifier, test_loader_all)
  test(model3, classifier, test_loader_all)

  summed_state_dict = OrderedDict()

  for key in state_dicts['vgg_star']:
    if key.find('conv') >=0:
      print(key)
      summed_state_dict[key] = combo_fn(state_dicts['vgg_a'][key],state_dicts['vgg_b'][key])
    else:
      summed_state_dict[key] = state_dicts['vgg_star'][key]

  model3.load_state_dict(summed_state_dict)
  test(model3, classifier, test_loader_all)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
 [Elapsed Time: 0:02:40] |###############| [Time:  0:02:40] [Train Acc:  0.102]


Epoch 1:
	train	Loss: 7.2188	VGG 1:0.10603	VGG 2:0.11082	VGG *:0.10921
	val	Loss: 6.9039	VGG 1:0.1202	VGG 2:0.12039	VGG *:0.1203
	test	Loss: 6.9076	VGG 1:0.11294	VGG 2:0.11294	VGG *:0.11294
Accuracy test: 0.11355
Accuracy test: 0.11355
Accuracy test: 0.11355
Accuracy test: 0.11355


 [Elapsed Time: 0:00:10] |#              | [ETA:   0:02:17] [Train Acc:  0.133]

KeyboardInterrupt: ignored

In [None]:
plt.title("Loss")
for split in ["train", "val", "test"]:
    plt.plot(history_loss[split], label=split)
plt.legend()
plt.show()
# Plot accuracy
plt.title("Accuracy")
for split in ["train", "val", "test"]:
    plt.plot(history_accuracy[split], label=split)
plt.legend()
plt.show()

In [None]:
!pip install --upgrade progressbar2

Collecting progressbar2
  Downloading progressbar2-3.53.1-py2.py3-none-any.whl (25 kB)
Installing collected packages: progressbar2
  Attempting uninstall: progressbar2
    Found existing installation: progressbar2 3.38.0
    Uninstalling progressbar2-3.38.0:
      Successfully uninstalled progressbar2-3.38.0
Successfully installed progressbar2-3.53.1
