In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torch.nn.functional as F
import argparse
import matplotlib
from tqdm import tqdm
import glob
from PIL import Image
import os
from datetime import datetime
import time
import math
import sys
import sys
sys.path.append("../src")
from ContrastiveModels import *
from visualization import *

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [3]:
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), 
                                            torchvision.transforms.Normalize(mean=(0.0,), std=(1.0,))])

mnist_dset_train = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, target_transform=None, download=True)
train_loader = torch.utils.data.DataLoader(mnist_dset_train, batch_size=20, shuffle=True, num_workers=0)

mnist_dset_test = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, target_transform=None, download=True)
test_loader = torch.utils.data.DataLoader(mnist_dset_test, batch_size=20, shuffle=False, num_workers=0)

In [4]:
activation = hard_sigmoid
criterion = torch.nn.MSELoss(reduction='none').to(device)
architecture = [784, 500, 10]
alphas_W = [0.15, 0.1]
alphas_M = [-0.005]

model = CSM(architecture, activation, alphas_W, alphas_M)
model = model.to(device)

In [5]:
# optimizer = torch.optim.SGD( optim_params, momentum=0.0 )
scheduler = torch.optim.lr_scheduler.StepLR(model.optimizer, step_size=5, gamma=0.9)

In [6]:
evaluateEP(model.to(device), test_loader, 20, 0.5, device)

Test accuracy :	 0.0953


0.0953

In [7]:
mbs = train_loader.batch_size
start = time.time()
iter_per_epochs = math.ceil(len(train_loader.dataset)/mbs)
betas = (0.0, 1.0)
beta_1, beta_2 = betas
neural_lr = 0.5

train_acc = [10.0]
test_acc = [10.0]
best = 0.0
epoch_sofar = 0
model.train()

CSM(
  (W): ModuleList(
    (0): Linear(in_features=784, out_features=500, bias=True)
    (1): Linear(in_features=500, out_features=10, bias=True)
  )
  (M): ModuleList(
    (0): Linear(in_features=500, out_features=500, bias=False)
  )
  (M_copy): ModuleList(
    (0): Linear(in_features=500, out_features=500, bias=False)
  )
)

In [None]:
trn_acc_list = []
tst_acc_list = []
T1 = 20
T2 = 4
random_sign = True
n_epochs = 50
for epoch_ in range(n_epochs):
    for idx, (x, y) in tqdm(enumerate(train_loader)):
        x, y = x.to(device), y.to(device)
        neurons = model.init_neurons(x.size(0), device)
        neurons = model(x, y, neurons, T1, neural_lr, beta=beta_1, criterion=criterion)
        neurons_1 = copy(neurons)
        if random_sign and (beta_1==0.0):
            rnd_sgn = 2*np.random.randint(2) - 1
            betas = beta_1, rnd_sgn*beta_2
            beta_1, beta_2 = betas
        neurons = model(x, y, neurons, T2, neural_lr, beta = beta_2, criterion=criterion)
        neurons_2 = copy(neurons)
        model.compute_syn_grads(x, y, neurons_1, neurons_2, betas, alphas_M, criterion)
        #optimizer.step()
    scheduler.step()
    trn_acc = evaluateEP(model.to(device), train_loader, T1, neural_lr, device, False)
    tst_acc = evaluateEP(model.to(device), test_loader, T1, neural_lr, device, False)
    trn_acc_list.append(trn_acc)
    tst_acc_list.append(tst_acc)
    
    print("Epoch : {}, Train Accuracy : {}, Test Accuracy : {}".format(epoch_+1, trn_acc, tst_acc))

3000it [01:37, 30.93it/s]
2it [00:00, 14.42it/s]

Epoch : 1, Train Accuracy : 0.8467166666666667, Test Accuracy : 0.8341


3000it [03:04, 16.29it/s]
2it [00:00, 14.37it/s]

Epoch : 2, Train Accuracy : 0.8743, Test Accuracy : 0.8567


3000it [03:17, 15.17it/s]
2it [00:00, 14.07it/s]

Epoch : 3, Train Accuracy : 0.86525, Test Accuracy : 0.8469


3000it [03:15, 15.36it/s]
2it [00:00, 14.43it/s]

Epoch : 4, Train Accuracy : 0.8831, Test Accuracy : 0.8628


3000it [03:15, 15.38it/s]
2it [00:00, 14.08it/s]

Epoch : 5, Train Accuracy : 0.8935166666666666, Test Accuracy : 0.8723


3000it [03:15, 15.37it/s]
2it [00:00, 14.23it/s]

Epoch : 6, Train Accuracy : 0.8866166666666667, Test Accuracy : 0.8648


3000it [03:16, 15.30it/s]
2it [00:00, 14.06it/s]

Epoch : 7, Train Accuracy : 0.88975, Test Accuracy : 0.8623


