In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torch.nn.functional as F
import argparse
import matplotlib
from tqdm import tqdm
import glob
from PIL import Image
import os
from datetime import datetime
import time
import math
import sys
import sys
sys.path.append("../../src")
from ContrastiveModels import *
from visualization import *

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [3]:
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), 
                                            torchvision.transforms.Normalize(mean=(0.0,), std=(1.0,))])

mnist_dset_train = torchvision.datasets.MNIST('../../data', train=True, transform=transform, target_transform=None, download=True)
train_loader = torch.utils.data.DataLoader(mnist_dset_train, batch_size=20, shuffle=True, num_workers=0)

mnist_dset_test = torchvision.datasets.MNIST('../../data', train=False, transform=transform, target_transform=None, download=True)
test_loader = torch.utils.data.DataLoader(mnist_dset_test, batch_size=20, shuffle=False, num_workers=0)

In [4]:
activation = hard_sigmoid
criterion = torch.nn.MSELoss(reduction='none').to(device)
architecture = [784, 500, 10]
alphas_W = [0.5, 0.375]
alphas_M = [-0.01]

model = CSM(architecture, activation, alphas_W, alphas_M)
model = model.to(device)

scheduler = torch.optim.lr_scheduler.StepLR(model.optimizer, step_size=15, gamma=0.75)

In [5]:
evaluateEP(model.to(device), test_loader, 20, 0.5, device)

Test accuracy :	 0.1114


0.1114

In [6]:
mbs = train_loader.batch_size
start = time.time()
iter_per_epochs = math.ceil(len(train_loader.dataset)/mbs)
betas = (0.0, 1.0)
beta_1, beta_2 = betas
neural_lr = 0.5

train_acc = [10.0]
test_acc = [10.0]
best = 0.0
epoch_sofar = 0
model.train()

CSM(
  (W): ModuleList(
    (0): Linear(in_features=784, out_features=500, bias=True)
    (1): Linear(in_features=500, out_features=10, bias=True)
  )
  (M): ModuleList(
    (0): Linear(in_features=500, out_features=500, bias=False)
  )
  (M_copy): ModuleList(
    (0): Linear(in_features=500, out_features=500, bias=False)
  )
)

In [7]:
trn_acc_list = []
tst_acc_list = []
T1 = 20
T2 = 4
random_sign = True
n_epochs = 30
for epoch_ in range(n_epochs):
    model.train()
    for idx, (x, y) in tqdm(enumerate(train_loader)):
        x, y = x.to(device), y.to(device)
        neurons = model.init_neurons(x.size(0), device)
        neurons = model(x, y, neurons, T1, neural_lr, beta=beta_1, criterion=criterion)
        neurons_1 = copy(neurons)
        if random_sign and (beta_1==0.0):
            rnd_sgn = 2*np.random.randint(2) - 1
            betas = beta_1, rnd_sgn*beta_2
            beta_1, beta_2 = betas
        neurons = model(x, y, neurons, T2, neural_lr, beta = beta_2, criterion=criterion)
        neurons_2 = copy(neurons)
        model.compute_syn_grads(x, y, neurons_1, neurons_2, betas, alphas_M, criterion)
#         optimizer.step()
    scheduler.step()
    model.eval()
    trn_acc = evaluateEP(model.to(device), train_loader, 20, neural_lr, device, False)
    tst_acc = evaluateEP(model.to(device), test_loader, 20, neural_lr, device, False)
    trn_acc_list.append(trn_acc)
    tst_acc_list.append(tst_acc)
    
    print("Epoch : {}, Train Accuracy : {}, Test Accuracy : {}".format(epoch_+1, trn_acc, tst_acc))

3000it [02:46, 18.04it/s]
4it [00:00, 34.25it/s]

Epoch : 1, Train Accuracy : 0.9630166666666666, Test Accuracy : 0.9605


3000it [01:27, 34.26it/s]
4it [00:00, 34.21it/s]

Epoch : 2, Train Accuracy : 0.9779333333333333, Test Accuracy : 0.9724


3000it [01:23, 35.90it/s]
3it [00:00, 22.86it/s]

Epoch : 3, Train Accuracy : 0.9810333333333333, Test Accuracy : 0.9706


3000it [01:58, 25.23it/s]
4it [00:00, 36.32it/s]

Epoch : 4, Train Accuracy : 0.9848, Test Accuracy : 0.9732


3000it [01:32, 32.50it/s]
4it [00:00, 35.98it/s]

Epoch : 5, Train Accuracy : 0.9876, Test Accuracy : 0.975


3000it [01:21, 36.73it/s]
3it [00:00, 21.99it/s]

Epoch : 6, Train Accuracy : 0.98755, Test Accuracy : 0.9742


3000it [01:33, 32.23it/s]
4it [00:00, 35.72it/s]

