In [1]:
import sys
sys.path.append("../src")
import torch
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torch.nn.functional as F

import glob
import os
from datetime import datetime
import time
import math
from tqdm import tqdm

from itertools import repeat
from torch.nn.parameter import Parameter
import collections
import matplotlib
from torch_utils import *
from ExplicitModels import *
from visualization import *
# matplotlib.use('Agg')

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [3]:
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),])

mnist_dset_train = torchvision.datasets.MNIST('./data', train=True, transform=transform, target_transform=None, download=True)
train_loader = torch.utils.data.DataLoader(mnist_dset_train, batch_size=20, shuffle=True, num_workers=0)

mnist_dset_test = torchvision.datasets.MNIST('./data', train=False, transform=transform, target_transform=None, download=True)
test_loader = torch.utils.data.DataLoader(mnist_dset_test, batch_size=20, shuffle=False, num_workers=0)

In [4]:
# activation = F.relu
# architecture = [784, 128, 64, 10]
# lambda_ = 0.99999
# epsilon = 0.2
# supervised_lambda_weight = 1e-3
# neural_lr_start = 0.001 
# neural_lr_stop = 0.0005 
# neural_lr_rule = "constant"
# neural_lr_decay_multiplier = 0.005
# neural_dynamic_iterations = 50
activation = F.relu
architecture = [784, 128, 64, 10]
lambda_ = 0.9999
epsilon = 0.2
supervised_lambda_weight = 1e-1
neural_lr_start = 0.001 
neural_lr_stop = 0.0005 
neural_lr_rule = "constant"
neural_lr_decay_multiplier = 0.005
neural_dynamic_iterations = 50

model = CorInfoMaxNudgedV1(architecture, lambda_, epsilon, activation, use_stepLR = True, 
                         sgd_nesterov = False, optimizer_type = "sgd", 
                         optim_lr = 1e-3, stepLR_step_size = 10*3000,)

In [None]:
trn_acc_list = []
tst_acc_list = []
random_sign = False
n_epochs = 50

for epoch_ in range(n_epochs):
    for idx, (x, y) in tqdm(enumerate(train_loader)):
        x, y = x.to(device), y.to(device)
        x = x.to(device).view(x.size(0),-1).T
        y_one_hot = F.one_hot(y, 10).to(device).T
        #y_one_hot = 0.94 * y_one_hot + 0.03 * torch.ones(*y_one_hot.shape, device = device)
        if random_sign:
            rnd_sgn = 2*np.random.randint(2) - 1
            supervised_lambda_weight = rnd_sgn * supervised_lambda_weight

        model.batch_step(  x, y_one_hot, supervised_lambda_weight,
                           neural_lr_start, neural_lr_stop, neural_lr_rule,
                           neural_lr_decay_multiplier, neural_dynamic_iterations,
                        )

    trn_acc = evaluatePC(model, train_loader, device, False, 
                         printing = False)
    tst_acc = evaluatePC(model, test_loader, device, False, 
                         printing = False)
    trn_acc_list.append(trn_acc)
    tst_acc_list.append(tst_acc)
    
    print("Epoch : {}, Train Accuracy : {}, Test Accuracy : {}".format(epoch_+1, trn_acc, tst_acc))

3000it [03:10, 15.71it/s]
2it [00:00, 11.40it/s]

Epoch : 1, Train Accuracy : 0.9072666666666667, Test Accuracy : 0.9129


3000it [03:11, 15.70it/s]
1it [00:00,  9.39it/s]

Epoch : 2, Train Accuracy : 0.9239833333333334, Test Accuracy : 0.9289


3000it [03:26, 14.56it/s]
2it [00:00, 13.19it/s]

Epoch : 3, Train Accuracy : 0.9339833333333334, Test Accuracy : 0.9355


3000it [03:33, 14.06it/s]
2it [00:00, 13.13it/s]

Epoch : 4, Train Accuracy : 0.9386333333333333, Test Accuracy : 0.9405


