<a href="https://colab.research.google.com/github/PyxAI/Data-Science-Notebooks/blob/master/The_OneNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In this experiment, I'm testing if a few pretrained networks can achieve better results together.

The approach is a combination of feature extraction and fine-tunning.

So, The idea here is the use pretrained networks, and not change the weights in the main body of the CNN, but rather, train the (augmented) classifier at the end. 

Backpropagation will take into consideration both networks for minimizing the loss, thus adjusting the weights of the fully connected layer(s) accordingly.

---
---
Bottom line conclusion:

**The new network demonstrates superior results over the best child pretrained network used in training.**

The most significant improvement can be seen on **unseen** data, which makes it interesting to think how we can build a network that uses a lot of pretrained knowledge, to quickly converge on new data. 

That is what I aimed for in this experiment

In [0]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random
import os

import torch
import torch.nn.functional as F
from torch import nn
from torchvision import models
from torchvision import datasets
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
from google.colab import drive

if torch.cuda.is_available():
  device = torch.device('cuda')
  dtype = torch.cuda.FloatTensor
else:
  device = torch.device('cpu')
  dtype = torch.FloatTensor


!git clone https://github.com/aaron-xichen/pytorch-playground
os.chdir('pytorch-playground')
!python3 setup.py develop --user
os.chdir('/content/')

!pip install timm
import timm

drive.mount('/content/gdrive')
!mkdir /content/weights
!cp /content/gdrive/My\ Drive/weights/resnet18_cifar.pt /content/weights/

In [0]:

# Checking the accuracy and loss of the network on the giving dataset
def validate(network, dset):
  network.eval()
  correct = 0
  total = 0
  with torch.no_grad():
    loss = 0
    for x, y in dset:
      x = x.to(device)
      y = y.to(device)
      output = network(x)
      loss += loss_fn(output, y)
      _, predicted = torch.max(output.data, 1)
      total += y.size(0)
      correct += (predicted == y).sum().item()
    loss/= len(dset)
    print ("validation loss: {:.4f}".format(loss))
    print ("accuracy here is: {}".format(100 * correct / total))
  return loss.detach().item()

# To ignore the last FC layer of the network
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
    def forward(self, x):
        return x

#Our very simple trainning loop
def run_train(network):
  loss_arr = []
  try:
    for epoch in range(epochs):
      for iteration, (x, y) in enumerate(train):
        network.train()
        x = x.to(device)
        y = y.to(device)
        optim.zero_grad()
        output = network(x)
        loss = loss_fn(output, y)
        loss.backward()
        if (iteration % print_every == 0):
          print ("epoch: {}, iter:{}".format(epoch, iteration))
          loss_arr.append(validate(network, val))
        optim.step()
    plt.plot(loss_arr)
    plt.show()
  except KeyboardInterrupt:
    plt.plot(loss_arr)
    plt.show()



##Settings

In [0]:
#Training
batch_size = 64
lr = 1e-4
epochs=20
print_every = 100

#Dataset
cifar_root = '/content/data/cifar10'
cifar_transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # As recommended for torchvision
    ])


#Dataset

In [6]:
#Take Cifar10
cifar = datasets.CIFAR10(cifar_root, train=True, download=True, transform=cifar_transform)

#split to train - val
train_set, val_set = torch.utils.data.random_split(cifar, [int(len(cifar)*0.9), int(len(cifar)*0.1)])
train = DataLoader(train_set, batch_size = batch_size, shuffle=True, num_workers = 4)
val = DataLoader(val_set, batch_size = batch_size, shuffle=True, num_workers = 4)

#test set
cifar = datasets.CIFAR10(cifar_root, train=False, download=True, transform=cifar_transform)
test = DataLoader(cifar, batch_size = batch_size, shuffle=True, num_workers = 4)

loss_fn = nn.CrossEntropyLoss()

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /content/data/cifar10/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting /content/data/cifar10/cifar-10-python.tar.gz to /content/data/cifar10
Files already downloaded and verified


##Model setup

Source: https://github.com/aaron-xichen/pytorch-playground

Pretrained weight on cifar100

