<a href="https://colab.research.google.com/github/DVGR-CZSL/DVGR-CZSL/blob/main/SUN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.autograd import Variable
from copy import deepcopy
from sklearn.preprocessing import normalize
import glob, os


In [None]:

class encoder(nn.Module):
  def __init__(self):
    super(encoder, self).__init__()
    self.fc1 = torch.nn.Linear(2048, 1000)
    self.fc2 = torch.nn.Linear(1000, 500)
    self.fc3 = torch.nn.Linear(500, 100)
    self.rel = torch.nn.ReLU()

  def forward(self, x):
    x = self.fc1(x)
    x = self.rel(x)
    x = self.fc2(x)
    x = self.rel(x)
    x = self.fc3(x)
    return x
class decoder(nn.Module):
  def __init__(self):
    super(decoder, self).__init__()
    self.n_e = 102
    self.n_y = 708
    self.fc1 = torch.nn.Linear(50 + self.n_e + self.n_y, 500)
    self.fc2 = torch.nn.Linear(500, 1000)
    self.fc3 = torch.nn.Linear(1000, 2048 + 708 + 102)
    self.rel = torch.nn.ReLU()

  def forward(self, x):
    x = self.fc1(x)
    x = self.rel(x)
    x = self.fc2(x)
    x = self.rel(x)
    x = self.fc3(x)
    x_out = x[:, :2048]
    y_out  = x[:, 2048: 2048 + 708]
    em_out = x[:, (2048 + 708):]
    return x_out, y_out, em_out

class VAE(nn.Module):
  def __init__(self, eps):
    super(VAE, self).__init__()
    self.en = encoder()
    self.de = decoder()
    self.eps = eps

  def forward(self, x, one_hot, cls_att):
    #print(x.shape, 'aa')
    x = self.en(x)
    mu = x[:, :50]
    logvar = x[:, 50:]
    std = torch.exp(0.5 * logvar)
    z = mu + self.eps * std
    z1 = torch.cat((z, one_hot), axis = 1)
    z1 = torch.cat((z1, cls_att), axis = 1)
    return self.de(z1), mu, logvar


class private(nn.Module):
  def __init__(self, eps):
    super(private, self).__init__()
    self.task = torch.nn.ModuleList()
    self.eps = eps
    for _ in range(15):
      self.task.append(VAE(self.eps))

  def forward(self, x, one_hot, cls_att, task_id):
    return self.task[task_id].forward(x, one_hot, cls_att)

class NET(nn.Module):
  def __init__(self, eps):
    super(NET, self).__init__()
    self.eps = eps
    #self.shared = VAE(self.eps)
    self.private = private(self.eps)
    #self.fc1 = torch.nn.Linear(4096, 2048)
    self.head = torch.nn.ModuleList()
    for _ in range(15):
      self.head.append(
          nn.Sequential(
              nn.Linear(2048, 1000),
              nn.Linear(1000, 500),
              nn.Linear(500, 708)
          )
      )
  def forward(self, x, one_hot, cls_att, task_id):
    #s_x, s_mu, s_logvar = self.shared(x, one_hot, cls_att)
    #print(s_x.shape)
    p_out, p_mu, p_logvar = self.private(x, one_hot, cls_att, task_id)
    #x = torch.cat((s_x, p_x), axis = 1)
    #x = self.fc1(x)

    return self.head[task_id].forward(x), (p_out, p_mu, p_logvar)

  def common_features(self, z, task_id):
    x_p, _, _ = self.private.task[task_id].de(z)
    #x_s = self.shared.de(z)
    #x = torch.cat((x_s, x_p), axis = 1)
    return  x_p #self.fc1(x)


In [None]:
path = 'FolderPath'
train_data_path = path + '/trainData'
train_label_path = path + '/trainLabels'
train_attr_path = path + '/trainAttributes'
test_data_path = path + '/testData'
test_label_path = path + '/testLabels'
test_attr_path = path + '/testAttributes'
attributes_path = path + '/dataAttributes'


In [None]:
def dataprocess(data_path):
  with open(data_path, 'rb') as fopen:
     #contents = np.load(fopen, allow_pickle=True, encoding='bytes')
    contents = np.load(fopen, allow_pickle=True, encoding='latin1')
    return contents

trainData1 = dataprocess(train_data_path)
trainLabels1 = dataprocess(train_label_path)
trainLabelsVectors1 = dataprocess(train_attr_path)
testData1 = dataprocess(test_data_path)
testLabels1 = dataprocess(test_label_path)
testlabelsvectors1 = dataprocess(test_attr_path)
ATTR = dataprocess(attributes_path)


