<a href="https://colab.research.google.com/github/armando-larocca/Project-IL-/blob/master/FT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import os
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from PIL import Image
import torch.optim as optim 
import torchvision.datasets as dsets
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Subset, DataLoader
import sys
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

DEVICE = 'cuda' 

NUM_CLASSES = 100 

batch_size = 128    
LR = 0.2      
MOMENTUM = 0.9       
WEIGHT_DECAY = 1e-5  

NUM_EPOCHS = 70  
STEP_SIZE = [49,63]  
GAMMA = 0.2       
LOG_FREQUENCY = 10

In [None]:
if not os.path.isdir('./Project-dir'):
  !git clone https://github.com/armando-larocca/Project-IL-

if not os.path.isdir('content/cifar100.py'):
  !mv '/content/Project-IL-/cifar100.py' '/content'  
  !mv '/content/Project-IL-/utils.py' '/content'  

from cifar100 import * 

In [None]:
transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(0.5),
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408),std=(0.2675, 0.2565, 0.2761)),
])

transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5071, 0.4867, 0.4408),std=(0.2675, 0.2565, 0.2761)),
])


In [None]:
import random
shf = []
for x in range(0,100):
  shf.append(x)

random.shuffle(shf)

train_dataset = Cifar100(".\Data", train=True, transform=transform)
test_dataset = Cifar100(".\Data", train=False, transform=transform_test)

train_dataset._shuffle_(shf)
test_dataset._shuffle_(shf)

incr_train = train_dataset.__incremental_train_indexes__(1)
incr_val = test_dataset.__incremental_val_indexes__(0)

decine_train = []
decine_val = []

for i in range(0,10):
  val_dataset = Subset(test_dataset, incr_val[i])
  training_dataset = Subset(train_dataset, incr_train[i])
  decine_train.append(training_dataset)
  decine_val.append(val_dataset)

In [None]:
if not os.path.isdir('content/cifarResnet.py'):
  !mv '/content/Project-IL-/cifarResnet.py' '/content'

from cifarResnet import resnet32

criterion = nn.CrossEntropyLoss()

In [None]:
class FT(nn.Module):
    def __init__(self,  n_classes):
        
        #### Network architecture #### 
        super(FT, self).__init__()
        self.feature_extractor = resnet32()
        self.fc = nn.Linear(64, n_classes, bias=True)

        self.n_classes = n_classes
        self.n_known = 0
       
        self.p = self.parameters()

    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.fc(x)
        return x

    def increment_classes(self, n):
      in_features = self.fc.in_features
      out_features = self.fc.out_features
      weight = self.fc.weight.data
      bias = self.fc.bias.data

      self.fc = nn.Linear(in_features, out_features+n, bias=True)
      self.fc.weight.data[:out_features] = weight
      self.fc.bias.data[:out_features] = bias
      self.n_classes += n

**Train**

In [None]:
best_acc = []
tot_matrix = []
tot_labe = []

net = FT(10)

for s in range(0,10):

  net.cuda()

  train_loader = torch.utils.data.DataLoader(decine_train[int(s)], batch_size=batch_size,shuffle=True, num_workers=4)
  test_loader = torch.utils.data.DataLoader(decine_val[int(s)], batch_size=batch_size,shuffle=True, num_workers=4)


  if (s!=0):
    net.increment_classes(10)

  net.cuda()
  net.train(True)
  p = net.parameters()
  optimizer = optim.SGD(p, lr=0.02,weight_decay=0.00001,momentum=0.9) 
  scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [49,63], gamma=0.2)
    
  matrix = []
  labe = []    
  b_ac = 0

  #### TRAIN ### 
  for epoch in range(0,NUM_EPOCHS):

    running_corrects = 0 
    total = 0

    for indices, images, labels in train_loader:
      images = Variable(images).cuda()
      labels = Variable(labels).cuda()
      indices = indices.cuda()
      optimizer.zero_grad()

      g = net.forward(images.cuda())

      _, preds = torch.max(g, 1)
      running_corrects += torch.sum(preds == labels.data).data.item()
      total += labels.size(0)

      ##### LOSS #######
      criterio= nn.CrossEntropyLoss()
      g.cuda()

      loss = criterio(g,labels)  
        
      ####################

      loss.backward()
      optimizer.step()

    accuracy = running_corrects / float(total)       
    scheduler.step()
    print ('Epoch [%d/%d], Loss: %.4f, Acc: %.2f' %(epoch+1, NUM_EPOCHS, loss.data, accuracy)) 
    
    #### TEST ####
    m=[]
    l=[]
    net.train(False)

    total = 0.0
    running_corrects = 0
    for indices, images, labels in test_loader:
        images = Variable(images).cuda()
        out = net.forward(images)
        _, preds = torch.max(out, 1)
        running_corrects += torch.sum(preds.cpu() == labels.data).data.item()
        total += labels.size(0)

        m.extend(preds) 
        l.extend(labels)

    matrix.append(m)
    labe.append(l)     

    accuracy = float(running_corrects / float(total))
    print('Test Accuracy',accuracy)

    if(b_ac < accuracy):
      b_ac = accuracy 

  tot_matrix.append(matrix)
  tot_labe.append(labe)
  best_acc.append(b_ac)

**Confusion matrix**

In [None]:
from sklearn.metrics import confusion_matrix
import numpy as np 
import matplotlib 
import matplotlib.pyplot as plt

tacche = [10,20,30,40,50,60,70,80,90]

x =  tot_matrix[9][69]
l = tot_labe[9][69]

l =[int(i) for i in l]
x =[int(i) for i in x]

cf = confusion_matrix(list(l),list(x))

fig, ax = plt.subplots(figsize=(15,15))
im = ax.imshow(cf,cmap = 'plasma')

ax.set_yticks(tacche)
ax.set_xticks(tacche)

plt.setp(ax.get_xticklabels(), rotation=45, ha="right",rotation_mode="anchor")
ax.set_title("image classification")
fig.tight_layout()
plt.show()

**Accuracy trend**

In [None]:
fig,ax = plt.subplots(figsize=(15,10))
ax.plot([10,20,30,40,50,60,70,80,90,100],best_acc,'k-o')
plt.xlim(10,100)
plt.ylim(0,1)
plt.xlabel('Classes')
plt.ylabel('Accuracy')
plt.show()

print(best_acc)