# CIFAR-10 Classification

# Import

In [None]:
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
from torch.optim import Adam, SGD
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda
from torch.utils.data import DataLoader

import numpy as np
import random
import math

from sklearn.manifold import TSNE
import time

import altair as alt
alt.data_transformers.disable_max_rows()
import pandas as pd
GPU = True # Choose whether to use GPU
if GPU:
    device = torch.device("cuda"  if torch.cuda.is_available() else "cpu")
else:
    device = torch.device("cpu")
print(f'Using {device}')

Using cpu


# Network

In [None]:
class MFAN(torch.nn.Module):
    def __init__(self,
                 x_in, x_lr,
                 y_in, y_lr, d_lr,
                 hid_dims,# 500, 300, 300
                 num_epoch=1, num_epoch_D=300,
                 d_loss_func=torch.nn.CrossEntropyLoss(),
                 enc_ABC=[1,1,1]):
        super(MFAN, self).__init__()
        self.x_net = torch.nn.Sequential( torch.nn.Linear(x_in, hid_dims[0]) ).to(device) # 784_500
        self.y_net = torch.nn.Sequential( torch.nn.Linear(y_in, hid_dims[0]) ).to(device) # 784_500
        self.d_net = torch.nn.Sequential( torch.nn.Linear(hid_dims[0], y_in) ).to(device) # 500_10
        for i in range(len(hid_dims)-1):
            self.x_net.append( torch.nn.Linear( hid_dims[i], hid_dims[i+1]) ).to(device)
            self.y_net.append( torch.nn.Linear( hid_dims[i], hid_dims[i+1]) ).to(device)
            self.d_net.append( torch.nn.Linear( hid_dims[i+1], y_in) ).to(device)
        self.relu = torch.nn.ReLU()

        # optimizers
        self.x_opt = Adam(self.x_net.parameters(), lr=x_lr)
        self.y_opt = Adam(self.y_net.parameters(), lr=y_lr)
        self.d_opt = Adam(self.d_net.parameters(), lr=d_lr)

        self.d_loss_func = d_loss_func

        # train iter
        self.num_epoch = num_epoch
        # self.batch_size = batch
        self.num_epoch_D = num_epoch_D

        self.A, self.B, self.C= enc_ABC

    def forward(self, x, y_pos, y_neg, layer_ind):
        Zx = self.x_net[layer_ind]( x ) # [60000, 500]
        Zys = self.y_net[layer_ind]( torch.cat([y_neg, y_pos],0) ) # [120000, 500]
        Zy_neg, Zy_pos = Zys[:len(y_neg)], Zys[len(y_pos):] # [60000, 500], [60000, 500]
        pos_sim = F.cosine_similarity(Zx, Zy_pos, dim=1) * self.A
        neg_sim = F.cosine_similarity(Zx, Zy_neg, dim=1) * self.B

        y_neg_similarity = F.cosine_similarity(Zy_neg, Zy_pos, dim=1) * self.C
        loss = (- pos_sim + neg_sim + y_neg_similarity).mean()
        # loss = (- pos_sim + neg_sim).mean()
        self.x_opt.zero_grad()
        self.y_opt.zero_grad()
        loss.backward()
        self.x_opt.step()
        self.y_opt.step()
        return Zx, Zy_pos, Zy_neg, loss.item()

    def d_forward(self, Zx, true_y_pos, layer_ind):
        y_pred = self.d_net[layer_ind]( Zx )
        loss = self.d_loss_func(y_pred, true_y_pos)
        self.d_opt.zero_grad()
        loss.backward()
        self.d_opt.step()
        return y_pred, loss.item()

# Init_loaders

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
     Lambda(lambda x: torch.flatten(x))])

batch_size = 50000
test_batch_size = 10000

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                         shuffle=False)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:03<00:00, 55.3MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
def _load():
    x_tr, y_tr = next(iter(train_loader))
    x_tr, y_true = x_tr.to(device), y_tr.to(device)

    x_te, y_te = next(iter(test_loader))
    x_te, y_te = x_te.to(device), y_te.to(device)

    y_tr = F.one_hot(y_true).float().to(device)
    return x_tr, y_tr, x_te, y_te, y_true

# playMFAN()