In [None]:
class CLASSIFIER(nn.Module):
  def __init__(self):
    super(CLASSIFIER, self).__init__()
    self.fc1 = torch.nn.Linear(2048, 1500)
    self.fc2 = torch.nn.Linear(1500, 1000)
    self.fc3 = torch.nn.Linear(1000, 708)
    self.drop = nn.Dropout(p = 0.2)
    self.rel = torch.nn.ReLU()

  def forward(self, x):
    #print(x.shape, '254')
    x = self.fc1(x)
    x = self.rel(x)
    x = self.fc2(x)
    x = self.rel(x)
    x = self.drop(x)
    x = self.fc3(x)
    return x


In [None]:
from sklearn.preprocessing import normalize
from sklearn.preprocessing import StandardScaler
import random


class CL_VAE():
  def __init__(self):
    super(CL_VAE, self).__init__()

    self.batch_size = 64
    self.num_classes = 708
    self.build_model()
    self.set_cuda()
    self.criterion = torch.nn.CrossEntropyLoss()
    self.recon = torch.nn.MSELoss()
    #self.L1 = torch.nn.L1Loss()
    self.L1 = torch.nn.MSELoss()
    self.seen_acc = []
    self.unseen_acc = []
    self.hm_acc = []
    self.overall_acc = []


  def build_model(self):
    self.eps = torch.randn(self.batch_size, 50)
    self.eps = self.eps.cuda()
    self.net = NET(self.eps)
    pytorch_total_params = sum(p.numel() for p in self.net.parameters() if p.requires_grad)
    print('pytorch_total_params:', pytorch_total_params)
    
  def set_cuda(self):
    self.net.cuda()



  def VAE_loss(self, recon, mu, sigma):
    kl_div = -0.5 * torch.sum(1 + sigma - mu.pow(2) - sigma.exp())
    #print('kl_div', kl_div.item())
    return recon + kl_div

  def train(self, all_traindata, all_trainlabels, all_testdata, all_testlabels, all_train_attr, all_test_attr, all_attr, total_tasks):
    replay_classes = []
    for i in range(total_tasks):
      traindata = torch.tensor(all_traindata[i])
      trainlabels = torch.tensor(all_trainlabels[i])
      testdata = torch.tensor(all_testdata[i])
      testlabels = torch.tensor(all_testlabels[i])
      train_attr = torch.tensor(all_train_attr[i], dtype = torch.float32)
      test_attr = torch.tensor(all_test_attr[i])
      attr = torch.tensor(all_attr)
      #print()

      #print(trainlabels, 'avfr')
      replay_classes.append(sorted(list(set(trainlabels.numpy().tolist()))))
      if i + 1 == 1:
        self.train_task(traindata.float(), trainlabels, train_attr, i)
        #replay_classes.append(sorted(list(set(trainlabels.detach().numpy().tolist()))))
        
      else:
        num_gen_samples = 50
        #z_dim = 108
        for m in range(i):
          #print(replay_classes, 'replay_classes')
          replay_trainlabels = []
          for ii in replay_classes[m]:
            for j in range(num_gen_samples):
              replay_trainlabels.append(ii)
          replay_trainlabels = torch.tensor(replay_trainlabels)
          replay_trainlabels_onehot = self.one_hot(replay_trainlabels)
          replay_attr = torch.tensor(attr[replay_trainlabels])
          labels_attr = torch.cat((replay_trainlabels_onehot, replay_attr), axis = 1)

          z = torch.randn(replay_trainlabels.shape[0], 50)

          z_one_hot = torch.cat((z, labels_attr), axis = 1)
          z_one_hot = z_one_hot.cuda()

          replay_data = self.net.common_features(z_one_hot.float(), m).detach().cpu()

          train_attr = torch.cat((replay_attr, train_attr), axis = 0)
          traindata = torch.cat((replay_data, traindata), axis = 0)
          trainlabels = torch.cat((replay_trainlabels, trainlabels))
          testdata = torch.cat((testdata, torch.tensor(all_testdata[m])), axis = 0)
          testlabels = torch.cat((testlabels, torch.tensor(all_testlabels[m])))
        #print(sorted(list(set(testlabels.detach().numpy().tolist()))), 'aaa', i + 1)
        self.train_task(traindata.float(), trainlabels, train_attr.float(), i)
      testdata_unseen = []
      testlabels_unseen = []
      testdata_seen = []
      testlabels_seen = []
      for j in range(i + 1):
        testdata_seen = testdata_seen + all_testdata[j]
        testlabels_seen = testlabels_seen + all_testlabels[j]
      for k in range(j + 1, total_tasks):
        testdata_unseen = testdata_unseen + all_testdata[k]
        testlabels_unseen = testlabels_unseen + all_testlabels[k]
      
      all_labels = sorted(list(set(testlabels_seen))) + sorted(list(set(testlabels_unseen)))
      num_samples = 150
      labels_list = []
      for label in all_labels:
        for l in range(num_samples):
          labels_list.append(label)

      attr_labels = attr[labels_list]
      labels_list = torch.tensor(labels_list, dtype = torch.int64)
      labels_list_onehot = self.one_hot(labels_list)
      #print(labels_list_onehot.shape, 'aa', attr_labels.shape)
      attr_labels_onehot = torch.cat((labels_list_onehot, attr_labels), axis = 1)
      noise = torch.randn(len(labels_list), 50)
      noise_others = torch.cat((noise, attr_labels_onehot), axis = 1)
      noise_others = noise_others.float().cuda()
      #print(noise_others.shape, 'aaa')
      pseudodata  = self.net.common_features(noise_others, i)

      test_seen = torch.tensor(testdata_seen)
      
      testlabels_s = torch.tensor(testlabels_seen)
      testlabels_us = torch.tensor(testlabels_unseen)
      #print(test_seen.shape, test_unseen.shape, testlabels_s.shape, testlabels_us.shape)
      scaler = StandardScaler()
      pseudodata = torch.from_numpy(scaler.fit_transform(pseudodata.detach().cpu().numpy())).cuda()
      test_seen = torch.from_numpy(scaler.transform(test_seen.detach().numpy()))
      if i < total_tasks - 1:
        test_unseen = torch.tensor(testdata_unseen)
        test_unseen = torch.from_numpy(scaler.transform(test_unseen.detach().numpy()))
      #pseudodata = torch.from_numpy(normalize(pseudodata.detach().cpu().numpy(), axis = 1)).cuda()
      #test_seen = torch.from_numpy(normalize(pseudodata.detach().cpu().numpy), axis = 1).to(cuda)
      else:
        test_unseen = None
        testlabels_us = None


      self.class_train(i, pseudodata, labels_list.cuda(), test_seen, testlabels_s, test_unseen, testlabels_us)

  def dataloader(self, x, y, attr = None):
    #x = x.detach().numpy()
    #length = x.shape[0]
    length = x.size()[0]
    indices = np.arange(length)
    random.shuffle(indices)
    new_x = x[indices]
    new_y = y[indices]
    if attr is not None:
      new_attr = attr[indices]
      return new_x, new_y, new_attr
    else:
      return new_x, new_y
