<a href="https://colab.research.google.com/github/DVAEsCL/DVAEsCL/blob/main/EMNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Download the EMNIST data

In [None]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import EMNIST
from torchvision.transforms import ToTensor
from torchvision.utils import make_grid
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import random_split
%matplotlib inline

batch_size_train = 100000
batch_size_test = 50000

dataset = EMNIST(root='data/', split = 'letters', download=True, transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))
val_size = 12000
train_size = len(dataset) - val_size

train_ds, val_ds = random_split(dataset, [train_size, val_size])
print(len(train_ds), len(val_ds))

train_loader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(val_ds, batch_size=batch_size_test, shuffle=True)

examples = enumerate(train_loader)
batch_idx, (example_data, example_targets) = next(examples)
print(example_data.shape)

examples_t = enumerate(test_loader)
batch_idx_t, (example_data_t, example_targets_t) = next(examples_t)
print(example_data_t.shape)

Downloading and extracting zip archive
Downloading http://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip to data/EMNIST/raw/emnist.zip


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/EMNIST/raw/emnist.zip to data/EMNIST/raw
Processing byclass
Processing bymerge
Processing balanced
Processing letters
Processing digits
Processing mnist
Done!
112800 12000
torch.Size([100000, 1, 28, 28])
torch.Size([12000, 1, 28, 28])


