In [None]:
import torch
torch.ones(1).to('cuda')

In [2]:
import os

In [None]:
print(os.getcwd()) 
!nvidia-smi 

In [None]:
# !pip install d2l
# !pip install munkres

In [5]:
import os
import gzip
import torch
import random
#import lightly
import label_map
import numpy as np
from torch import nn
from scipy import stats
from d2l import torch as d2l
import matplotlib.pyplot as plt
from torch.nn import functional as F
#import seaborn as sns; sns.set_theme()
# from lightly.models.modules.heads import SimSiamProjectionHead
# from lightly.models.modules.heads import SimSiamPredictionHead
from sklearn.metrics.cluster import normalized_mutual_info_score as NMI

In [6]:
class Generator(nn.Module):
    def __init__(self, class_num, z_dim):
        super().__init__()
        self.input_dim = z_dim + class_num
        self.output_dim = 16 
        
        #FC_block: 15 -> 256 -> 256 -> 16
        self.fc = nn.Sequential(
            nn.Linear(self.input_dim, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),

            nn.Linear(256, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),

            nn.Linear(256, self.output_dim),
            nn.Sigmoid(),
        )
        
    def forward(self, X):
        return self.fc(X) 

class Discriminator(nn.Module):
    def __init__(self, class_num):
        super().__init__()
        self.class_num = class_num
        
        #FC_block: 16 -> 256 -> 256 -> class_num+1
        self.fc = nn.Sequential(
            nn.Linear(16, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),

            nn.Linear(256, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),

            nn.Linear(256, class_num + 1),
        )
        
    def forward(self, X):
        return self.fc(X) 

class E_net(nn.Module):
    def __init__(self, class_num):
        super().__init__()
        self.class_num = class_num
        
        #FC_block: 16 -> 256 -> 256 -> class_num
        self.fc = nn.Sequential(
            nn.Linear(16, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),

            nn.Linear(256, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),

            nn.Linear(256, class_num),
        )
        
    def forward(self, X):
        return self.fc(X) 

class SimSiam(nn.Module):
    def __init__(
        self, pred_hidden_dim, out_dim
    ):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(16, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),

            nn.Linear(256, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),

            nn.Linear(256, out_dim),                
        )
        
        self.prediction_head = nn.Sequential(
            nn.Linear(out_dim, pred_hidden_dim),
            nn.BatchNorm1d(pred_hidden_dim),
            nn.LeakyReLU(0.2),

            nn.Linear(pred_hidden_dim, out_dim),                
        )

    def forward(self, x):
        
        # get projections
        z = self.encoder(x)

        z = nn.Softmax(dim=1)(z)
        H1 = (-1 * z * z.log()).sum() / z.shape[0] #average entropy
        emperical_z = z.mean(0)
        H2 = (-1 * emperical_z * emperical_z.log()).sum()
        # get predictions
        p = self.prediction_head(z)
        # stop gradient
        z = z.detach()
        return z, p, H1, H2


In [7]:
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
        m.weight.data.normal_(0, 0.02)
        m.bias.data.zero_()      

def gen_cond_label(batch_size, class_num, z_dim):
    conditional_label = torch.zeros(batch_size, class_num)
    cluster_size = round(batch_size / class_num)
    for i in range(class_num):
        if i == class_num - 1:
            conditional_label[i * cluster_size : , i] = 1
        else:
            conditional_label[i * cluster_size : (i + 1) * cluster_size, i] = 1
    G_input = torch.cat([conditional_label, torch.rand(batch_size, z_dim)], 1)
    return G_input, conditional_label

def Accuracy(true_label, pred_label):
  k_set = torch.unique(true_label)
  correct_num = 0
  for i in k_set:
      idx = true_label == i
      cluster_i = pred_label[idx]
      correct_num += torch.max(torch.bincount(cluster_i.int())) 
  accuracy = correct_num / len(true_label)
  return float(accuracy)