3000it [03:33, 14.07it/s]
2it [00:00, 13.47it/s]

Epoch : 5, Train Accuracy : 0.9414166666666667, Test Accuracy : 0.9406


3000it [03:32, 14.15it/s]
2it [00:00, 13.33it/s]

Epoch : 6, Train Accuracy : 0.942, Test Accuracy : 0.9412


3000it [03:32, 14.14it/s]
2it [00:00, 13.24it/s]

Epoch : 7, Train Accuracy : 0.9460166666666666, Test Accuracy : 0.9457


3000it [03:32, 14.12it/s]
2it [00:00, 13.26it/s]

Epoch : 8, Train Accuracy : 0.9411, Test Accuracy : 0.9423


3000it [03:33, 14.04it/s]
2it [00:00, 10.95it/s]

Epoch : 9, Train Accuracy : 0.9374333333333333, Test Accuracy : 0.9382


3000it [03:34, 13.98it/s]
2it [00:00, 14.09it/s]

Epoch : 10, Train Accuracy : 0.9478166666666666, Test Accuracy : 0.9477


3000it [03:32, 14.09it/s]
2it [00:00, 13.24it/s]

Epoch : 11, Train Accuracy : 0.9423833333333334, Test Accuracy : 0.9423


3000it [03:33, 14.08it/s]
2it [00:00, 13.31it/s]

Epoch : 12, Train Accuracy : 0.9449333333333333, Test Accuracy : 0.947


3000it [03:33, 14.05it/s]
2it [00:00, 13.42it/s]

Epoch : 13, Train Accuracy : 0.9465, Test Accuracy : 0.9445


3000it [03:34, 14.01it/s]
2it [00:00, 13.27it/s]

Epoch : 14, Train Accuracy : 0.9424833333333333, Test Accuracy : 0.9431


3000it [03:34, 13.98it/s]
2it [00:00, 13.13it/s]

Epoch : 15, Train Accuracy : 0.94505, Test Accuracy : 0.9458


3000it [03:35, 13.91it/s]
2it [00:00, 12.60it/s]

Epoch : 16, Train Accuracy : 0.9426833333333333, Test Accuracy : 0.9452


3000it [03:35, 13.91it/s]
2it [00:00, 13.09it/s]

Epoch : 17, Train Accuracy : 0.9474, Test Accuracy : 0.9454


3000it [03:36, 13.87it/s]
2it [00:00, 13.15it/s]

Epoch : 18, Train Accuracy : 0.9514666666666667, Test Accuracy : 0.9519


3000it [03:35, 13.94it/s]
2it [00:00, 12.64it/s]

Epoch : 19, Train Accuracy : 0.9478333333333333, Test Accuracy : 0.9491


3000it [03:34, 14.00it/s]
2it [00:00, 13.12it/s]

Epoch : 20, Train Accuracy : 0.9498333333333333, Test Accuracy : 0.948


3000it [03:37, 13.80it/s]
2it [00:00, 12.97it/s]

Epoch : 21, Train Accuracy : 0.9508333333333333, Test Accuracy : 0.9505


3000it [03:36, 13.86it/s]
2it [00:00, 13.12it/s]

Epoch : 22, Train Accuracy : 0.9513666666666667, Test Accuracy : 0.9494


3000it [03:35, 13.94it/s]
2it [00:00, 13.21it/s]

Epoch : 23, Train Accuracy : 0.9444166666666667, Test Accuracy : 0.9433


3000it [03:37, 13.80it/s]
2it [00:00, 13.53it/s]

Epoch : 24, Train Accuracy : 0.9507666666666666, Test Accuracy : 0.9512


3000it [03:35, 13.91it/s]
2it [00:00, 12.92it/s]

Epoch : 25, Train Accuracy : 0.9527833333333333, Test Accuracy : 0.9509


3000it [03:36, 13.86it/s]
2it [00:00, 12.84it/s]