3000it [03:14, 15.42it/s]
2it [00:00, 14.13it/s]

Epoch : 8, Train Accuracy : 0.8954333333333333, Test Accuracy : 0.873


3000it [03:18, 15.10it/s]
2it [00:00, 14.50it/s]

Epoch : 9, Train Accuracy : 0.9003, Test Accuracy : 0.8731


3000it [03:18, 15.13it/s]
2it [00:00, 13.76it/s]

Epoch : 10, Train Accuracy : 0.90475, Test Accuracy : 0.8749


3000it [03:16, 15.25it/s]
2it [00:00, 14.48it/s]

Epoch : 11, Train Accuracy : 0.9049333333333334, Test Accuracy : 0.8784


3000it [03:16, 15.26it/s]
2it [00:00, 14.13it/s]

Epoch : 12, Train Accuracy : 0.9102333333333333, Test Accuracy : 0.8818


3000it [03:16, 15.25it/s]
2it [00:00, 13.96it/s]

Epoch : 13, Train Accuracy : 0.9089166666666667, Test Accuracy : 0.8751


3000it [03:16, 15.26it/s]
2it [00:00, 13.00it/s]

Epoch : 14, Train Accuracy : 0.9060666666666667, Test Accuracy : 0.876


3000it [03:18, 15.09it/s]
2it [00:00, 14.28it/s]

Epoch : 15, Train Accuracy : 0.9071333333333333, Test Accuracy : 0.875


3000it [03:16, 15.28it/s]
2it [00:00, 13.97it/s]

Epoch : 16, Train Accuracy : 0.9177666666666666, Test Accuracy : 0.8779


3000it [03:17, 15.19it/s]
2it [00:00, 13.81it/s]

Epoch : 17, Train Accuracy : 0.9252333333333334, Test Accuracy : 0.8887


3000it [03:20, 14.99it/s]
2it [00:00, 14.11it/s]

Epoch : 18, Train Accuracy : 0.9179666666666667, Test Accuracy : 0.8773


3000it [03:17, 15.16it/s]
2it [00:00, 14.16it/s]

Epoch : 19, Train Accuracy : 0.9236666666666666, Test Accuracy : 0.8829


3000it [03:21, 14.92it/s]
2it [00:00, 13.28it/s]

Epoch : 20, Train Accuracy : 0.9259, Test Accuracy : 0.8858


3000it [03:19, 15.03it/s]
2it [00:00, 14.00it/s]

Epoch : 21, Train Accuracy : 0.9293, Test Accuracy : 0.8871


3000it [03:20, 14.96it/s]
2it [00:00, 13.56it/s]

Epoch : 22, Train Accuracy : 0.92295, Test Accuracy : 0.8761


3000it [03:20, 14.97it/s]
2it [00:00, 13.84it/s]

Epoch : 23, Train Accuracy : 0.9273333333333333, Test Accuracy : 0.8867


3000it [03:21, 14.91it/s]
2it [00:00, 13.69it/s]

Epoch : 24, Train Accuracy : 0.9368, Test Accuracy : 0.8882


3000it [03:21, 14.92it/s]
2it [00:00, 13.84it/s]

Epoch : 25, Train Accuracy : 0.9376333333333333, Test Accuracy : 0.891


3000it [02:47, 17.86it/s]
3it [00:00, 27.76it/s]

Epoch : 26, Train Accuracy : 0.93495, Test Accuracy : 0.8882


3000it [01:39, 30.20it/s]
3it [00:00, 27.93it/s]

Epoch : 27, Train Accuracy : 0.9132833333333333, Test Accuracy : 0.8638


3000it [01:40, 29.76it/s]
3it [00:00, 27.76it/s]

Epoch : 28, Train Accuracy : 0.9407833333333333, Test Accuracy : 0.8857


3000it [01:40, 29.99it/s]
3it [00:00, 29.02it/s]

Epoch : 29, Train Accuracy : 0.9177666666666666, Test Accuracy : 0.8725


3000it [01:44, 28.59it/s]
2it [00:00, 16.60it/s]

Epoch : 30, Train Accuracy : 0.9337, Test Accuracy : 0.8784


3000it [02:16, 21.99it/s]
3it [00:00, 28.16it/s]

Epoch : 31, Train Accuracy : 0.94195, Test Accuracy : 0.8901


3000it [02:06, 23.70it/s]


In [None]:
plot_convergence_plot(trn_acc_list, xlabel = 'Number of Epochs', ylabel = 'Accuracy %',
                      title = 'CSM Train Accuracy w.r.t. Epochs', 
                      figsize = (12,8), fontsize = 25, linewidth = 3)

In [None]:
plot_convergence_plot(tst_acc_list, xlabel = 'Number of Epochs', ylabel = 'Accuracy %',
                      title = 'CSM Test Accuracy w.r.t. Epochs', 
                      figsize = (12,8), fontsize = 25, linewidth = 3)