In [8]:
def train_GAN(D, G, E, optimizer_D, optimizer_G, torch_dataset, class_num, z_dim, first_epoch, m, epoch, device):
  for epoch_GAN in range(5):
    metric = d2l.Accumulator(3)
    train_iter = torch.utils.data.DataLoader(dataset = torch_dataset, batch_size = 64,
                        shuffle = True, num_workers = 4)
    num_batches = len(train_iter)
    for i, (X, _) in enumerate(train_iter):
      #train D first
      optimizer_D.zero_grad()
      X = X.to(device)
      X_size = X.shape[0]
      #real_images_loss 
      D_real = D(X)
      if first_epoch:
        Y_real = 1/class_num * torch.ones(X_size, class_num, device = device)
      else:
        Y_real = nn.Softmax(dim = 1)(E(X)).detach()
      Y_real = torch.cat([Y_real, 2 * torch.ones(X_size, 1, device = device)], 1)
      ones = torch.ones_like(Y_real)
      ones[:, -1] = 0
      D_real_loss = torch.nn.BCEWithLogitsLoss(weight = Y_real)(D_real, ones)
      
      #fake_images_loss
      G_input, _ = gen_cond_label(X_size, class_num, z_dim)
      G_input = G_input.to(device)
      X_fake = G(G_input).detach() 
      D_fake = D(X_fake)
      Y_fake = torch.zeros_like(Y_real)
      Y_fake[:, -1] = 1
      D_fake_loss = torch.nn.BCEWithLogitsLoss()(D_fake, Y_fake)

      #total loss
      D_loss = D_real_loss + D_fake_loss
      D_loss.backward()
      optimizer_D.step()
      
      #train G
      #G.train()
      #D.eval()
      optimizer_G.zero_grad()

      G_input, cond_label = gen_cond_label(X_size, class_num, z_dim)
      G_input = G_input.to(device)
      X_fake = G(G_input)
      D_fake = D(X_fake)
      Y_fake = torch.cat([cond_label, torch.zeros(X_size, 1)], 1)
      Y_fake = Y_fake.to(device)

      G_loss = torch.nn.BCEWithLogitsLoss()(D_fake, Y_fake)
      G_loss.backward()
      optimizer_G.step()
      


def train_simsiam(Nets, my_SimSiam, optimizer_my_SimSiam, cluster_size, class_num, iterations, z_dim, M, epoch, animator, device):

    fake_num = cluster_size * class_num
    for m in range(M):
        Nets[f'fake_samples_{m}'] = torch.zeros(0, device = device)
        Nets[f'fake_labels_{m}'] = torch.zeros(0, device = device)

    metric = d2l.Accumulator(4)
    for iter in range(iterations):   
        #generate
        X_to_fuse = torch.zeros(0, device = device)   
        for m in range(M):
            G_input, cond_label = gen_cond_label(fake_num, class_num, z_dim)
            pseudo_label = torch.argmax(cond_label, 1)
            G_input, pseudo_label = G_input.to(device), pseudo_label.to(device)
            X_fake = Nets[f'G_{m}'](G_input).detach()
            # X_fake = Nets[f'G_{anchor_idx}'](G_input).detach()
          
            #save
            X_to_fuse = torch.cat([X_to_fuse, X_fake], 0)
            Nets[f'fake_samples_{m}'] = torch.cat([Nets[f'fake_samples_{m}'], X_fake], 0)
            Nets[f'fake_labels_{m}'] = torch.cat([Nets[f'fake_labels_{m}'], pseudo_label], 0)

        #contrast learning
        my_SimSiam.train()
        optimizer_my_SimSiam.zero_grad()
        z, p, H1, H2 = my_SimSiam(X_to_fuse)
        my_SimSiam_loss = 0
        for m in range(M):
            for i in range(class_num):
                idx1, idx2 = cluster_size * i + fake_num * m, cluster_size * (i + 1) + fake_num * m
                z_i, p_i = z[idx1 : idx2], p[idx1 : idx2]
                z_i_norm, p_i_norm = F.normalize(z_i), F.normalize(p_i)
                my_SimSiam_loss -= torch.mm(z_i_norm, p_i_norm.T).sum() / (cluster_size ** 2)
        total_loss = my_SimSiam_loss + 15 * M * (H1 - H2)
        total_loss.backward()
        optimizer_my_SimSiam.step()

        #record loss
        with torch.no_grad():
          metric.add(my_SimSiam_loss, H1, H2, 1)
        my_SimSiam_loss = metric[0] / metric[3]
        H1 = metric[1] / metric[3]
        H2 = metric[2] / metric[3]
        if (iter + 1) % 100 == 0 or iter == 0:
          # with open('output/pendigits/results/CCEGAN/r2/log_loss.txt', 'a') as f: #save output
          #   f.write(f'epoch_iteration = {5 * (epoch + 1)}_{iter + 1}, my_SimSiam_loss = {my_SimSiam_loss}, H1 = {H1}, H2 = {H2}\n')
          animator.add(5 * (epoch + (iter + 1) / iterations), (my_SimSiam_loss, H1, H2))
    ##evaluate by SIMSIAM
    SIMSIAM_flag = True
    eval_by_E(my_SimSiam, 'my_SimSiam', epoch, torch_dataset, SIMSIAM_flag, device)
          
