In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torch.nn.functional as F
import argparse
import matplotlib
from tqdm import tqdm
import glob
from PIL import Image
import os
from datetime import datetime
import time
import math
import sys
import sys
sys.path.append("../src")
from ContrastiveModels import *
from visualization import *

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [3]:
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), 
                                            torchvision.transforms.Normalize(mean=(0.0,), std=(1.0,))])

mnist_dset_train = torchvision.datasets.MNIST('./data', train=True, transform=transform, target_transform=None, download=True)
train_loader = torch.utils.data.DataLoader(mnist_dset_train, batch_size=20, shuffle=True, num_workers=0)

mnist_dset_test = torchvision.datasets.MNIST('./data', train=False, transform=transform, target_transform=None, download=True)
test_loader = torch.utils.data.DataLoader(mnist_dset_test, batch_size=20, shuffle=False, num_workers=0)

In [4]:
activation = hard_sigmoid
criterion = torch.nn.MSELoss(reduction='none').to(device)
architecture = [784, 500, 10]
alphas_W = [0.5, 0.375]
alphas_M = [-0.01]

model = CSM(architecture, activation, alphas_W, alphas_M)
model = model.to(device)

In [5]:
evaluateEP(model.to(device), test_loader, 20, 0.5, device)

Test accuracy :	 0.1189


0.1189

In [6]:
mbs = train_loader.batch_size
start = time.time()
iter_per_epochs = math.ceil(len(train_loader.dataset)/mbs)
betas = (0.0, 1.0)
beta_1, beta_2 = betas
neural_lr = 0.5

train_acc = [10.0]
test_acc = [10.0]
best = 0.0
epoch_sofar = 0
model.train()

CSM(
  (W): ModuleList(
    (0): Linear(in_features=784, out_features=500, bias=True)
    (1): Linear(in_features=500, out_features=10, bias=True)
  )
  (M): ModuleList(
    (0): Linear(in_features=500, out_features=500, bias=False)
  )
  (M_copy): ModuleList(
    (0): Linear(in_features=500, out_features=500, bias=False)
  )
)

In [None]:
trn_acc_list = []
tst_acc_list = []
T1 = 20
T2 = 4
random_sign = True
n_epochs = 50
for epoch_ in range(n_epochs):
    for idx, (x, y) in tqdm(enumerate(train_loader)):
        x, y = x.to(device), y.to(device)
        neurons = model.init_neurons(x.size(0), device)
        neurons = model(x, y, neurons, T1, neural_lr, beta=beta_1, criterion=criterion)
        neurons_1 = copy(neurons)
        if random_sign and (beta_1==0.0):
            rnd_sgn = 2*np.random.randint(2) - 1
            betas = beta_1, rnd_sgn*beta_2
            beta_1, beta_2 = betas
        neurons = model(x, y, neurons, T2, neural_lr, beta = beta_2, criterion=criterion)
        neurons_2 = copy(neurons)
        model.compute_syn_grads(x, y, neurons_1, neurons_2, betas, alphas_M, criterion)
        #optimizer.step()
    trn_acc = evaluateEP(model.to(device), train_loader, 20, neural_lr, device, False)
    tst_acc = evaluateEP(model.to(device), test_loader, 20, neural_lr, device, False)
    trn_acc_list.append(trn_acc)
    tst_acc_list.append(tst_acc)
    
    print("Epoch : {}, Train Accuracy : {}, Test Accuracy : {}".format(epoch_+1, trn_acc, tst_acc))

3000it [03:09, 15.86it/s]
2it [00:00, 13.13it/s]

Epoch : 1, Train Accuracy : 0.9645666666666667, Test Accuracy : 0.9608


3000it [03:18, 15.08it/s]
2it [00:00, 13.80it/s]

Epoch : 2, Train Accuracy : 0.975, Test Accuracy : 0.964


3000it [03:17, 15.18it/s]
2it [00:00, 13.92it/s]

Epoch : 3, Train Accuracy : 0.9809333333333333, Test Accuracy : 0.9694


3000it [03:18, 15.12it/s]
2it [00:00, 14.15it/s]

Epoch : 4, Train Accuracy : 0.9859833333333333, Test Accuracy : 0.9742


3000it [03:18, 15.09it/s]
2it [00:00, 13.81it/s]

Epoch : 5, Train Accuracy : 0.98915, Test Accuracy : 0.9759


3000it [03:21, 14.91it/s]
2it [00:00, 13.19it/s]

Epoch : 6, Train Accuracy : 0.9888833333333333, Test Accuracy : 0.9747