In [0]:
import torch
from torch.autograd import Variable
os.chdir('pytorch-playground')
from utee import selector
cifar100, ds_fetcher, is_imagenet = selector.select('cifar100')
ds_val = ds_fetcher(batch_size=10, train=False, val=True)
for idx, (data, target) in enumerate(ds_val):
    data =  Variable(torch.FloatTensor(data)).cuda()
    output = cifar100(data)

cifar100.classifier = nn.Linear(cifar100.classifier[0].in_features, 10, bias=True)

Source: https://github.com/rwightman/pytorch-image-models

Pretrained on ImageNet

In [11]:
mixnet = timm.create_model('mixnet_xl', pretrained=True)
mixnet.classifier = nn.Linear(mixnet.classifier.in_features, 10, bias=True)

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_xl_ra-aac3c00c.pth" to /root/.cache/torch/checkpoints/mixnet_xl_ra-aac3c00c.pth


Pretrained on CIFAR-10

In [12]:
resnet18A = models.resnet18(pretrained=False)
resnet18A._modules['fc']= nn.Linear(512, 10) #To make it CIFARable
resnet18A.load_state_dict(torch.load('/content/gdrive/My Drive/weights/ready_resnet18_cifar.pt', map_location=torch.device(device)))


<All keys matched successfully>

Torch model trained on ImageNet

In [13]:
resnet18B = models.resnet18(pretrained=True)
resnet18B.fc = nn.Linear(in_features=resnet18B.fc.in_features, out_features=10, bias=True)

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/checkpoints/resnet18-5c106cde.pth


HBox(children=(FloatProgress(value=0.0, max=46827520.0), HTML(value='')))




A basic idea here is to only train the classifier, while relying on the trained networks to extract features as best they can.

So here we are shutting down grad on everything but FC layer


In [0]:
#Shutting down grad on everything but FC layer
for net in [cifar100,mixnet,resnet18B,resnet18A]:
  for param in net.named_parameters():
    if not any(['body' in param[0], 'fc' in param[0], 'classifier' in param[0]]):
      param[1].requires_grad=False

We will now train the classifier of each network on CIFAR10,  under the same conditions, for the same amount of epochs, to see how good it can get on it's own.

# Banchmark runs

The first network, that was trained on the CIFAR100 achieved these results:

 > Validation loss: 0.5378
  
 > Validation accuracy: 81.2%




In [0]:
# Training code

optim = torch.optim.Adam(cifar100.parameters(), lr=lr)
cifar100.to(device)
run_train(cifar100)

The second network, mixnet I stopped after 10 epochs since it wasn't even close to good results:


> Validation loss: 1.7438

> Validation accuracy: 40.86 %

In [0]:
# Training code

optim = torch.optim.Adam(mixnet.parameters(), lr=lr)
mixnet.to(device)
run_train(mixnet)

The following network, resnet18B, was trained on ImageNet and achieved:

> Validation loss: 1.5685

> Validation accuracy: 45.8 %

In [0]:
# Training code

optim = torch.optim.Adam(resnet18B.parameters(), lr=lr)
resnet18B.to(device)
run_train(resnet18B)

The following network, Renset18A, The network that is pretrained on CIFAR10 got:

> Validation loss: 0.1200

> Validation accuracy: 97.22 %

In [0]:
# Training code

optim = torch.optim.Adam(resnet18A.parameters(), lr=lr)
resnet18A.to(device)
run_train(resnet18A)

#Real testing begins:

mixing the networks.

Starting from setting the last layer to output 512

In [0]:
resnet18A.fc = Identity()
resnet18B.fc = Identity()
cifar100.classifier = nn.Linear(cifar100.classifier.in_features, 512, bias=True)
mixnet.classifier = nn.Linear(mixnet.classifier.in_features, 512, bias=True)

In [0]:

"""
Our unifying class.
The class unified networks by concatenating the extracted features of the networks
And uses a (Linear) classifier on top of that.
"""
class Onet(nn.Module):
  def __init__(self, networks):
    super(Onet, self).__init__()
    #Here we should take as input all the different pretrained networks.
    self.networks = networks
    self.pilesize = len(self.networks)
    for net in self.networks:
      net.to(device)

    #Classifier
    self.body = nn.Sequential(
        nn.Linear(512*self.pilesize, 512),
        nn.LeakyReLU(),
        nn.Linear(512, 10)
    )

  def forward(self, image):
    joined = torch.Tensor(0).to(device)
    # Run a forward pass in each of the networks, then concatenate them.
    for net in self.networks:
      x = net(image)
      joined = torch.cat((x, joined), dim=1)
    # Activation layer, and send to classification.
    x = F.leaky_relu(joined)
    x = self.body(x)
    return x