def eval_by_E(E, m, epoch, torch_dataset, SIMSIAM_falg, device):
  E.eval()
  pred_label = torch.zeros(0, device = device)
  true_label = torch.zeros(0)
  eval_iter = torch.utils.data.DataLoader(dataset = torch_dataset, batch_size = 2000,
                    shuffle = False, num_workers = 4)
  for X, y in eval_iter:
    X = X.to(device)
    if SIMSIAM_falg:
      pred = E(X)[0].detach()
    else:
      pred = nn.Softmax(dim =1)(E(X)).detach()
    label = torch.argmax(pred, 1)
    pred_label = torch.cat([pred_label, label])
    true_label = torch.cat([true_label, y])
  pred_label = pred_label.to('cpu')
  accuracy = Accuracy(true_label, pred_label)
  nmi = NMI(true_label, pred_label)
  print(f'E_{m}: {accuracy}, {nmi}')
  # if SIMSIAM_falg and (epoch + 1) % 4 == 0:
  #     torch.save((true_label.type(torch.LongTensor), pred_label.type(torch.LongTensor)),
  #           f'output/pendigits/results/CCEGAN/r2/labels_{5 * (epoch + 1)}')
  with open('performance.txt', 'a') as f:
      f.write(f'E_base, round = {(epoch + 1)}, E_{m}, Accuracy = {accuracy}, NMI = {nmi}\n')

def filter_images(my_SimSiam, fake_samples, fake_labels, class_num, cluster_size_chosen, device):
    my_SimSiam.eval()
    fake_samples_filtered = torch.zeros(0, device = device)
    cluster_size = []
    for i in range(class_num): 
        cluster_idx = fake_labels == i
        images_cluster_i = fake_samples[cluster_idx]
        cluster_i_size = images_cluster_i.shape[0]
        
        pred_conf = torch.zeros((cluster_i_size, class_num), device = device)
        pred_label = torch.zeros(cluster_i_size)
        batch_num = int(cluster_i_size / 500)
        for i in range(batch_num):
            idx1, idx2 = 500 * i, 500 * (i + 1)
            pred = my_SimSiam(images_cluster_i[idx1 : idx2])[0].detach()
            label = torch.argmax(pred, 1)
            pred_conf[idx1 : idx2] = pred
            pred_label[idx1 : idx2] = label
        pred_conf = pred_conf.cpu()
        label_mode, _ = pred_label.mode()
        
        idx_chosen = pred_label == label_mode
        pred_conf = pred_conf[idx_chosen]
        images_cluster_i = images_cluster_i[idx_chosen]
        mode_num = sum(idx_chosen)
        if mode_num > cluster_size_chosen:
            v, _ = pred_conf.max(dim = 1)
            _, idx_high = v.sort(descending = True)
            idx_high_chosen = idx_high[:cluster_size_chosen]
            images_cluster_i_chosen = images_cluster_i[idx_high_chosen]
            fake_samples_filtered = torch.cat([fake_samples_filtered, images_cluster_i_chosen], 0)
            cluster_size.append(cluster_size_chosen)
        else:
            fake_samples_filtered = torch.cat([fake_samples_filtered, images_cluster_i], 0)
            cluster_size.append(mode_num)
    return fake_samples_filtered, torch.tensor(np.repeat(range(class_num), cluster_size), device = device).long()

def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True


In [9]:
def firt_filter(my_SimSiam, Nets, torch_dataset, cluster_size_chosen, M, epoch, device):
    fake_samples_filtered_all = torch.zeros(0, device = device)
    fake_samples_num_ls = []
  
    for m in range(M):
        fake_samples_filtered, fake_labels_filtered = filter_images(my_SimSiam, Nets[f'fake_samples_{m}'], Nets[f'fake_labels_{m}'],
                                      class_num, cluster_size_chosen, device)
        fake_samples_filtered_all = torch.cat([fake_samples_filtered_all, fake_samples_filtered], 0)
        fake_samples_num_ls.append(fake_samples_filtered.shape[0])
        torch_fake = torch.utils.data.TensorDataset(fake_samples_filtered, fake_labels_filtered)
        # with open('output/pendigits/results/CCEGAN/r2/log_loss.txt', 'a') as f:
        #     f.write('\n')
        for iter in range(3):
          count = 0
          train_fake = torch.utils.data.DataLoader(dataset = torch_fake, batch_size = 256, shuffle = True, num_workers = 0)
          batch_num = len(train_fake)
          for X_fake, y_fake in train_fake:
            Nets[f'optimizer_E_{m}'].zero_grad()
            pred = Nets[f'E_{m}'](X_fake)
            E_loss = torch.nn.CrossEntropyLoss()(pred, y_fake)
            E_loss.backward()
            Nets[f'optimizer_E_{m}'].step()
            with torch.no_grad():
              count += E_loss
          E_loss = count / batch_num
          # with open('output/pendigits/results/CCEGAN/r2/log_loss.txt', 'a') as f:
          #   f.write(f'epoch_iteration = {5 * (epoch + 1)}_{iter + 1}, E_{m}_loss = {E_loss}\n')
        ##evaluate by Es
        SIMSIAM_flag = False
        eval_by_E(Nets[f'E_{m}'], m, epoch, torch_dataset, SIMSIAM_flag, device)
    return fake_samples_filtered_all, fake_samples_num_ls

