In [1]:
import torch
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torch.nn.functional as F
import argparse
import matplotlib
from tqdm import tqdm
import glob
from PIL import Image
import os
from datetime import datetime
import time
import math
import sys
import sys
sys.path.append("../src")
from models import *
from visualization import *

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [3]:
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), 
                                            torchvision.transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), 
                                            std=(3*0.2023, 3*0.1994, 3*0.2010))])

cifar_dset_train = torchvision.datasets.CIFAR10('./data', train=True, transform=transform, target_transform=None, download=True)
train_loader = torch.utils.data.DataLoader(cifar_dset_train, batch_size=20, shuffle=True, num_workers=0)

cifar_dset_test = torchvision.datasets.CIFAR10('./data', train=False, transform=transform, target_transform=None, download=True)
test_loader = torch.utils.data.DataLoader(cifar_dset_test, batch_size=20, shuffle=False, num_workers=0)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
activation = hard_sigmoid
criterion = torch.nn.MSELoss(reduction='none').to(device)
architecture = [int(32*32*3), 500, 10]
model = CSM(architecture, activation = activation)
model = model.to(device)

In [5]:
alphas_W = [0.059, 0.017]
alphas_M = [-0.067]

In [6]:
optim_params = []
for idx in range(len(model.W)):
    optim_params.append(  {'params': model.W[idx].parameters(), 'lr': alphas_W[idx]}  )
    
for idx in range(len(model.M)):
    optim_params.append(  {'params': model.M[idx].parameters(), 'lr': alphas_M[idx]}  )

In [7]:
optimizer = torch.optim.SGD( optim_params, momentum=0.0 )
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.9)

In [8]:
evaluateEP(model.to(device), test_loader, 20, device)

Test accuracy :	 0.1208


0.1208

In [9]:
mbs = train_loader.batch_size
start = time.time()
iter_per_epochs = math.ceil(len(train_loader.dataset)/mbs)
betas = (0.0, 1.0)
beta_1, beta_2 = betas
neural_lr = 0.2
train_acc = [10.0]
test_acc = [10.0]
best = 0.0
epoch_sofar = 0
model.train()

CSM(
  (W): ModuleList(
    (0): Linear(in_features=3072, out_features=500, bias=True)
    (1): Linear(in_features=500, out_features=10, bias=True)
  )
  (M): ModuleList(
    (0): Linear(in_features=500, out_features=500, bias=False)
  )
  (M_copy): ModuleList(
    (0): Linear(in_features=500, out_features=500, bias=False)
  )
)

In [None]:
trn_acc_list = []
tst_acc_list = []
T1 = 20
T2 = 4
random_sign = True
n_epochs = 50
for epoch_ in range(n_epochs):
    for idx, (x, y) in tqdm(enumerate(train_loader)):
        x, y = x.to(device), y.to(device)
        neurons = model.init_neurons(x.size(0), device)
        neurons = model(x, y, neurons, T1, neural_lr = neural_lr, beta=beta_1, criterion=criterion)
        neurons_1 = copy(neurons)
        if random_sign and (beta_1==0.0):
            rnd_sgn = 2*np.random.randint(2) - 1
            betas = beta_1, rnd_sgn*beta_2
            beta_1, beta_2 = betas
        neurons = model(x, y, neurons, T2, neural_lr = neural_lr, beta = beta_2, criterion=criterion)
        neurons_2 = copy(neurons)
        model.compute_syn_grads(x, y, neurons_1, neurons_2, betas, alphas_M, criterion)
        optimizer.step()
    scheduler.step()
    model.eval()
    correct = 0
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        neurons = model.init_neurons(x.size(0), device)
        neurons = model(x, y, neurons, T1, neural_lr = neural_lr) # dynamics for T time steps
        pred = torch.argmax(neurons[-1], dim=1).squeeze()  # in this case prediction is done directly on the last (output) layer of neurons
        correct += (y == pred).sum().item()

    trn_acc = correct/len(train_loader.dataset) 
    
    correct = 0
    for x, y in test_loader:
        x, y = x.to(device), y.to(device)
        neurons = model.init_neurons(x.size(0), device)
        neurons = model(x, y, neurons, T1, neural_lr = neural_lr) # dynamics for T time steps
        pred = torch.argmax(neurons[-1], dim=1).squeeze()  # in this case prediction is done directly on the last (output) layer of neurons
        correct += (y == pred).sum().item()

    tst_acc = correct/len(test_loader.dataset) 
    