3000it [03:26, 14.54it/s]
2it [00:00, 13.82it/s]

Epoch : 7, Train Accuracy : 0.9927666666666667, Test Accuracy : 0.977


3000it [03:22, 14.78it/s]
2it [00:00, 13.35it/s]

Epoch : 8, Train Accuracy : 0.9937166666666667, Test Accuracy : 0.9781


3000it [03:20, 14.98it/s]
2it [00:00, 13.67it/s]

Epoch : 9, Train Accuracy : 0.9946833333333334, Test Accuracy : 0.9767


3000it [03:19, 15.01it/s]
2it [00:00, 13.54it/s]

Epoch : 10, Train Accuracy : 0.99355, Test Accuracy : 0.9766


3000it [03:19, 15.05it/s]
2it [00:00, 13.75it/s]

Epoch : 11, Train Accuracy : 0.9964666666666666, Test Accuracy : 0.9811


3000it [03:20, 14.98it/s]
2it [00:00, 14.09it/s]

Epoch : 12, Train Accuracy : 0.9949833333333333, Test Accuracy : 0.9792


3000it [03:18, 15.09it/s]
2it [00:00, 13.64it/s]

Epoch : 13, Train Accuracy : 0.9952833333333333, Test Accuracy : 0.9777


3000it [03:19, 15.04it/s]
2it [00:00, 14.04it/s]

Epoch : 14, Train Accuracy : 0.9846833333333334, Test Accuracy : 0.9705


3000it [03:19, 15.04it/s]
2it [00:00, 13.57it/s]

Epoch : 15, Train Accuracy : 0.9832166666666666, Test Accuracy : 0.9694


3000it [03:20, 14.95it/s]
2it [00:00, 13.69it/s]

Epoch : 16, Train Accuracy : 0.9900833333333333, Test Accuracy : 0.9746


3000it [03:23, 14.71it/s]
2it [00:00, 13.96it/s]

Epoch : 17, Train Accuracy : 0.9913, Test Accuracy : 0.9735


3000it [03:21, 14.90it/s]
2it [00:00, 14.10it/s]

Epoch : 18, Train Accuracy : 0.9924833333333334, Test Accuracy : 0.9758


3000it [03:22, 14.84it/s]
2it [00:00, 13.02it/s]

Epoch : 19, Train Accuracy : 0.9948166666666667, Test Accuracy : 0.9754


3000it [03:22, 14.79it/s]
2it [00:00, 13.60it/s]

Epoch : 20, Train Accuracy : 0.9933666666666666, Test Accuracy : 0.9725


3000it [03:22, 14.83it/s]
2it [00:00, 13.39it/s]

Epoch : 21, Train Accuracy : 0.9943833333333333, Test Accuracy : 0.9773


3000it [03:23, 14.76it/s]
2it [00:00, 13.05it/s]

Epoch : 22, Train Accuracy : 0.9921166666666666, Test Accuracy : 0.973


3000it [03:23, 14.71it/s]
2it [00:00, 14.31it/s]

Epoch : 23, Train Accuracy : 0.9966166666666667, Test Accuracy : 0.9789


3000it [03:22, 14.80it/s]
2it [00:00, 13.15it/s]

Epoch : 24, Train Accuracy : 0.99075, Test Accuracy : 0.9738


3000it [02:51, 17.47it/s]
2it [00:00, 15.89it/s]

Epoch : 25, Train Accuracy : 0.9953166666666666, Test Accuracy : 0.9758


3000it [02:49, 17.65it/s]
2it [00:00, 16.50it/s]

Epoch : 26, Train Accuracy : 0.9961, Test Accuracy : 0.9746


3000it [02:46, 17.97it/s]
3it [00:00, 27.72it/s]

Epoch : 27, Train Accuracy : 0.9976666666666667, Test Accuracy : 0.9787


3000it [02:02, 24.50it/s]
2it [00:00, 15.57it/s]

Epoch : 28, Train Accuracy : 0.9964333333333333, Test Accuracy : 0.9754


2676it [01:46, 29.35it/s]

In [None]:
plot_convergence_plot(trn_acc_list, xlabel = 'Number of Epochs', ylabel = 'Accuracy %',
                      title = 'CSM Train Accuracy w.r.t. Epochs', 
                      figsize = (12,8), fontsize = 25, linewidth = 3)

In [None]:
plot_convergence_plot(tst_acc_list, xlabel = 'Number of Epochs', ylabel = 'Accuracy %',
                      title = 'CSM Test Accuracy w.r.t. Epochs', 
                      figsize = (12,8), fontsize = 25, linewidth = 3)