Epoch : 26, Train Accuracy : 0.9503833333333334, Test Accuracy : 0.9487


3000it [03:36, 13.87it/s]
2it [00:00, 12.53it/s]

Epoch : 27, Train Accuracy : 0.953, Test Accuracy : 0.9495


3000it [03:36, 13.85it/s]
2it [00:00, 13.13it/s]

Epoch : 28, Train Accuracy : 0.9538166666666666, Test Accuracy : 0.9501


3000it [03:36, 13.85it/s]
2it [00:00, 13.14it/s]

Epoch : 29, Train Accuracy : 0.95355, Test Accuracy : 0.9512


3000it [03:36, 13.85it/s]
2it [00:00, 12.63it/s]

Epoch : 30, Train Accuracy : 0.9517333333333333, Test Accuracy : 0.9513


3000it [03:36, 13.83it/s]
2it [00:00, 12.78it/s]

Epoch : 31, Train Accuracy : 0.9529166666666666, Test Accuracy : 0.9531


3000it [03:35, 13.89it/s]
2it [00:00, 12.77it/s]

Epoch : 32, Train Accuracy : 0.9527666666666667, Test Accuracy : 0.953


3000it [03:36, 13.85it/s]
2it [00:00, 12.59it/s]

Epoch : 33, Train Accuracy : 0.9560166666666666, Test Accuracy : 0.9535


3000it [03:36, 13.88it/s]
2it [00:00, 13.22it/s]

Epoch : 34, Train Accuracy : 0.9533166666666667, Test Accuracy : 0.952


3000it [03:35, 13.92it/s]
2it [00:00, 13.12it/s]

Epoch : 35, Train Accuracy : 0.9511666666666667, Test Accuracy : 0.948


3000it [03:35, 13.92it/s]
2it [00:00, 13.34it/s]

Epoch : 36, Train Accuracy : 0.9527, Test Accuracy : 0.9534


3000it [03:33, 14.03it/s]
2it [00:00, 13.08it/s]

Epoch : 37, Train Accuracy : 0.9509166666666666, Test Accuracy : 0.9495


3000it [03:29, 14.31it/s]
1it [00:00,  9.73it/s]

Epoch : 38, Train Accuracy : 0.9529666666666666, Test Accuracy : 0.9525


3000it [03:30, 14.22it/s]
2it [00:00, 12.82it/s]

Epoch : 39, Train Accuracy : 0.9534333333333334, Test Accuracy : 0.9521


3000it [03:31, 14.19it/s]
2it [00:00, 13.33it/s]

Epoch : 40, Train Accuracy : 0.9555666666666667, Test Accuracy : 0.9541


3000it [03:34, 13.99it/s]
2it [00:00, 13.21it/s]

Epoch : 41, Train Accuracy : 0.9545, Test Accuracy : 0.9533


3000it [03:31, 14.17it/s]
2it [00:00, 13.38it/s]

Epoch : 42, Train Accuracy : 0.95165, Test Accuracy : 0.9521


3000it [03:33, 14.08it/s]
2it [00:00, 13.37it/s]

Epoch : 43, Train Accuracy : 0.9429166666666666, Test Accuracy : 0.9431


3000it [03:31, 14.21it/s]
2it [00:00, 13.39it/s]

Epoch : 44, Train Accuracy : 0.9525666666666667, Test Accuracy : 0.9548


2780it [03:12, 14.08it/s]

In [None]:
activation = F.relu
architecture = [784, 128, 64, 10]
lambda_ = 0.9999
epsilon = 0.1
supervised_lambda_weight = 1e-3
neural_lr_start = 0.001 
neural_lr_stop = 0.0005 
neural_lr_rule = "constant"
neural_lr_decay_multiplier = 0.005
neural_dynamic_iterations = 50

model = CorInfoMaxNudged(architecture, lambda_, epsilon, activation, use_stepLR = True, 
                         sgd_nesterov = False, optimizer_type = "sgd", 
                         optim_lr = 1e-3, stepLR_step_size = 10*3000,)