Epoch : 7, Train Accuracy : 0.9915666666666667, Test Accuracy : 0.9767


3000it [01:21, 36.81it/s]
4it [00:00, 35.19it/s]

Epoch : 8, Train Accuracy : 0.9918833333333333, Test Accuracy : 0.976


3000it [01:22, 36.58it/s]
4it [00:00, 35.29it/s]

Epoch : 9, Train Accuracy : 0.99395, Test Accuracy : 0.9764


3000it [01:21, 36.86it/s]
4it [00:00, 35.17it/s]

Epoch : 10, Train Accuracy : 0.9952833333333333, Test Accuracy : 0.9796


3000it [01:21, 36.86it/s]
4it [00:00, 34.65it/s]

Epoch : 11, Train Accuracy : 0.9818166666666667, Test Accuracy : 0.9661


3000it [01:21, 36.71it/s]
4it [00:00, 34.68it/s]

Epoch : 12, Train Accuracy : 0.9923666666666666, Test Accuracy : 0.9764


3000it [01:42, 29.17it/s]
4it [00:00, 35.33it/s]

Epoch : 13, Train Accuracy : 0.9958666666666667, Test Accuracy : 0.9782


3000it [01:21, 36.90it/s]
4it [00:00, 36.13it/s]

Epoch : 14, Train Accuracy : 0.9941166666666666, Test Accuracy : 0.9756


3000it [01:21, 36.88it/s]
4it [00:00, 33.60it/s]

Epoch : 15, Train Accuracy : 0.9966666666666667, Test Accuracy : 0.9778


3000it [01:21, 36.72it/s]
4it [00:00, 34.74it/s]

Epoch : 16, Train Accuracy : 0.9952333333333333, Test Accuracy : 0.979


3000it [01:21, 36.83it/s]
4it [00:00, 35.78it/s]

Epoch : 17, Train Accuracy : 0.9974833333333334, Test Accuracy : 0.9799


3000it [01:23, 35.87it/s]
4it [00:00, 36.09it/s]

Epoch : 18, Train Accuracy : 0.99785, Test Accuracy : 0.9796


3000it [01:43, 28.93it/s]
4it [00:00, 35.24it/s]

Epoch : 19, Train Accuracy : 0.99835, Test Accuracy : 0.9799


3000it [01:28, 33.90it/s]
4it [00:00, 35.94it/s]

Epoch : 20, Train Accuracy : 0.99725, Test Accuracy : 0.9766


3000it [01:23, 35.94it/s]
4it [00:00, 35.82it/s]

Epoch : 21, Train Accuracy : 0.99675, Test Accuracy : 0.9761


3000it [01:21, 36.70it/s]
4it [00:00, 35.53it/s]

Epoch : 22, Train Accuracy : 0.9974, Test Accuracy : 0.9773


3000it [01:21, 36.69it/s]
4it [00:00, 35.26it/s]

Epoch : 23, Train Accuracy : 0.9985833333333334, Test Accuracy : 0.9794


3000it [01:30, 33.31it/s]
4it [00:00, 35.51it/s]

Epoch : 24, Train Accuracy : 0.99865, Test Accuracy : 0.9791


3000it [01:21, 36.85it/s]
4it [00:00, 36.35it/s]

Epoch : 25, Train Accuracy : 0.9988833333333333, Test Accuracy : 0.9797


3000it [01:21, 36.84it/s]
4it [00:00, 35.01it/s]

Epoch : 26, Train Accuracy : 0.9989666666666667, Test Accuracy : 0.9781


3000it [01:21, 36.98it/s]
4it [00:00, 34.55it/s]

Epoch : 27, Train Accuracy : 0.9996, Test Accuracy : 0.9802


3000it [01:21, 36.93it/s]
4it [00:00, 35.10it/s]

Epoch : 28, Train Accuracy : 0.9997166666666667, Test Accuracy : 0.9814


3000it [01:21, 36.79it/s]
4it [00:00, 35.83it/s]

Epoch : 29, Train Accuracy : 0.9997833333333334, Test Accuracy : 0.9812


3000it [01:33, 32.02it/s]


Epoch : 30, Train Accuracy : 0.99975, Test Accuracy : 0.9799


In [None]:
plot_convergence_plot(trn_acc_list, xlabel = 'Number of Epochs', ylabel = 'Accuracy %',
                      title = 'CSM Train Accuracy w.r.t. Epochs', 
                      figsize = (12,8), fontsize = 25, linewidth = 3)

In [None]:
plot_convergence_plot(tst_acc_list, xlabel = 'Number of Epochs', ylabel = 'Accuracy %',
                      title = 'CSM Test Accuracy w.r.t. Epochs', 
                      figsize = (12,8), fontsize = 25, linewidth = 3)