#print(x.shape, dataloader(x, args))


  def class_train(self, task_id, pseudodata, labels_list, test_seen, testlabels_s, test_unseen = None, testlabels_us = None):
    pseudodata, labels_list = self.dataloader(pseudodata, labels_list)
    #print(sorted(list(set(labels_list.detach().cpu().numpy()))), 'aaa')
    self.CLASS = CLASSIFIER()
    self.CLASS = self.CLASS.cuda()
    class_opti = torch.optim.Adam(self.CLASS.parameters(), lr = 1e-4)
    num_epochs = 25
    batch_s = 64
    num_iter = int(pseudodata.shape[0]/batch_s)
    for e in range(num_epochs):
      for i in range(num_iter):
        self.CLASS.train()
        self.CLASS.zero_grad()
        batch_data = pseudodata[i * batch_s : (i + 1) * batch_s]
        batch_label = labels_list[i * batch_s : (i + 1) * batch_s]
        #print(batch_data.shape, '145')
        out = self.CLASS(batch_data)
        loss = self.criterion(out, batch_label)
        loss.backward(retain_graph = True)
        class_opti.step()
          
      #print('Epoch:', e + 1, 'Loss:', loss.item())
    _, pred_s = torch.max(self.CLASS(test_seen.float().cuda()), axis = 1)
    if testlabels_us is not None:
      _, pred_us = torch.max(self.CLASS(test_unseen.float().cuda()), axis = 1)
      pred_us = pred_us.detach().cpu()

    pred_s = pred_s.detach().cpu()
    
    correct = {}
    total = {}
    for m in range(self.num_classes):
      correct[m] = 0 
      total[m] = 0
    for m in range(test_seen.shape[0]):
      #print(testlabels_s[m].item(), '44') #break
      if pred_s[m].item() == testlabels_s[m].item():
        #print(testlabels_s[m], '44')
        correct[testlabels_s[m].item()] += 1
      total[testlabels_s[m].item()] += 1
    
    acc1 = 0
    acc2 = 0
    num_s = 0
    num_us = 0
    seenclasses = sorted(list(set(testlabels_s.detach().cpu().numpy())))
    
    for m in seenclasses:
      acc1 += correct[m]*1/total[m]
      num_s += 1

    acc1 = acc1/num_s
    self.seen_acc.append(acc1)
    

    if testlabels_us is not None:
      unseenclasses = sorted(list(set(testlabels_us.detach().cpu().numpy())))
      for m in range(test_unseen.shape[0]):
        if pred_us[m].item() == testlabels_us[m].item():
          correct[testlabels_us[m].item()] += 1
        total[testlabels_us[m].item()] += 1
      for m in unseenclasses:
        acc2 += correct[m]/total[m]
        num_us += 1

      acc2 = acc2/num_us
      self.unseen_acc.append(acc2)
      
      self.hm_acc.append((2 * self.unseen_acc[task_id] * self.seen_acc[task_id])/(self.seen_acc[task_id] + self.unseen_acc[task_id]))
      self.overall_acc.append((len(testlabels_s) * self.seen_acc[task_id] + len(testlabels_us) * self.unseen_acc[task_id])/(len(testlabels_s) + len(testlabels_us)))
    
    print('self.seen_acc:', np.mean(self.seen_acc))
    print('self.unseen_acc:', np.mean(self.unseen_acc))
    print('self.hm_acc:', np.mean(self.hm_acc))

    
      



  def one_hot(self, labels):
    matrix = torch.zeros(len(labels), self.num_classes)
    rows = np.arange(len(labels))
    matrix[rows, labels] = 1
    return matrix 

  def model_save(self):
    torch.save(self.net.state_dict(), os.path.join(self.net_path))


  def train_task(self, traindata, trainlabels, train_attr, task_id):
    traindata, trainlabels, train_attr = self.dataloader(traindata, trainlabels, train_attr)
    net_opti = torch.optim.Adam(self.net.parameters(), lr = 1e-4)
    num_iterations = int(traindata.shape[0]/self.batch_size)
    num_epochs = 101 #51
    for e in range(num_epochs):
      for i in range(num_iterations):
        self.net.zero_grad()
        self.net.train()
        

        batch_data = traindata[i * self.batch_size : (i + 1)*self.batch_size]
        batch_label = trainlabels[i * self.batch_size : (i + 1)*self.batch_size]
        batch_train_attr = train_attr[i * self.batch_size : (i + 1)*self.batch_size]
        batch_label_one_hot = self.one_hot(batch_label)
        batch_data = batch_data.cuda()
        batch_label = batch_label.cuda()
        batch_label_one_hot = batch_label_one_hot.cuda()
        batch_train_attr = batch_train_attr.cuda()

        out, private_out = self.net(batch_data, batch_label_one_hot, batch_train_attr, task_id)
        #s_x, s_mu, s_logvar = shared_out
        p_out, p_mu, p_logvar = private_out
        p_x, p_y, p_em = p_out
        #print(p_y.shape, 'aa', p_em.shape)
        #print(batch_label.shape, 'ap', batch_train_attr.shape)
        #print(out.shape, '12', batch_label.shape, s_x.shape)
        #p_y_onehot = self.one_hot(p_y)

        cross_en_loss = self.criterion(out, batch_label)
        y_loss = self.L1(p_y, batch_label_one_hot)
        #print(p_em.shape, batch_train_attr.shape)
        em_loss = self.L1(p_em, batch_train_attr)

        #s_recon = self.recon(batch_data, s_x)
        p_recon = self.recon(batch_data, p_x)

        #s_VAE_loss = self.VAE_loss(s_recon, s_mu, s_logvar)
        p_VAE_loss = self.VAE_loss(p_recon, p_mu, p_logvar)

        all_loss =  cross_en_loss +  p_VAE_loss# + y_loss + em_loss#+  s_VAE_loss

        all_loss.backward(retain_graph=True)
        net_opti.step()
      #print('epoch:', e + 1, 'task_loss', cross_en_loss.item(), 'p_VAE', p_VAE_loss.item())
    



In [None]:
import time
model = CL_VAE()
st = time.time()
model.train(trainData1, trainLabels1, testData1, testLabels1, trainLabelsVectors1, testlabelsvectors1, ATTR, 15)
en = time.time()
print("It takes:", en - st, 'seconds')