x, y = next(iter(train_loader))
x, y = x.to(device), y.to(device)
x = x.to(device).view(x.size(0),-1).T
y_one_hot = F.one_hot(y, 10).to(device).T

neurons = model.fast_forward(x, no_grad = True)
model.run_neural_dynamics(x, y_one_hot, neurons, supervised_lambda_weight, 
                          neural_lr_start, neural_lr_stop, lr_rule = neural_lr_rule, 
                          lr_decay_multiplier = neural_lr_decay_multiplier, 
                          neural_dynamic_iterations = neural_dynamic_iterations)

In [None]:
mbs = x.size(1)
for jj in range(len(neurons)):
    neurons[jj] = neurons[jj].requires_grad_()
corinfo_cost = model.CorInfo_Cost(x, y, neurons)
init_grads = torch.tensor([1 for i in range(mbs)], dtype=torch.float, device=device, requires_grad=True) #Initializing gradients
grads = torch.autograd.grad(corinfo_cost, neurons, grad_outputs=init_grads, create_graph=False) # dPhi/ds

In [None]:
one_over_epsilon = model.one_over_epsilon
gam_ = model.gam_

x, y = next(iter(train_loader))
x, y = x.to(device), y.to(device)
x = x.to(device).view(x.size(0),-1).T
y_one_hot = F.one_hot(y, 10).to(device).T

Wff = model.Wff
B = model.B

neurons = model.fast_forward(x, no_grad = True)

layers = [x] + neurons
for jj in range(len(Wff)):
    if jj == 0:
        error = - one_over_epsilon * (layers[jj + 1] - (Wff[jj]['weight'] @ layers[jj] + Wff[jj]['bias'])) 
    else:
        error = - one_over_epsilon * (layers[jj + 1] - (Wff[jj]['weight'] @ model.activation(layers[jj]) + Wff[jj]['bias']))

    lateral_term = gam_ * 0.5 * (layers[jj + 1].T @ B[jj]['weight'] @ layers[jj + 1])
    corinfo_cost = torch.sum(error * error, 0)
    


In [None]:
error.shape
torch.sum(error * error, 0).shape

In [None]:
outer_prod_broadcasting((B[jj]['weight'] @ layers[jj + 1]), layers[jj + 1].T).shape

In [None]:
layers[jj + 1][:,2].T @ B[jj]['weight'] @ layers[jj + 1][:,2]

In [None]:
torch.sum((B[jj]['weight'] @ layers[jj + 1]) * layers[jj + 1], 0)

In [None]:
(B[jj]['weight'] @ layers[jj + 1]).shape, layers[jj + 1].shape

In [None]:
trn_acc_list = []
tst_acc_list = []
random_sign = False
n_epochs = 50

for epoch_ in range(n_epochs):
    for idx, (x, y) in tqdm(enumerate(train_loader)):
        x, y = x.to(device), y.to(device)
        x = x.to(device).view(x.size(0),-1).T
        y_one_hot = F.one_hot(y, 10).to(device).T
        #y_one_hot = 0.94 * y_one_hot + 0.03 * torch.ones(*y_one_hot.shape, device = device)
        if random_sign:
            rnd_sgn = 2*np.random.randint(2) - 1
            supervised_lambda_weight = rnd_sgn * supervised_lambda_weight

        model.batch_step(  x, y_one_hot, supervised_lambda_weight,
                           neural_lr_start, neural_lr_stop, neural_lr_rule,
                           neural_lr_decay_multiplier, neural_dynamic_iterations,
                        )

    trn_acc = evaluatePC(model, train_loader, device, False, 
                         printing = False)
    tst_acc = evaluatePC(model, test_loader, device, False, 
                         printing = False)
    trn_acc_list.append(trn_acc)
    tst_acc_list.append(tst_acc)
    
    print("Epoch : {}, Train Accuracy : {}, Test Accuracy : {}".format(epoch_+1, trn_acc, tst_acc))