def second_filter(Nets, fake_samples_filtered_all, fake_samples_num_ls, first_epoch, M, epoch, device):
    base_partitions = torch.zeros(0)
    #generate pseudo_labels
    fake_iter = torch.utils.data.DataLoader(dataset = fake_samples_filtered_all, batch_size = 2000,
                        shuffle = False, num_workers = 0)
    for m in range(M):
        #label assignment by base E
        pred = label_assignment_by_base_E(Nets[f'E_{m}'], fake_iter, device)
        base_partitions = torch.cat([base_partitions, pred.reshape(-1, 1)], 1)
        base_partitions = base_partitions.type(torch.LongTensor)
    #label alignment
    base_fake_aligned = torch.zeros_like(base_partitions)
    anchor = base_partitions[:, 0]
    base_fake_aligned[:, 0] = anchor
    if first_epoch: 
        for m in range(1, M):
            base = base_partitions[:, m]
            base_aligned, base_to_anchor, anchor_to_base = label_map.label_map(anchor, base)
            base_fake_aligned[:, m] = torch.tensor(base_aligned)
            Nets[f'Base_to_anchor_{m}'] = base_to_anchor #save map
            Nets[f'Anchor_to_base_{m}'] = anchor_to_base #save map
    else:
        for m in range(1, M):
            base = base_partitions[:, m]
            for i in range(class_num):
                idx_base = base == i
                base_fake_aligned[:, m][idx_base] = Nets[f'Base_to_anchor_{m}'][i]
    label_fused, count_vote = stats.mode(base_fake_aligned, axis = 1) #voted pred_label
    idx_high_confidence = (count_vote > M/2).squeeze()


    SIMSIAM_flag = False
    idx_start, idx_end = 0, 0
    for m in range(M):
        Nets[f'E_{m}'].train()
        idx_end += fake_samples_num_ls[m] 
        fake_samples_chosen = fake_samples_filtered_all[idx_start : idx_end]
        label_fused_chosen = label_fused[idx_start : idx_end]
        label_reset = np.zeros_like(label_fused_chosen)
        if m == 0:
            label_reset = label_fused_chosen.squeeze()
        else:
            for i in range(class_num):
                idx_cluster = label_fused_chosen == i
                label_reset[idx_cluster] = Nets[f'Anchor_to_base_{m}'][i]
            label_reset = label_reset.squeeze()
        idx_high_confidence_chosen = idx_high_confidence[idx_start : idx_end]
        fake_dataset = torch.utils.data.TensorDataset(fake_samples_chosen[idx_high_confidence_chosen],
                                torch.tensor(label_reset)[idx_high_confidence_chosen])
        count = 0
        for _ in range(2):
          fake_iter = torch.utils.data.DataLoader(dataset = fake_dataset,
                            batch_size = 128, shuffle = True, num_workers = 0)
          for X, y in fake_iter:
              X, y = X.to(device), y.to(device)
              Nets[f'optimizer_E_{m}'].zero_grad()
              pred = Nets[f'E_{m}'](X)
              E_loss = torch.nn.CrossEntropyLoss()(pred, y)
              E_loss.backward()
              Nets[f'optimizer_E_{m}'].step()
              count += E_loss #save E_loss
        idx_start = idx_end
        eval_by_E(Nets[f'E_{m}'], m, epoch, torch_dataset, SIMSIAM_flag, device)
    
def label_assignment_by_base_E(E, train_iter, device):
    E.eval()
    pred_label = torch.zeros(0, device = device)
    for X in train_iter:
        X = X.to(device)
        pred = nn.Softmax(dim =1)(E(X)).detach()
        label = torch.argmax(pred, 1)
        pred_label = torch.cat([pred_label, label], 0)
    return pred_label.cpu()