In [0]:
# Creating the network instance

networks = [resnet18B, mixnet, cifar100]
onet = Onet(networks)
onet.to(device)

# Always Adam.
optim = torch.optim.Adam(onet.parameters(), lr=lr)

# Cross Entropy loss being the natural selection for the classification task
loss_fn = nn.CrossEntropyLoss()


In [0]:
run_train(onet)

on resnet18B + mixnet + cifar100
we got:

>validation loss: 0.6551

>accuracy here is: 77.18

In [0]:
networks = [resnet18B, mixnet]
onet = Onet(networks)
onet.to(device)
optim = torch.optim.Adam(onet.parameters(), lr=lr)


In [0]:
run_train(onet)

resnet18B, mixnet:

> Validation loss: 1.3916

> Validation accuracy: 50.98 %

In [0]:
networks = [resnet18B, cifar100]
onet = Onet(networks)
onet.to(device)
optim = torch.optim.Adam(onet.parameters(), lr=lr)
run_train(onet)

resnet18B, cifar100:

>validation loss: 0.6431

>validation accuracy: 77.74

In [0]:
networks = [mixnet, cifar100]
onet = Onet(networks)
onet.to(device)
optim = torch.optim.Adam(onet.parameters(), lr=lr)
run_train(onet)

Using all 4 networks demonstrated a slight improvment in performance from the best network:

> Validation loss: 0.11

> Validation accuracy: 97.46 %

#Summary:

model, loss, accuracy

CIFAR100, 0.5378, 81.2

mixnet, 1.7438, 40.86

resnetB (imagenet), 1.5685, 45.8

resnetA (cifar10), 0.12, 97.22

resnet18B + mixnet + cifar100, 0.6551, 77.18

resnet18B + mixnet, 1.3916, 50.98

resnet180B + cifar100, 0.6431, 77.74

resnet18A + resnet18B + mixnet + cifar100, 0.11, 97.46

mixenet + cifar10, 0.6510, 77.3

In [0]:
columns = ["model", "loss", "accuracy", "Accuracy diff from best model [%]"]
data =[
["cifar-100", 0.5378, 81.2, 0], 
["mixnet", 1.7438, 40.86, 0],
["resnet18", 1.5685, 45.8, 0],
["cifar-10", 0.12, 97.22, 0],
["resnet18 + mixnet", 1.3916, 50.98, "+11.3"],
["resnet18 + cifar-100", 0.6431, 77.74, "-4.6"],
["mixnet + cifar-100", 0.6510, 77.3, "-5"],
["resnet18 + mixnet + cifar-100", 0.6551, 77.18, "-5"],
["cifar-10 + resnet18 + mixnet + cifar-100", 0.11, 97.46, "+0.01"],
["resnet18B + resnet18A + cifar100",0.1079, 97.06, ]
]

In [5]:
pd.DataFrame(data, columns=columns)

Unnamed: 0,model,loss,accuracy,Accuracy diff from best model [%]
0,cifar-100,0.5378,81.2,0.0
1,mixnet,1.7438,40.86,0.0
2,resnet18,1.5685,45.8,0.0
3,cifar-10,0.12,97.22,0.0
4,resnet18 + mixnet,1.3916,50.98,11.3
5,resnet18 + cifar-100,0.6431,77.74,-4.6
6,mixnet + cifar-100,0.651,77.3,-5.0
7,resnet18 + mixnet + cifar-100,0.6551,77.18,-5.0
8,cifar-10 + resnet18 + mixnet + cifar-100,0.11,97.46,0.01
9,resnet18B + resnet18A + cifar100,0.1079,97.06,


In [0]:
#Save model
torch.save(onet.state_dict(), '/content/weights/onet_resnet18_4legs.pt')
!cp /content/weights/onet_resnet18_4legs.pt /content/gdrive/My\ Drive/weights/

In [0]:
#Checking results on the test set
validate(onet, test) 