<a href="https://colab.research.google.com/github/DVAEsCL/DVAEsCL/blob/main/MiniImageNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Download the miniImageNet dataset
Please click the [link](https://drive.google.com/drive/folders/1-kwTWeyrw3uetDV_Y2xRuH5fJI3qq26h?usp=sharing) 

In [1]:
import pickle
import numpy as np

def dataprocess(data_path):
  with open(data_path, 'rb') as fopen:
     #contents = np.load(fopen, allow_pickle=True, encoding='bytes')
    contents = np.load(fopen, allow_pickle=True, encoding='latin1')
    return contents
test_in = dataprocess("/content/drive/MyDrive/ColabNotebooks/ContinualLearning/miniImageNet/miniImageNet_full.pickle")

In [2]:
data = test_in['images']
labels = test_in['labels']
print(test_in['images'].shape)
print(test_in['labels'].shape)

(60000, 84, 84, 3)
(60000,)


In [3]:
indices = []
for i in range(100):
  index = np.where(labels == i)[0].tolist()
  indices.append(index)

In [4]:
new_indices_train = []
new_indices_test = []
val = 0
for i in range(20):
  for j in range(5):
    if j == 0:
      ind1 = np.array(indices[j + val][:500])
      ind2 = np.array(indices[j + val][500:])
    else:
      ind1 = np.concatenate((ind1, indices[j + val][:500]), axis = 0)
      ind2 = np.concatenate((ind2, indices[j + val][500:]), axis = 0) #.tolist()
  #print(len(ind1))
  new_indices_train.append(ind1.tolist())
  new_indices_test.append(ind2.tolist())
  val = val + j + 1

In [5]:
import torch
traindata = []
trainlabels = []
testdata = []
testlabels = []

for i in range(20):
  traindata.append(torch.from_numpy(data[new_indices_train[i]].astype(np.float32)/255).view(2500, 3, 84, 84))
  trainlabels.append(torch.from_numpy(labels[new_indices_train[i]]))
  testdata.append(torch.from_numpy(data[new_indices_test[i]].astype(np.float32)/255).view(500, 3, 84, 84))
  testlabels.append(torch.from_numpy(labels[new_indices_test[i]]))


In [6]:
print(testdata[0].shape)
print(traindata[0].shape)

torch.Size([500, 3, 84, 84])
torch.Size([2500, 3, 84, 84])


In [7]:
import torch
import torch.nn as nn
class encoder(nn.Module):
  def __init__(self):
    super(encoder, self).__init__()
    self.conv1 = nn.Conv2d(3, 6, 3, 2, 1) 
    self.conv2 = nn.Conv2d(6, 12, 2, 2, 0)
    self.conv3 = nn.Conv2d(12, 24, 2, 2, 0)
    self.conv4 = nn.Conv2d(24, 48, 2, 2, 0)
    self.conv5 = nn.Conv2d(48, 48, 2, 2, 0)
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.conv3(x)
    x = self.conv4(x)
    x = self.conv5(x)
    return x
class decoder(nn.Module):
  def __init__(self):
    super(decoder, self).__init__()
    self.nc_mnist = 118
    self.nc_cifar10 = 202
    self.nc_cifar100 = 292
    self.nk_mnist = 3
    self.nk_cifar10 = 4
    self.decon1 = nn.ConvTranspose2d(196, 24, 3, 1, 0)
    self.decon2 = nn.ConvTranspose2d(24, 12, 4, 2, 0)
    self.decon3 = nn.ConvTranspose2d(12, 6, 7, 2, 0)
    self.decon4 = nn.ConvTranspose2d(6, 3, 2, 2, 0)
    self.decon5 = nn.ConvTranspose2d(3, 3, 2, 2, 0)
  def forward(self, x):
    #print(x.shape, '11')
    x = x.view(x.shape[0], 196, 1, 1)
    x = self.decon1(x)
    x = self.decon2(x)
    x = self.decon3(x)
    x = self.decon4(x)
    x = self.decon5(x)
    return x

class VAE(nn.Module):
  def __init__(self, eps):
    super(VAE, self).__init__()
    self.en = encoder()
    self.de = decoder()
    self.eps = eps
    self.mnist_z = 108
    self.cifar10_z = 192
  def forward(self, x, one_hot):
    x = self.en(x)
    x = x.view(x.shape[0], -1)
    mu = x[:, :96]
    logvar = x[:, 96:]
    std = torch.exp(0.5 * logvar)
    z = mu + self.eps * std
    #print(z.shape, 'aaa', one_hot.shape)
    z1 = torch.cat((z, one_hot), axis = 1)
    #print(z1.shape, 'bbb')
    return self.de(z1), mu, logvar

class private(nn.Module):
  def __init__(self, eps):
    super(private, self).__init__()
    self.task = torch.nn.ModuleList()
    self.eps = eps
    for _ in range(20):
      self.task.append(VAE(self.eps))

  def forward(self, x, one_hot, task_id):
    return self.task[task_id].forward(x, one_hot)

class NET(nn.Module):
  def __init__(self, eps):
    super(NET, self).__init__()
    self.eps = eps
    self.shared = VAE(self.eps)
    self.private = private(self.eps)
    self.head = torch.nn.ModuleList()
    self.mnist = 216
    self.cifar10 = 384
    self.in_mnist = 2
    self.in_cifar10 = 6
    for _ in range(20):
      self.head.append(
          nn.Sequential(
            nn.Conv2d(6, 12, 3, 2, 1),
            #nn.Conv2d(3, 6, 2, 2, 0),
            #nn.Conv2d(6, 12, 2, 2, 0),
            nn.Conv2d(12, 24, 2, 2, 0),
            nn.Conv2d(24, 48, 2, 2, 0),
            nn.Conv2d(48, 48, 2, 2, 0),
            nn.Flatten(1, -1),
            nn.Linear(48*5*5, 100), #for cifar10 only
            #nn.Linear(100, 10)
          )
      )

  def forward(self, x, one_hot, task_id):
    s_x, s_mu, s_logvar = self.shared(x, one_hot)
    p_x, p_mu, p_logvar = self.private(x, one_hot, task_id)
    #print(s_x.shape, p_x.shape, '111')
    x = torch.cat([s_x, p_x], dim = 1)
    #print(x.shape, '22')
    return self.head[task_id].forward(x), (s_x, s_mu, s_logvar), (p_x, p_mu, p_logvar)


#Number of epochs and synthetic data
If you wish to change the number of epochs and synthetic data used as a generative replay, check lines 170 and 70, respectively. Change according to your requirments.

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from collections import deque
from torch.autograd import grad as torch_grad

import torchvision.utils as vutils

import os
import os.path

import numpy as np
np.set_printoptions(threshold=np.inf)


class CL_VAE():
  def __init__(self):
    super(CL_VAE, self).__init__()

    self.batch_size = 64
    self.mnist_z = 108
    self.cifar10_z = 192
    self.num_class_cifar100 = 100
    self.num_class_cifar10 = 10
    self.build_model()
    self.set_cuda()
    self.criterion = torch.nn.CrossEntropyLoss()
    self.recon = torch.nn.MSELoss()
    self.net_path = 'path/trial.pth'
    self.accuracy_matrix = [[] for kk in range(20)]
    self.acc_matr = []
    self.forget_mat = []
    self.acc_25 = []
    self.acc_50 = []


  def build_model(self):
    self.eps = torch.randn(self.batch_size, 96)
    self.eps = self.eps.cuda()
    self.net = NET(self.eps)
    pytorch_total_params = sum(p.numel() for p in self.net.parameters() if p.requires_grad)
    print('pytorch_total_params:', pytorch_total_params)
    
  def set_cuda(self):
    self.net.cuda()



  def VAE_loss(self, recon, mu, sigma):
    kl_div = -0.5 * torch.sum(1 + sigma - mu.pow(2) - sigma.exp())
    #print('kl_div', kl_div.item())
    return recon + kl_div

  def train(self, all_traindata, all_trainlabels, all_testdata, all_testlabels, total_tasks):

    replay_classes = []
    for i in range(total_tasks):
      traindata = all_traindata[i]
      trainlabels = all_trainlabels[i]
      testdata = all_testdata[i]
      testlabels = all_testlabels[i]
      #print(trainlabels, 'avfr')
      replay_classes.append(sorted(list(set(trainlabels.numpy().tolist()))))
      if i + 1 == 1:
        self.train_task(traindata, trainlabels, testdata, testlabels, i)
        #replay_classes.append(sorted(list(set(trainlabels.detach().numpy().tolist()))))
      else:
        num_gen_samples = 4
        #z_dim = 108
        for m in range(i):
          #print(replay_classes, 'replay_classes')
          replay_trainlabels = []
          for ii in replay_classes[m]:
            for j in range(num_gen_samples):
              replay_trainlabels.append(ii)
          replay_trainlabels = torch.tensor(replay_trainlabels)
          replay_trainlabels_onehot = self.one_hot(replay_trainlabels)



          z = torch.randn(5 * num_gen_samples, 96)
          #print(z.shape, replay_trainlabels_onehot.shape, 'aa')
          z_one_hot = torch.cat((z, replay_trainlabels_onehot), axis = 1)
          z_one_hot = z_one_hot.cuda()
          replay_data = self.net.private.task[m].de(z_one_hot).detach().cpu()

          traindata = torch.cat((replay_data, traindata), axis = 0)
          trainlabels = torch.cat((replay_trainlabels, trainlabels))
          testdata = torch.cat((testdata, torch.tensor(all_testdata[m])), axis = 0)
          testlabels = torch.cat((testlabels, torch.tensor(all_testlabels[m])))
      
        self.train_task(traindata, trainlabels, testdata, testlabels, i)
      self.acc_mat(all_testdata, all_testlabels, total_tasks, i)
      
      #print(sorted(list(set(trainlabels.detach().numpy()))), '/n', sorted(list(set(testlabels.detach().numpy()))))
    self.forgetting_measure(self.accuracy_matrix, total_tasks)
    #print(np.mean(self.forget_mat), 'forget_mat:', self.forget_mat)
    #print(self.acc_25, 'acc_25', np.mean(self.acc_25))
    #print(self.acc_50, 'acc_50', np.mean(self.acc_50))


  def one_hot(self, labels):
    matrix = torch.zeros(len(labels), 100)
    rows = np.arange(len(labels))
    matrix[rows, labels] = 1
    return matrix 
  def  forgetting_measure(self, accuracy_matrix, num_tasks):
    forgetting_measure = []
    accuracy_matrix = np.array(accuracy_matrix)
    for after_task_idx in range(1, num_tasks):
      after_task_num = after_task_idx + 1
      #print(accuracy_matrix, 'accuracy_matrix')
      prev_acc = accuracy_matrix[:after_task_num - 1, :after_task_num - 1]
      forgettings = prev_acc.max(axis=0) - accuracy_matrix[after_task_num - 1, :after_task_num - 1]
      forgetting_measure.append(np.mean(forgettings).item())
    
    #print('forgetting_measure', forgetting_measure)
    #print("the forgetting measure is...", np.mean(np.array(forgetting_measure)))

  def acc_mat(self, testData1, testLabels1, num_tasks, t):
    for kk in range(num_tasks):
      testData_tw = testData1[kk]
      testLabels_tw = testLabels1[kk]
      testLabels_tw_classes = sorted(list(set(testLabels_tw.detach().numpy().tolist())))
      #pred_tw = (class_appr.test(testData_tw)).cpu() #classifier.predict(testData_tw)
      _, pred_tw = self.evall(testData_tw, testLabels_tw, kk)
      #pred_tw = torch.argmax(pred_tw, dim = 1) 
      #pred_tw = pred_tw.cpu()       
      testLabels_tw = testLabels_tw.detach().numpy()[:pred_tw.shape[0]]
      #print(pred_tw[0], '12', testLabels_tw[0])
      dict_correct_tw = {}
      dict_total_tw = {}

      for ii in testLabels_tw_classes:
        dict_total_tw[ii] = 0
        dict_correct_tw[ii] = 0

      for ii in range(0, testLabels_tw.shape[0]):
        #print(testLabels_tw[ii],'aaa', pred_tw[ii])
        if(testLabels_tw[ii] == pred_tw[ii]):
          dict_correct_tw[testLabels_tw[ii].item()] = dict_correct_tw[testLabels_tw[ii].item()] + 1
        #print(testLabels_tw[ii], '1', dict_total_tw[testLabels_tw[ii]], '2', dict_total_tw[testLabels_tw[ii]])
        dict_total_tw[testLabels_tw[ii].item()] = dict_total_tw[testLabels_tw[ii].item()] + 1
            
      avgAcc_tw = 0.0
      num_seen_tw = 0.0
        
      for ii in testLabels_tw_classes:
        avgAcc_tw = avgAcc_tw + (dict_correct_tw[ii]*1.0)/(dict_total_tw[ii])
        num_seen_tw = num_seen_tw + 1
        
        avgAcc_tw = avgAcc_tw/num_seen_tw
        
        #testData_tw[jj].append(avgAcc_tw)
      self.accuracy_matrix[t].append(avgAcc_tw)
      

  def model_save(self):
    torch.save(self.net.state_dict(), os.path.join(self.net_path))


  def train_task(self, traindata, trainlabels, testdata, testlabels, task_id):

    net_opti = torch.optim.Adam(self.net.parameters(), lr = 1e-4)
    #data, label = traindata
    #batch_size = 64
    num_iterations = int(traindata.shape[0]/self.batch_size)
    num_epochs = 50
    #print(num_iterations, '451')
    for e in range(num_epochs):
      for i in range(num_iterations):
        self.net.zero_grad()
        self.net.train()
        

        batch_data = traindata[i * self.batch_size : (i + 1)*self.batch_size]
        #print(batch_data.shape, '41')
        batch_label = trainlabels[i * self.batch_size : (i + 1)*self.batch_size]
        batch_label_one_hot = self.one_hot(batch_label)
        batch_data = batch_data.cuda()
        batch_label = batch_label.cuda()
        batch_label_one_hot = batch_label_one_hot.cuda()

        out, shared_out, private_out = self.net(batch_data, batch_label_one_hot, task_id)
        s_x, s_mu, s_logvar = shared_out
        p_x, p_mu, p_logvar = private_out
        #print(out.shape, '12', batch_label.shape, s_x.shape)

        cross_en_loss = self.criterion(out, batch_label)

        s_recon = self.recon(batch_data, s_x)
        p_recon = self.recon(batch_data, p_x)

        s_VAE_loss = self.VAE_loss(s_recon, s_mu, s_logvar)
        p_VAE_loss = self.VAE_loss(p_recon, p_mu, p_logvar)

        all_loss = cross_en_loss + s_VAE_loss + p_VAE_loss

        all_loss.backward(retain_graph=True)
        net_opti.step()
      #print('epoch:', e + 1, 'task_loss', cross_en_loss.item(), 's_VAE:', s_VAE_loss.item(), 'p_VAE', p_VAE_loss.item())

      if (e + 1) % 25 == 0:
        acc1, _ = self.evall(testdata, testlabels, task_id)

      if e + 1 == 25:
        self.acc_25.append(acc1)
      if e + 1 == 50:
        self.acc_50.append(acc1)
        
        #acc2 = self.ev((a, b), task_id)
        #print('Task:', task_id + 1, 'acc', acc1)
        
          
        if task_id + 1 == 20:
          self.model_save()
    
    
    #self.acc_matr.append(acc1)

  def evall(self, testdata, testlabels, task_id):
    self.net.eval()
    #data, labels = testdata
    #print(testdata.shape, '11', testlabels.shape)
    #batch_size = 64
    num_iterations = int(testdata.shape[0]/self.batch_size)
    acc = []
    pred_labels_list = []
    for i in range(num_iterations):
      batch_data = testdata[i * self.batch_size : (i + 1) * self.batch_size]
      batch_labels = testlabels[i * self.batch_size : (i + 1) * self.batch_size]
      batch_label_one_hot = self.one_hot(batch_labels)
      batch_data = batch_data.cuda()
      batch_labels = batch_labels.cuda()
      batch_label_one_hot = batch_label_one_hot.cuda()
      out, _, _ = self.net(batch_data, batch_label_one_hot, task_id)
      pred_labels = torch.argmax(out, axis = 1)
      pred_labels_list.append(pred_labels.detach().cpu().numpy().tolist())
      #print(pred_labels, 'aa')
      #print(pred_labels.shape, '1452', batch_labels)
      acc.append((torch.sum(batch_labels == pred_labels)/batch_data.shape[0] * 100).detach().cpu().numpy().tolist())

    #print('acc:', acc)
    return np.mean(np.array(acc)), np.array(pred_labels_list).flatten()





In [11]:
import time
aaa = CL_VAE()
#start = time.time()

aaa.train(traindata, trainlabels, testdata, testlabels, 20)
#end = time.time()
#print('It took: ', end - start)


pytorch_total_params: 4107152


