# MNIST Classification

# Import

In [1]:
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from tqdm import tqdm
from torch.optim import Adam, SGD
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize, Lambda
from torch.utils.data import DataLoader

import numpy as np
import random
import math

from sklearn.manifold import TSNE
import time

import altair as alt
alt.data_transformers.disable_max_rows()
import pandas as pd
GPU = True # Choose whether to use GPU
if GPU:
    device = torch.device("cuda"  if torch.cuda.is_available() else "cpu")
else:
    device = torch.device("cpu")
print(f'Using {device}')

Using cpu


# Network

In [2]:
class CFF(torch.nn.Module):
    def __init__(self,
                 x_in, x_lr,
                 y_in, y_lr, d_lr,
                 hid_dims,# 500, 300, 300
                 num_epoch=1, num_epoch_D=300,
                 d_loss_func=torch.nn.CrossEntropyLoss(),
                 enc_ABC=[1,1,1]):
        super(CFF, self).__init__()
        self.x_net = torch.nn.Sequential( torch.nn.Linear(x_in, hid_dims[0]) ).to(device) # 784_500
        self.y_net = torch.nn.Sequential( torch.nn.Linear(y_in, hid_dims[0]) ).to(device) # 784_500
        self.d_net = torch.nn.Sequential( torch.nn.Linear(hid_dims[0], y_in) ).to(device) # 500_10
        for i in range(len(hid_dims)-1):
            self.x_net.append( torch.nn.Linear( hid_dims[i], hid_dims[i+1]) ).to(device)
            self.y_net.append( torch.nn.Linear( hid_dims[i], hid_dims[i+1]) ).to(device)
            self.d_net.append( torch.nn.Linear( hid_dims[i+1], y_in) ).to(device)
        self.relu = torch.nn.ReLU()

        # optimizers
        self.x_opt = Adam(self.x_net.parameters(), lr=x_lr)
        self.y_opt = Adam(self.y_net.parameters(), lr=y_lr)
        self.d_opt = Adam(self.d_net.parameters(), lr=d_lr)

        self.d_loss_func = d_loss_func

        # train iter
        self.num_epoch = num_epoch
        # self.batch_size = batch
        self.num_epoch_D = num_epoch_D

        self.A, self.B, self.C= enc_ABC

    def forward(self, x, y_pos, y_neg, layer_ind):
        Zx = self.x_net[layer_ind]( x ) # [60000, 500]
        Zys = self.y_net[layer_ind]( torch.cat([y_neg, y_pos],0) ) # [120000, 500]
        Zy_neg, Zy_pos = Zys[:len(y_neg)], Zys[len(y_pos):] # [60000, 500], [60000, 500]
        pos_sim = F.cosine_similarity(Zx, Zy_pos, dim=1) * self.A
        neg_sim = F.cosine_similarity(Zx, Zy_neg, dim=1) * self.B

        y_neg_similarity = F.cosine_similarity(Zy_neg, Zy_pos, dim=1) * self.C
        loss = (- pos_sim + neg_sim + y_neg_similarity).mean()
        # loss = (- pos_sim + neg_sim).mean()
        self.x_opt.zero_grad()
        self.y_opt.zero_grad()
        loss.backward()
        self.x_opt.step()
        self.y_opt.step()
        return Zx, Zy_pos, Zy_neg, loss.item()

    def d_forward(self, Zx, true_y_pos, layer_ind):
        y_pred = self.d_net[layer_ind]( Zx )
        loss = self.d_loss_func(y_pred, true_y_pos)
        self.d_opt.zero_grad()
        loss.backward()
        self.d_opt.step()
        return y_pred, loss.item()

# Init_loaders

In [3]:
def MNIST_loaders(train_batch_size=60000, test_batch_size=10000):
    transform = Compose([
        ToTensor(),
        Normalize((0.5,), (0.5,)),
        Lambda(lambda x: torch.flatten(x))])

    train_loader = DataLoader(
        MNIST('./data/', train=True,
              download=True,
              transform=transform),
        batch_size=train_batch_size, shuffle=True)

    test_loader = DataLoader(
        MNIST('./data/', train=False,
              download=True,
              transform=transform),
        batch_size=test_batch_size, shuffle=False)
    return train_loader, test_loader
torch.manual_seed(1234)
train_loader, test_loader = MNIST_loaders()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 34773523.26it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 1178477.42it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 9823458.20it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 1273601.33it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



# Train()

