In [1]:
import sys
sys.path.append("../src")
import torch
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torch.nn.functional as F

import glob
import os
from datetime import datetime
import time
import math
from tqdm import tqdm

from itertools import repeat
from torch.nn.parameter import Parameter
import collections
import matplotlib
from torch_utils import *
from ExplicitModels import *
from visualization import *
# matplotlib.use('Agg')

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [3]:
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), 
                                            torchvision.transforms.Normalize(mean=(0.0,), std=(1.0,))])

mnist_dset_train = torchvision.datasets.MNIST('./data', train=True, transform=transform, target_transform=None, download=True)
train_loader = torch.utils.data.DataLoader(mnist_dset_train, batch_size=64, shuffle=True, num_workers=0)

mnist_dset_test = torchvision.datasets.MNIST('./data', train=False, transform=transform, target_transform=None, download=True)
test_loader = torch.utils.data.DataLoader(mnist_dset_test, batch_size=64, shuffle=False, num_workers=0)

In [4]:
activation = F.relu
architecture = [784, 128, 64, 10]

neural_lr_start = 0.1/1000
neural_lr_stop = 0.05/1000
neural_lr_rule = "constant"
neural_lr_decay_multiplier = 0.005
neural_dynamic_iterations = 50

model = SupervisedPredictiveCodingNudged_wAutoGrad(architecture, activation, use_stepLR = True, 
                                                   sgd_nesterov = True, optimizer_type = "sgd", 
                                                   optim_lr = 1e-3, stepLR_step_size = 10*3000,
                                                   supervised_lambda_weight = 1e-3)

In [5]:
x, y = next(iter(train_loader))
x, y = x.to(device), y.to(device)
y_one_hot = F.one_hot(y, 10).to(device).T
x = x.view(x.size(0),-1).T
neurons = model.fast_forward(x, no_grad = True)
z = torch.clone(neurons[-1])[:,0]#.requires_grad_(False)

z.requires_grad_()
loss = torch.nn.CrossEntropyLoss(reduction = "none")
l = loss(z.T, y[0])
l.backward()
torch.norm(z.grad - (F.softmax(z, 0) - y_one_hot[:,0]))

  l = loss(z.T, y[0])


tensor(3.6500e-08, device='cuda:0', grad_fn=<CopyBackwards>)

In [6]:
x, y = next(iter(train_loader))
x, y = x.to(device), y.to(device)
y_one_hot = F.one_hot(y, 10).to(device).T
x = x.view(x.size(0),-1).T
neurons = model.fast_forward(x, no_grad = True)
z = torch.clone(neurons[-1])#[:,0]#.requires_grad_(False)

z.requires_grad_()
loss = torch.nn.CrossEntropyLoss(reduction = "sum")
l = loss(z.T, y_one_hot.to(torch.float).T)
l.backward()
torch.norm(z.grad - (F.softmax(z, 0) - y_one_hot))

tensor(3.6591e-07, device='cuda:0', grad_fn=<CopyBackwards>)

In [7]:
x, y = next(iter(train_loader))
x, y = x.to(device), y.to(device)
y_one_hot = F.one_hot(y, 10).to(device).T
x = x.view(x.size(0),-1).T
neurons = model.fast_forward(x, no_grad = True)
z = torch.clone(neurons[-1])#[:,0]#.requires_grad_(False)

z.requires_grad_()
loss = torch.nn.CrossEntropyLoss(reduction = "sum")
l1 = loss(z.T, y_one_hot.to(torch.float).T)

loss = torch.nn.CrossEntropyLoss(reduction = "mean")
l2 = loss(z.T, y_one_hot.to(torch.float).T)

In [10]:
l2*64 - l1

tensor(0., device='cuda:0', grad_fn=<SubBackward0>)

In [11]:
x.shape

torch.Size([784, 64])

In [65]:
F.softmax(z, 0).sum()

tensor(1.0000, device='cuda:0')

In [84]:
z.requires_grad_()
loss = torch.nn.CrossEntropyLoss(reduction = "none")
l = loss(z.T, y[0])
l.backward()
torch.norm(z.grad - (F.softmax(z, 0) - y_one_hot[:,0]))

tensor(3.9425e-08, device='cuda:0', grad_fn=<CopyBackwards>)

In [58]:
y_one_hot[:,0]

tensor([0, 0, 0, 0, 0, 0, 1, 0, 0, 0], device='cuda:0')

In [85]:
z.grad

tensor([ 0.1087,  0.1276,  0.0928,  0.0928,  0.0928,  0.0928,  0.0928, -0.9040,
         0.1109,  0.0928], device='cuda:0')

In [86]:
F.softmax(z, 0) - y_one_hot[:,0]

tensor([ 0.1087,  0.1276,  0.0928,  0.0928,  0.0928,  0.0928,  0.0928, -0.9040,
         0.1109,  0.0928], device='cuda:0', grad_fn=<SubBackward0>)

