In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torch.nn.functional as F
import argparse
import matplotlib
from tqdm import tqdm
import glob
from PIL import Image
import os
from datetime import datetime
import time
import math
import sys
import sys
sys.path.append("../../src")
from ContrastiveModels import *
from visualization import *

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [3]:
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), 
                                            torchvision.transforms.Normalize(mean=(0.0,), std=(1.0,))])

mnist_dset_train = torchvision.datasets.MNIST('../../data', train=True, transform=transform, target_transform=None, download=True)
train_loader = torch.utils.data.DataLoader(mnist_dset_train, batch_size=20, shuffle=True, num_workers=0)

mnist_dset_test = torchvision.datasets.MNIST('../../data', train=False, transform=transform, target_transform=None, download=True)
test_loader = torch.utils.data.DataLoader(mnist_dset_test, batch_size=20, shuffle=False, num_workers=0)

In [None]:
activation = hard_sigmoid
criterion = torch.nn.MSELoss(reduction='none').to(device)
architecture = [784, 500, 10]
alphas_W = [0.5, 0.375]
alphas_M = [-0.01]

model = CSM(architecture, activation, alphas_W, alphas_M)
model = model.to(device)

scheduler = torch.optim.lr_scheduler.StepLR(model.optimizer, step_size=15, gamma=0.75)

In [5]:
evaluateEP(model.to(device), test_loader, 20, 0.5, device)

Test accuracy :	 0.0818


0.0818

In [6]:
mbs = train_loader.batch_size
start = time.time()
iter_per_epochs = math.ceil(len(train_loader.dataset)/mbs)
betas = (0.0, 1.0)
beta_1, beta_2 = betas
neural_lr = 0.5

train_acc = [10.0]
test_acc = [10.0]
best = 0.0
epoch_sofar = 0
model.train()

CSM(
  (W): ModuleList(
    (0): Linear(in_features=784, out_features=500, bias=True)
    (1): Linear(in_features=500, out_features=10, bias=True)
  )
  (M): ModuleList(
    (0): Linear(in_features=500, out_features=500, bias=False)
  )
  (M_copy): ModuleList(
    (0): Linear(in_features=500, out_features=500, bias=False)
  )
)

In [None]:
trn_acc_list = []
tst_acc_list = []
T1 = 20
T2 = 4
random_sign = True
n_epochs = 30
for epoch_ in range(n_epochs):
    model.train()
    for idx, (x, y) in tqdm(enumerate(train_loader)):
        x, y = x.to(device), y.to(device)
        neurons = model.init_neurons(x.size(0), device)
        neurons = model(x, y, neurons, T1, neural_lr, beta=beta_1, criterion=criterion)
        neurons_1 = copy(neurons)
        if random_sign and (beta_1==0.0):
            rnd_sgn = 2*np.random.randint(2) - 1
            betas = beta_1, rnd_sgn*beta_2
            beta_1, beta_2 = betas
        neurons = model(x, y, neurons, T2, neural_lr, beta = beta_2, criterion=criterion)
        neurons_2 = copy(neurons)
        model.compute_syn_grads(x, y, neurons_1, neurons_2, betas, alphas_M, criterion)
#         optimizer.step()
    scheduler.step()
    model.eval()
    trn_acc = evaluateEP(model.to(device), train_loader, 20, neural_lr, device, False)
    tst_acc = evaluateEP(model.to(device), test_loader, 20, neural_lr, device, False)
    trn_acc_list.append(trn_acc)
    tst_acc_list.append(tst_acc)
    
    print("Epoch : {}, Train Accuracy : {}, Test Accuracy : {}".format(epoch_+1, trn_acc, tst_acc))

3000it [01:50, 27.23it/s]
4it [00:00, 34.08it/s]

Epoch : 1, Train Accuracy : 0.9538166666666666, Test Accuracy : 0.9502


3000it [01:22, 36.58it/s]
4it [00:00, 32.62it/s]

Epoch : 2, Train Accuracy : 0.97155, Test Accuracy : 0.9653


3000it [01:21, 36.69it/s]
3it [00:00, 22.36it/s]

Epoch : 3, Train Accuracy : 0.98215, Test Accuracy : 0.9735


3000it [01:39, 30.10it/s]
4it [00:00, 35.23it/s]

Epoch : 4, Train Accuracy : 0.9819333333333333, Test Accuracy : 0.9687


3000it [01:20, 37.28it/s]
4it [00:00, 36.25it/s]

Epoch : 5, Train Accuracy : 0.9848833333333333, Test Accuracy : 0.9715


3000it [01:20, 37.39it/s]
4it [00:00, 34.53it/s]

Epoch : 6, Train Accuracy : 0.9860166666666667, Test Accuracy : 0.9732


3000it [01:20, 37.44it/s]
4it [00:00, 35.18it/s]

Epoch : 7, Train Accuracy : 0.99075, Test Accuracy : 0.9775


3000it [01:20, 37.42it/s]
4it [00:00, 34.57it/s]

Epoch : 8, Train Accuracy : 0.9914666666666667, Test Accuracy : 0.9743


1604it [00:42, 37.75it/s]

In [None]:
plot_convergence_plot(trn_acc_list, xlabel = 'Number of Epochs', ylabel = 'Accuracy %',
                      title = 'CSM Train Accuracy w.r.t. Epochs', 
                      figsize = (12,8), fontsize = 25, linewidth = 3)

In [None]:
plot_convergence_plot(tst_acc_list, xlabel = 'Number of Epochs', ylabel = 'Accuracy %',
                      title = 'CSM Test Accuracy w.r.t. Epochs', 
                      figsize = (12,8), fontsize = 25, linewidth = 3)