In [5]:
x_tr, y_tr = next(iter(train_loader))
x_tr, y_tr = x_tr.to(device), y_tr.to(device)

x_te, y_te = next(iter(test_loader))
x_te, y_te = x_te.to(device), y_te.to(device)
model = CFF(784, 5e-5,
            10, 5e-4, 5e-2,
            [500, 300, 300],
            num_epoch=1, enc_ABC=[0.6,0.4,0.4])
model.train()
num_classes = 10
batch_size = 240
x = x_tr.clone().to(device)
y_pos = F.one_hot(y_tr).float().to(device)
y_tr_neg = (y_tr.cpu() + np.random.randint(1, num_classes, len(y_tr))) % num_classes
y_neg = F.one_hot(y_tr_neg).float().to(device)

loss_dict = {'encoder':[], 'decoder':[]}
Zx_layers_tr = [] # store for Decoder training
# ----------------------- Encoder ---------------------------
iter_batch = x.shape[0]//batch_size
for li in range(len(model.x_net)):
    losses = []
    Zx_layer_data_tr = []
    for _ni in tqdm(range(model.num_epoch),
                    desc='Encode '+str(li+1)+'/'+str(len(model.x_net))+' layer; num_iter_batch='+str(iter_batch)):
        Zx, Zy_pos, Zy_neg = [], [], []
        for i in range(x.shape[0]//batch_size):
            start_i = i * batch_size
            end_i = start_i + batch_size
            xb, y_posb, y_negb = x[start_i:end_i], y_pos[start_i:end_i], y_neg[start_i:end_i]
            Zxb, Zy_posb, Zy_negb, _loss = model.forward(xb, y_posb, y_negb, li)
            losses.append(_loss)
            if Zx == []:
                Zx, Zy_pos, Zy_neg = Zxb, Zy_posb, Zy_negb
            else:
                Zx, Zy_pos, Zy_neg = torch.cat((Zx,Zxb),0), torch.cat((Zy_pos,Zy_posb),0), torch.cat((Zy_neg,Zy_negb),0)
        Zx_layer_data_tr = Zx.detach() # use only the last n_epoch's Zx to train decoder
    Zx_layers_tr.append(Zx_layer_data_tr)
    loss_dict['encoder'].append(losses)
    x, y_pos, y_neg = model.relu(F.normalize(Zx)).detach(), model.relu(F.normalize(Zy_pos)).detach(), model.relu(F.normalize(Zy_neg)).detach()
# ----------------------- Decoder ---------------------------
batch_size = 360
iter_batch = Zx_layers_tr[0].shape[0]//batch_size
d_num_epoch = 100
for li in range(len(model.x_net)):
    losses = []
    for _ni in tqdm(range(d_num_epoch),
                    desc='Decode '+str(li+1)+'/'+str(len(model.x_net))+' layer; num_iter_batch='+str(iter_batch)):
        Zx_tr = Zx_layers_tr[li]
        y_pos = F.one_hot(y_tr).float().to(device)
        for i in range(Zx_tr.shape[0]//batch_size):
            start_i = i * batch_size
            end_i = start_i + batch_size
            Zx_trb, y_posb = Zx_tr[start_i:end_i], y_pos[start_i:end_i]
            y_predb, _loss = model.d_forward(Zx_trb, y_posb, li)
            losses.append(_loss)
    loss_dict['decoder'].append(losses)

Encode 1/3 layer; num_iter_batch=250: 100%|██████████| 1/1 [00:42<00:00, 42.16s/it]
Encode 2/3 layer; num_iter_batch=250: 100%|██████████| 1/1 [00:28<00:00, 28.10s/it]
Encode 3/3 layer; num_iter_batch=250: 100%|██████████| 1/1 [00:26<00:00, 26.49s/it]
Decode 1/3 layer; num_iter_batch=166: 100%|██████████| 100/100 [00:36<00:00,  2.72it/s]
Decode 2/3 layer; num_iter_batch=166: 100%|██████████| 100/100 [00:20<00:00,  4.89it/s]
Decode 3/3 layer; num_iter_batch=166: 100%|██████████| 100/100 [00:20<00:00,  4.78it/s]


In [None]:
# # ---------------------- Loss Plot --------------------------
# n_col = len(model.x_net)
# fig, axes = plt.subplots(nrows=1, ncols=n_col, figsize=(10, 2))
# for i in range(n_col):
#     print('encode', len(loss_dict['encoder']), len(loss_dict['encoder'][i]))
#     axes[i].plot(loss_dict['encoder'][i])
#     axes[i].set_title('Encoder Loss L'+str(i))

# fig, axes = plt.subplots(nrows=1, ncols=n_col, figsize=(10, 2))
# for i in range(n_col):
#     axes[i].plot(loss_dict['decoder'][i])
#     axes[i].set_title('Decoder Loss L'+str(i))

# Eval()

In [10]:
do_tsne = True
model.eval()
Zx_layers_te = [] # store for y prediction
# ============================== EVAL T-SNE ====================================
x = x_te.clone().to(device)
y_pos = F.one_hot(torch.tensor(np.arange(10))).float().to(device)
perplexity = 10
embed_x, embed_y = [], [] # store for T-SNE plot
# ---------------- Get embed ---------------------
for i in range(len(model.x_net)):
    with torch.no_grad():
        zx_list = model.x_net[i](x.to(device))
        Zx_layers_te.append(zx_list) # store Zx
        zx_list = F.normalize(zx_list, p=2, dim=1).detach()

        zy_list = model.y_net[i](y_pos.to(device))
        zy_list = F.normalize(zy_list, p=2, dim=1).detach()
        #........................ T-SNE ......................
        if do_tsne:
            num_emb_x = 10000 # or 5000 (time saving)
            z = torch.cat((zx_list[:num_emb_x], zy_list),0)
            print("--- start fit tsne (L",i,") ---: num_points=",len(z))
            start_time = time.time()
            tsne = TSNE(n_components=2, perplexity=perplexity)
            z_tsne = tsne.fit_transform(z.cpu())
            print("--- %s seconds ---" % (time.time() - start_time))
            embed_x.append(z_tsne[:num_emb_x])
            embed_y.append(z_tsne[num_emb_x:])
        x, y_pos = model.relu(F.normalize(zx_list)).detach(), model.relu(F.normalize(zy_list)).detach()

# =========================== EVAL Accuracy ====================================
y_tr_eval = y_tr.clone().to(device)
y_te_eval = y_te.clone().to(device)
x_tr_eval = x_tr.clone().to(device)
x_te_eval = x_te.clone().to(device)

acc_tr = [] # store train accuracy
acc_te = [] # store test accuracy
for i in range(len(model.d_net)):
    with torch.no_grad():
        Zx_tr_eval = model.x_net[i](x_tr_eval)
        y_pred_tr = model.d_net[i](Zx_tr_eval)
        acc_tr.append((torch.argmax(y_pred_tr, dim=1) == y_tr_eval).sum().item() / y_tr_eval.size(0)) # Convert logits to labels
        x_tr_eval = model.relu(F.normalize(Zx_tr_eval)).detach()

        Zx_te_eval = model.x_net[i](x_te_eval)
        y_pred_te = model.d_net[i](Zx_te_eval)
        acc_te.append((torch.argmax(y_pred_te, dim=1) == y_te_eval).sum().item() / y_te_eval.size(0)) # Convert logits to labels
        x_te_eval = model.relu(F.normalize(Zx_te_eval)).detach()

print('acc: tr',acc_tr,' te',acc_te)

--- start fit tsne (L 0 ) ---: num_points= 10010
--- 106.517653465271 seconds ---
--- start fit tsne (L 1 ) ---: num_points= 10010
--- 102.64879250526428 seconds ---
--- start fit tsne (L 2 ) ---: num_points= 10010
--- 102.05071711540222 seconds ---
acc: tr [0.90505, 0.9222833333333333, 0.9015166666666666]  te [0.8966, 0.9222, 0.8979]


In [11]:
# plot
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.graph_objects as go
fig = make_subplots(rows=1, cols=len(model.x_net))
for i in range(len(model.x_net)):
    zx_tsne, zy_tsne = embed_x[i], embed_y[i]
    cmap = np.array(px.colors.qualitative.Plotly)
    colors = cmap[y_te.cpu()]
    traces = [
        go.Scatter(x=zx_tsne[:, 0], y=zx_tsne[:, 1], mode='markers', marker=dict(color=colors, size=3), name='zx'),
        go.Scatter(x=zy_tsne[:, 0], y=zy_tsne[:, 1], mode='markers', marker=dict(color=cmap, size=20, symbol='x', line=dict(color='black', width=2)), name='zy')
    ]
    fig.add_trace(traces[0], row=1, col=i+1)
    fig.add_trace(traces[1], row=1, col=i+1)
fig.show()