#If you download the zip file start from here
Please download the zip file, extracts four files and upload to a drive(your_path). Please click [here](https://drive.google.com/drive/folders/1GThh-QaZeF-b6J59khdScz-XK8FeXuYy?usp=sharing) to download the data.


In [5]:
import torch
import torch.nn as nn
class encoder(nn.Module):
  def __init__(self):
    super(encoder, self).__init__()
    self.nc_mnist = 1
    self.nc_cifar10 = 3
    self.conv1 = nn.Conv2d(1, 3, 3, 1, 1) 
    self.conv2 = nn.Conv2d(3, 6, 2, 2, 0)
    self.conv3 = nn.Conv2d(6, 12, 2, 2, 0)
    self.conv4 = nn.Conv2d(12, 24, 2, 2, 0)
  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.conv3(x)
    x = self.conv4(x)
    return x
class decoder(nn.Module):
  def __init__(self):
    super(decoder, self).__init__()
    self.nc_mnist = 118
    self.nc_cifar10 = 202
    self.nk_mnist = 3
    self.nk_cifar10 = 4
    self.decon1 = nn.ConvTranspose2d(134, 24, 3, 1, 0)
    self.decon2 = nn.ConvTranspose2d(24, 12, 3, 2, 0)
    self.decon3 = nn.ConvTranspose2d(12, 6, 2, 2, 0)
    self.decon4 = nn.ConvTranspose2d(6, 1, 2, 2, 0)
  def forward(self, x):
    x = x.view(x.shape[0], 134, 1, 1)
    x = self.decon1(x)
    x = self.decon2(x)
    x = self.decon3(x)
    x = self.decon4(x)
    return x

class VAE(nn.Module):
  def __init__(self, eps):
    super(VAE, self).__init__()
    self.en = encoder()
    self.de = decoder()
    self.eps = eps
    self.mnist_z = 108
    self.cifar10_z = 192
  def forward(self, x, one_hot):
    x = self.en(x)
    x = x.view(x.shape[0], -1)
    mu = x[:, :108]
    logvar = x[:, 108:]
    std = torch.exp(0.5 * logvar)
    z = mu + self.eps * std
    #print(z.shape, 'aaa', one_hot.shape)
    z1 = torch.cat((z, one_hot), axis = 1)
    #print(z1.shape, 'bbb')
    return self.de(z1), mu, logvar

class private(nn.Module):
  def __init__(self, eps):
    super(private, self).__init__()
    self.task = torch.nn.ModuleList()
    self.eps = eps
    for _ in range(13):
      self.task.append(VAE(self.eps))

  def forward(self, x, one_hot, task_id):
    return self.task[task_id].forward(x, one_hot)

class NET(nn.Module):
  def __init__(self, eps):
    super(NET, self).__init__()
    self.eps = eps
    self.shared = VAE(self.eps)
    self.private = private(self.eps)
    self.head = torch.nn.ModuleList()
    self.mnist = 216
    self.cifar10 = 384
    self.in_mnist = 2
    self.in_cifar10 = 6
    for _ in range(13):
      self.head.append(
          nn.Sequential(
            nn.Conv2d(2, 3, 3, 1, 1),
            nn.Conv2d(3, 6, 2, 2, 0),
            nn.Conv2d(6, 12, 2, 2, 0),
            nn.Conv2d(12, 24, 2, 2, 0),
            nn.Flatten(1, -1),
            #nn.Linear(24*16*16, 100), #for cifar10 only
            nn.Linear(216, 26)
          )
      )

  def forward(self, x, one_hot, task_id):
    s_x, s_mu, s_logvar = self.shared(x, one_hot)
    p_x, p_mu, p_logvar = self.private(x, one_hot, task_id)
    #print(s_x.shape, p_x.shape, '111')
    x = torch.cat([s_x, p_x], dim = 1)
    #print(x.shape, '22')
    return self.head[task_id].forward(x), (s_x, s_mu, s_logvar), (p_x, p_mu, p_logvar)


If you wish to change the number of epochs and synthetic data used as a generative replay, check lines 100 and 57, respectively. Change according to your requirments.

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from collections import deque
from torch.autograd import grad as torch_grad

import torchvision.utils as vutils

import os
import os.path

class CL_VAE():
  def __init__(self):
    super(CL_VAE, self).__init__()

    self.batch_size = 64
    self.mnist_z = 108
    self.cifar10_z = 192
    self.build_model()
    self.set_cuda()
    self.criterion = torch.nn.CrossEntropyLoss()
    self.recon = torch.nn.MSELoss()
    self.net_path = 'path/EMNIST.pth'

  def build_model(self):
    self.eps = torch.randn(self.batch_size, self.mnist_z)
    self.eps = self.eps.cuda()
    self.net = NET(self.eps)
    pytorch_total_params = sum(p.numel() for p in self.net.parameters() if p.requires_grad)
    print('pytorch_total_params:', pytorch_total_params)
    
  def set_cuda(self):
    self.net.cuda()



  def VAE_loss(self, recon, mu, sigma):
    kl_div = -0.5 * torch.sum(1 + sigma - mu.pow(2) - sigma.exp())
    #print('kl_div', kl_div.item())
    return recon + kl_div

  def train(self, all_traindata, all_trainlabels, all_testdata, all_testlabels, total_tasks):
    replay_classes = []
    for i in range(total_tasks):
      traindata = torch.tensor(all_traindata[i])
      trainlabels = torch.tensor(all_trainlabels[i])
      testdata = torch.tensor(all_testdata[i])
      testlabels = torch.tensor(all_testlabels[i])
      #print(trainlabels, 'avfr')
      replay_classes.append(sorted(list(set(trainlabels.numpy().tolist()))))
      if i + 1 == 1:
        self.train_task(traindata, trainlabels, testdata, testlabels, i)
        #replay_classes.append(sorted(list(set(trainlabels.detach().numpy().tolist()))))
      else:
        num_gen_samples = 4
        #z_dim = 108
        for m in range(i):
          #print(replay_classes, 'replay_classes')
          replay_trainlabels = []
          for ii in replay_classes[m]:
            for j in range(num_gen_samples):
              replay_trainlabels.append(ii)
          replay_trainlabels = torch.tensor(replay_trainlabels)
          replay_trainlabels_onehot = self.one_hot(replay_trainlabels)



          z = torch.randn(2 * num_gen_samples, self.mnist_z)
          z_one_hot = torch.cat((z, replay_trainlabels_onehot), axis = 1)
          z_one_hot = z_one_hot.cuda()

          replay_data = self.net.private.task[m].de(z_one_hot).detach().cpu()

          traindata = torch.cat((replay_data, traindata), axis = 0)
          trainlabels = torch.cat((replay_trainlabels, trainlabels))
          testdata = torch.cat((testdata, torch.tensor(all_testdata[m])), axis = 0)
          testlabels = torch.cat((testlabels, torch.tensor(all_testlabels[m])))
        #print(sorted(list(set(testlabels.detach().numpy().tolist()))), 'aaa', i + 1)
        self.train_task(traindata, trainlabels, testdata, testlabels, i)
    

  def one_hot(self, labels):
    matrix = torch.zeros(len(labels), 26)
    rows = np.arange(len(labels))
    matrix[rows, labels] = 1
    return matrix 

  def model_save(self):
    torch.save(self.net.state_dict(), os.path.join(self.net_path))


  def train_task(self, traindata, trainlabels, testdata, testlabels, task_id):

    net_opti = torch.optim.Adam(self.net.parameters(), lr = 1e-4)
    #data, label = traindata
    #batch_size = 64
    num_iterations = int(traindata.shape[0]/self.batch_size)
    num_epochs = 50
    #print(num_iterations, '451')
    for e in range(num_epochs):
      for i in range(num_iterations):
        self.net.zero_grad()
        self.net.train()
        

        batch_data = traindata[i * self.batch_size : (i + 1)*self.batch_size]
        #print(batch_data.shape, '41')
        batch_label = trainlabels[i * self.batch_size : (i + 1)*self.batch_size]
        batch_label_one_hot = self.one_hot(batch_label)
        batch_data = batch_data.cuda()
        batch_label = batch_label.cuda()
        batch_label_one_hot = batch_label_one_hot.cuda()

        out, shared_out, private_out = self.net(batch_data, batch_label_one_hot, task_id)
        s_x, s_mu, s_logvar = shared_out
        p_x, p_mu, p_logvar = private_out
        #print(out.shape, '12', batch_label.shape, s_x.shape)

        cross_en_loss = self.criterion(out, batch_label)

        s_recon = self.recon(batch_data, s_x)
        p_recon = self.recon(batch_data, p_x)

        s_VAE_loss = self.VAE_loss(s_recon, s_mu, s_logvar)
        p_VAE_loss = self.VAE_loss(p_recon, p_mu, p_logvar)

        all_loss = cross_en_loss + s_VAE_loss + p_VAE_loss

        all_loss.backward(retain_graph=True)
        net_opti.step()
      #print('epoch:', e + 1, 'task_loss', cross_en_loss.item(), 's_VAE:', s_VAE_loss.item(), 'p_VAE', p_VAE_loss.item())

      if (e + 1) % 25 == 0:
        acc1 = self.evall(testdata, testlabels, task_id)
        print('Task:', task_id + 1, 'acc', acc1)

        if task_id + 1 == 13:
          self.model_save()

  def ev(self, test, task_id):
    data, labels = test
    out, _, _ = self.net(data, task_id)
    pred_labels = torch.argmax(out, axis = 1)
    return (torch.sum(labels == pred_labels)/data.shape[0] * 100).detach().cpu().numpy().tolist()


  def evall(self, testdata, testlabels, task_id):
    self.net.eval()
    #data, labels = testdata
    #print(testdata.shape, '11', testlabels.shape)
    #batch_size = 64
    num_iterations = int(testdata.shape[0]/self.batch_size)
    acc = []
    for i in range(num_iterations):
      batch_data = testdata[i * self.batch_size : (i + 1) * self.batch_size]
      batch_labels = testlabels[i * self.batch_size : (i + 1) * self.batch_size]
      batch_label_one_hot = self.one_hot(batch_labels)
      batch_data = batch_data.cuda()
      batch_labels = batch_labels.cuda()
      batch_label_one_hot = batch_label_one_hot.cuda()
      out, _, _ = self.net(batch_data, batch_label_one_hot, task_id)
      pred_labels = torch.argmax(out, axis = 1)
      #print(pred_labels, 'aa')
      #print(pred_labels.shape, '1452', batch_labels)
      acc.append((torch.sum(batch_labels == pred_labels)/batch_data.shape[0] * 100).detach().cpu().numpy().tolist())
    #print('acc:', acc)
    return np.mean(np.array(acc))





In [3]:
import json
traindata_path = '/content/drive/MyDrive/ColabNotebooks/ContinualLearning/traindata_EMNIST.json'
trainlabels_path = '/content/drive/MyDrive/ColabNotebooks/ContinualLearning/trainlabels_EMNIST.json'
testdata_path = '/content/drive/MyDrive/ColabNotebooks/ContinualLearning/testdata_EMNIST.json'
testlabels_path = '/content/drive/MyDrive/ColabNotebooks/ContinualLearning/testlabels_EMNIST.json'

with open(traindata_path) as f:
  traindata = json.load(f)
with open(trainlabels_path) as f:
  trainlabels = json.load(f)
with open(testdata_path) as f:
  testdata = json.load(f)
with open(testlabels_path) as f:
  testlabels = json.load(f)


In [6]:
import time
aaa = CL_VAE()
sta = time.time()
aaa.train(traindata, trainlabels, testdata, testlabels, 13)
end = time.time()
print("It takes:", end - sta)

pytorch_total_params: 562939