#     trn_acc = evaluateEP(model.to(device), train_loader, T1, device, False)
#     tst_acc = evaluateEP(model.to(device), test_loader, T1, device, False)

    trn_acc_list.append(trn_acc)
    tst_acc_list.append(tst_acc)
    
    print("Epoch : {}, Train Accuracy : {}, Test Accuracy : {}".format(epoch_+1, trn_acc, tst_acc))

2500it [02:45, 15.12it/s]
2it [00:00, 13.59it/s]

Epoch : 1, Train Accuracy : 0.37652, Test Accuracy : 0.3643


2500it [02:46, 15.03it/s]
2it [00:00, 13.74it/s]

Epoch : 2, Train Accuracy : 0.1, Test Accuracy : 0.1


2500it [02:47, 14.93it/s]
2it [00:00, 13.66it/s]

Epoch : 3, Train Accuracy : 0.43454, Test Accuracy : 0.4044


2500it [02:48, 14.88it/s]
2it [00:00, 14.19it/s]

Epoch : 4, Train Accuracy : 0.10012, Test Accuracy : 0.1


2500it [02:45, 15.13it/s]
2it [00:00, 14.03it/s]

Epoch : 5, Train Accuracy : 0.45944, Test Accuracy : 0.4318


2500it [02:46, 15.02it/s]
2it [00:00, 13.86it/s]

Epoch : 6, Train Accuracy : 0.46106, Test Accuracy : 0.4375


2500it [02:47, 14.94it/s]
2it [00:00, 14.10it/s]

Epoch : 7, Train Accuracy : 0.48548, Test Accuracy : 0.4583


2500it [02:46, 15.02it/s]
2it [00:00, 13.65it/s]

Epoch : 8, Train Accuracy : 0.49632, Test Accuracy : 0.4681


2500it [02:48, 14.86it/s]
2it [00:00, 13.58it/s]

Epoch : 9, Train Accuracy : 0.48256, Test Accuracy : 0.4525


2500it [02:47, 14.89it/s]
2it [00:00, 13.81it/s]

Epoch : 10, Train Accuracy : 0.4248, Test Accuracy : 0.3985


2500it [02:48, 14.85it/s]
2it [00:00, 14.08it/s]

Epoch : 11, Train Accuracy : 0.49292, Test Accuracy : 0.4524


2500it [02:47, 14.90it/s]
2it [00:00, 14.10it/s]

Epoch : 12, Train Accuracy : 0.50852, Test Accuracy : 0.4737


2500it [02:48, 14.85it/s]
2it [00:00, 13.67it/s]

Epoch : 13, Train Accuracy : 0.44844, Test Accuracy : 0.4124


2500it [02:48, 14.88it/s]
2it [00:00, 13.52it/s]

Epoch : 14, Train Accuracy : 0.1612, Test Accuracy : 0.1555


2500it [02:48, 14.80it/s]
2it [00:00, 13.20it/s]

Epoch : 15, Train Accuracy : 0.32862, Test Accuracy : 0.3224


2500it [02:49, 14.77it/s]
2it [00:00, 14.09it/s]

Epoch : 16, Train Accuracy : 0.52264, Test Accuracy : 0.474


2500it [02:48, 14.84it/s]
2it [00:00, 13.75it/s]

Epoch : 17, Train Accuracy : 0.52058, Test Accuracy : 0.469


1258it [01:25, 14.39it/s]

In [None]:
plot_convergence_plot(trn_acc_list, xlabel = 'Number of Epochs', ylabel = 'Accuracy %',
                      title = 'CSM Train Accuracy w.r.t. Epochs', 
                      figsize = (12,8), fontsize = 25, linewidth = 3)

In [None]:
plot_convergence_plot(tst_acc_list, xlabel = 'Number of Epochs', ylabel = 'Accuracy %',
                      title = 'CSM Test Accuracy w.r.t. Epochs', 
                      figsize = (12,8), fontsize = 25, linewidth = 3)