In [10]:
def train_CCEGAN(torch_dataset, M, class_num, z_dim, num_epochs, device):
    #setup_seed(20)
    out_dim = class_num
    pred_hidden_dim = 64
    my_SimSiam = SimSiam(pred_hidden_dim, out_dim)
    my_SimSiam.to(device)
    optimizer_my_SimSiam = torch.optim.Adam(my_SimSiam.parameters(), lr = 6e-5, betas = (0.5, 0.99), weight_decay = 2.5 * 1e-5)

    Nets = locals()  
    for m in range(M):
        Nets[f'G_{m}'] = Generator(class_num, z_dim)
        Nets[f'D_{m}'] = Discriminator(class_num)
        Nets[f'E_{m}'] = E_net(class_num)

        Nets[f'G_{m}'].apply(init_weights)
        Nets[f'D_{m}'].apply(init_weights)
        Nets[f'E_{m}'].apply(init_weights)
        
        Nets[f'G_{m}'].to(device) 
        Nets[f'D_{m}'].to(device) 
        Nets[f'E_{m}'].to(device)

        Nets[f'optimizer_G_{m}'] = torch.optim.Adam(Nets[f'G_{m}'].parameters(),
                          lr = 3e-4, betas = (0.5, 0.9), weight_decay = 2.5 * 1e-5)
        Nets[f'optimizer_D_{m}'] = torch.optim.Adam(Nets[f'D_{m}'].parameters(),
                          lr = 6e-4, betas = (0.5, 0.9), weight_decay = 2.5 * 1e-5)
        Nets[f'optimizer_E_{m}'] = torch.optim.Adam(Nets[f'E_{m}'].parameters(),
                          lr = 3e-5, betas = (0.5, 0.9), weight_decay = 2.5 * 1e-5)



    first_epoch = True 
    animator = d2l.Animator(xlabel = 'iter', xlim = [1, num_epochs * 5], 
            legend = ['my_SimSiam_loss', 'H1', 'H2'])
    for epoch in range(num_epochs):
        for m in range(M):
            Nets[f'G_{m}'].train()
            Nets[f'D_{m}'].train()
            Nets[f'E_{m}'].train()
            
            #train model
            ##train GAN first
            train_GAN(Nets[f'D_{m}'], Nets[f'G_{m}'], Nets[f'E_{m}'], Nets[f'optimizer_D_{m}'], Nets[f'optimizer_G_{m}'], 
                torch_dataset, class_num, z_dim, first_epoch, m, epoch, device)
       
        #train SimSiam
        cluster_size = 4  #M = 5
        iterations = 1000
        train_simsiam(Nets, my_SimSiam, optimizer_my_SimSiam, cluster_size, class_num, iterations, z_dim, M, epoch, animator, device)
        
        #train Es
        ##filter images
        ###first filter      
        cluster_size_chosen = 500 * 5
        fake_samples_filtered_all, fake_samples_num_ls = firt_filter(my_SimSiam, Nets, torch_dataset, cluster_size_chosen, M, epoch, device)
        ###second filter
        second_filter(Nets, fake_samples_filtered_all, fake_samples_num_ls, first_epoch, M, epoch, device)
        
        first_epoch = False

In [11]:
finp_tr = 'pendigits/pendigits.tra.txt'
finp_tes = 'pendigits/pendigits.tes.txt'
data_tr = np.loadtxt(finp_tr, delimiter=',')
X_train = data_tr[:, 0:16]
X_train /= 100.0
Y_train = data_tr[:, -1].astype(int)

data_tes = np.loadtxt(finp_tes, delimiter=',')
X_test = data_tes[:, 0:16]
X_test /= 100.0
Y_test = data_tes[:, -1].astype(int)

X = np.concatenate((X_train, X_test))
Y = np.concatenate((Y_train, Y_test)).astype(int)
torch_dataset = torch.utils.data.TensorDataset(torch.tensor(X).float(), torch.tensor(Y))

In [12]:
z_dim, class_num, num_epochs, M = 5, 10, 12, 5
device = d2l.try_gpu()

In [None]:
#start = torch.cuda.Event(enable_timing=True)
#end = torch.cuda.Event(enable_timing=True)

#start.record()
train_CCEGAN(torch_dataset, M, class_num, z_dim, num_epochs, device)
#end.record()

# Waits for everything to finish running
#torch.cuda.synchronize()

#print(start.elapsed_time(end))