In [77]:
z.requires_grad_()
loss = torch.nn.CrossEntropyLoss(reduction = "mean")
l = loss(F.softmax(z, 0), y_one_hot.to(torch.float))
l.backward()
torch.norm(z.grad - (F.softmax(z, 0) - y_one_hot))

tensor(7.5311, device='cuda:0', grad_fn=<CopyBackwards>)

In [8]:
z.shape

torch.Size([10])

In [None]:
trn_acc_list = []
tst_acc_list = []

n_epochs = 50

for epoch_ in range(n_epochs):
#     if epoch_ > 12:
#         neural_lr_start = 0.05
#     if epoch_ > 17:
#         neural_lr_start = 0.03
    for idx, (x, y) in tqdm(enumerate(train_loader)):
        x, y = x.to(device), y.to(device)
        x = x.to(device).view(x.size(0),-1).T
        y_one_hot = F.one_hot(y, 10).to(device).T
        #y_one_hot = 0.94 * y_one_hot + 0.03 * torch.ones(*y_one_hot.shape, device = device)
        
        model.batch_step(  x, y_one_hot, neural_lr_start, neural_lr_stop, neural_lr_rule,
                               neural_lr_decay_multiplier, neural_dynamic_iterations,
                               )

    trn_acc = evaluatePC(model, train_loader, device, False, 
                         printing = False)
    tst_acc = evaluatePC(model, test_loader, device, False, 
                         printing = False)
    trn_acc_list.append(trn_acc)
    tst_acc_list.append(tst_acc)
    
    print("Epoch : {}, Train Accuracy : {}, Test Accuracy : {}".format(epoch_+1, trn_acc, tst_acc))

In [None]:
trn_acc = evaluatePC(model, train_loader, device, False, 
                     printing = False)
tst_acc = evaluatePC(model, test_loader, device, False, 
                     printing = False)
trn_acc_list.append(trn_acc)
tst_acc_list.append(tst_acc)

print("Epoch : {}, Train Accuracy : {}, Test Accuracy : {}".format(epoch_+1, trn_acc, tst_acc))

In [None]:
x, y = next(iter(train_loader))
x, y = x.to(device), y.to(device)
y_one_hot = F.one_hot(y, 10).to(device).T
x = x.view(x.size(0),-1).T
neurons = model.fast_forward(x, no_grad = True)
z = torch.clone(neurons[-1])#.requires_grad_(False)

In [None]:
neurons[-1].shape, y.shape, z.requires_grad

In [None]:
neurons[-1].requires_grad

In [None]:
F.softmax(z.T, 1).sum(1)

In [None]:
z.requires_grad_()
loss = torch.nn.CrossEntropyLoss(reduction = "none")
l = loss(F.softmax(z, 0).T, y_one_hot.to(torch.float).T)
l.backward()
torch.norm(z.grad - (F.softmax(z, 0) - y_one_hot))

In [None]:
l.shape

In [None]:
z.grad

In [None]:
(F.softmax(z, 0) - y_one_hot)

In [None]:
z.grad / (F.softmax(z, 0) - y_one_hot)

In [None]:
1/10

In [None]:
F.softmax(z, 0) - y_one_hot

In [None]:
# optim_params = []
# for idx in range(len(model.Wff)):
#     for key_ in ["weight", "bias"]:
#         optim_params.append(  {'params': model.Wff[idx][key_], 'lr': lr_start["ff"]}  )

In [None]:
# optimizer = torch.optim.Adam(optim_params, maximize = True)

In [None]:
trn_acc_list = []
tst_acc_list = []

n_epochs = 50

for epoch_ in range(n_epochs):
    if epoch_ > 12:
        neural_lr_start = 0.05
    if epoch_ > 17:
        neural_lr_start = 0.03
    for idx, (x, y) in tqdm(enumerate(train_loader)):
        x, y = x.to(device), y.to(device)
        x = activation_inverse(x.view(x.size(0),-1).T, "sigmoid")
        y_one_hot = F.one_hot(y, 10).to(device).T
        y_one_hot = 0.94 * y_one_hot + 0.03 * torch.ones(*y_one_hot.shape, device = device)
        
        model.batch_step(  x, y_one_hot, neural_lr_start, neural_lr_stop, neural_lr_rule,
                               neural_lr_decay_multiplier, neural_dynamic_iterations,
                               )

    trn_acc = evaluatePC(  model, train_loader, neural_lr_start, neural_lr_stop, neural_lr_rule, 
                           neural_lr_decay_multiplier,
                           neural_dynamic_iterations, device, printing = False)
    tst_acc = evaluatePC(  model, test_loader, neural_lr_start, neural_lr_stop, neural_lr_rule, 
                           neural_lr_decay_multiplier,
                           neural_dynamic_iterations, device, printing = False)
    trn_acc_list.append(trn_acc)
    tst_acc_list.append(tst_acc)
    
    print("Epoch : {}, Train Accuracy : {}, Test Accuracy : {}".format(epoch_+1, trn_acc, tst_acc))