In [None]:
def playMFAN(lr_x=1e-5, lr_y=1e-2, lr_d=1e-2, dim=[64, 32, 32],ABC=[0.6,0.3,0.6],EB=64,DB=64,EE=30,DE=64):
    x_tr, y_tr, x_te, y_te, y_true = _load()

    model = MFAN(x_tr.shape[1], 1e-5,
                10, 1e-2, 1e-2,
                dim,
                enc_ABC=ABC)
    num_classes = 10
    batch_size = EB
    dec_batch_size = DB
    num_epoch = EE
    d_num_epoch = DE

    x = x_tr.clone().to(device)
    y_pos = y_tr.clone().to(device)
    y_tr_neg = (y_true.cpu() + np.random.randint(1, num_classes, len(y_true))) % num_classes
    y_neg = F.one_hot(y_tr_neg).float().to(device)

    loss_dict = {'encoder':[], 'decoder':[]}
    # Start Train MFAN
    MFAN_start_time = time.time()
    model.train()
    Zx_layers_tr = [] # store for Decoder training
    # ----------------------- Encoder ---------------------------
    iter_batch = x.shape[0]//batch_size
    for li in range(len(model.x_net)):
        losses = []
        Zx_layer_data_tr = []
        for _ni in range(num_epoch):
        # for _ni in tqdm(range(num_epoch),
        #                 desc='Encode '+str(li+1)+'/'+str(len(model.x_net))+' layer; num_iter_batch='+str(iter_batch)):
            Zx, Zy_pos, Zy_neg = [], [], []
            for i in range(x.shape[0]//batch_size):
                start_i = i * batch_size
                end_i = start_i + batch_size
                xb, y_posb, y_negb = x[start_i:end_i], y_pos[start_i:end_i], y_neg[start_i:end_i]
                Zxb, Zy_posb, Zy_negb, _loss = model.forward(xb, y_posb, y_negb, li)
                losses.append(_loss)
                if Zx == []:
                    Zx, Zy_pos, Zy_neg = Zxb, Zy_posb, Zy_negb
                else:
                    Zx, Zy_pos, Zy_neg = torch.cat((Zx,Zxb),0), torch.cat((Zy_pos,Zy_posb),0), torch.cat((Zy_neg,Zy_negb),0)
            Zx_layer_data_tr = Zx.detach() # use only the last n_epoch's Zx to train decoder
        Zx_layers_tr.append(Zx_layer_data_tr)
        loss_dict['encoder'].append(losses)
        x, y_pos, y_neg = model.relu(F.normalize(Zx)).detach(), model.relu(F.normalize(Zy_pos)).detach(), model.relu(F.normalize(Zy_neg)).detach()
    # ----------------------- Decoder ---------------------------
    iter_batch = Zx_layers_tr[0].shape[0]//dec_batch_size
    for li in range(len(model.x_net)):
        losses = []
        for _ni in range(d_num_epoch):
        # for _ni in tqdm(range(d_num_epoch),
        #                 desc='Decode '+str(li+1)+'/'+str(len(model.x_net))+' layer; num_iter_batch='+str(iter_batch)):
            Zx_tr = Zx_layers_tr[li]
            y_pos = F.one_hot(y_true).float().to(device)
            for i in range(Zx_tr.shape[0]//dec_batch_size):
                start_i = i * dec_batch_size
                end_i = start_i + dec_batch_size
                Zx_trb, y_posb = Zx_tr[start_i:end_i], y_pos[start_i:end_i]
                y_predb, _loss = model.d_forward(Zx_trb, y_posb, li)
                losses.append(_loss)
        loss_dict['decoder'].append(losses)
    print()
    print("MFAN Train: --- %s seconds ---" % (time.time() - MFAN_start_time))
    # =========================== EVAL Accuracy ====================================
    y_tr_eval = y_tr.clone().to(device)
    y_te_eval = y_te.clone().to(device)
    x_tr_eval = x_tr.clone().to(device)
    x_te_eval = x_te.clone().to(device)

    MFAN_start_time = time.time()
    model.eval()

    acc_tr = [] # store train accuracy
    acc_te = [] # store test accuracy
    for i in range(len(model.d_net)):
        with torch.no_grad():
            Zx_tr_eval = model.x_net[i](x_tr_eval)
            y_pred_tr = model.d_net[i](Zx_tr_eval)
            # acc_tr.append((torch.argmax(y_pred_tr, dim=1) == y_tr_eval).sum().item() / y_tr_eval.size(0)) # Convert logits to labels
            acc_tr.append((torch.argmax(y_pred_tr, dim=1) == torch.argmax(y_tr, dim=1)).sum().item() / y_tr.size(0)) # Convert logits to labels

            x_tr_eval = model.relu(F.normalize(Zx_tr_eval)).detach()

            Zx_te_eval = model.x_net[i](x_te_eval)
            y_pred_te = model.d_net[i](Zx_te_eval)
            # acc_te.append((torch.argmax(y_pred_te, dim=1) == y_te_eval).sum().item() / y_te_eval.size(0)) # Convert logits to labels
            acc_te.append((torch.argmax(y_pred_te, dim=1) == y_te).sum().item() / y_te.size(0)) # Convert logits to labels

            x_te_eval = model.relu(F.normalize(Zx_te_eval)).detach()
    print("MFAN Eval: --- %s seconds ---" % (time.time() - MFAN_start_time))
    print('acc: tr',acc_tr,' te',acc_te)

In [None]:
playMFAN(
    dim=[256, 128, 64],
    ABC=[0.6,0.4,0.6],
    EB=256,DB=256,
    EE=1,DE=5,
    lr_x=1e-2, lr_y=1e-3, lr_d=1e-2)



MFAN Train: --- 29.704814195632935 seconds ---
MFAN Eval: --- 2.272536277770996 seconds ---
acc: tr [0.3977, 0.3327, 0.2776]  te [0.3775, 0.3277, 0.2757]
