In [1]:
import sys
sys.path.append("../src")
import torch
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torch.nn.functional as F

import glob
import os
from datetime import datetime
import time
import math
from tqdm import tqdm

from itertools import repeat
from torch.nn.parameter import Parameter
import collections
import matplotlib
from torch_utils import *
from PC import *
from visualization import *
# matplotlib.use('Agg')

In [2]:
def evaluatePC(model, loader, neural_lr_start, neural_lr_stop, neural_lr_rule, 
                         neural_lr_decay_multiplier,
                         neural_dynamic_iterations, device, printing = True):
    # Evaluate the model on a dataloader with T steps for the dynamics
    #model.eval()
    correct=0
    phase = 'Train' if loader.dataset.train else 'Test'
    
    for x, y in loader:
        x = activation_inverse(2*x.view(x.size(0),-1).to(device).T - 1, "sigmoid")
#         x = activation_inverse(x.view(x.size(0),-1).T, "sigmoid").to(device)
#         x = x.view(x.size(0),-1).T.to(device)
        y = y.to(device)
        
        neurons = model.forward(x)
        
#         # dynamics for T time steps
#         neurons = model.run_neural_dynamics(x, y_one_hot, neurons, neural_lr_start, neural_lr_stop, 
#                                             neural_lr_rule,
#                                             neural_lr_decay_multiplier, neural_dynamic_iterations, 0, "test")
        pred = torch.argmax(neurons[-1], dim=0).squeeze()  # in this case prediction is done directly on the last (output) layer of neurons
        correct += (y == pred).sum().item()

    acc = correct/len(loader.dataset) 
    if printing:
        print(phase+' accuracy :\t', acc)   
    return acc

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [4]:
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), 
                                            torchvision.transforms.Normalize(mean=(0.0,), std=(1.0,))])

mnist_dset_train = torchvision.datasets.MNIST('./data', train=True, transform=transform, target_transform=None, download=True)
train_loader = torch.utils.data.DataLoader(mnist_dset_train, batch_size=20, shuffle=True, num_workers=0)

mnist_dset_test = torchvision.datasets.MNIST('./data', train=False, transform=transform, target_transform=None, download=True)
test_loader = torch.utils.data.DataLoader(mnist_dset_test, batch_size=20, shuffle=False, num_workers=0)

In [5]:
activation = torch.sigmoid
architecture = [784, 500, 500, 10]

neural_lr_start = 0.1
neural_lr_stop = 0.05
neural_lr_rule = "constant"
neural_lr_decay_multiplier = 0.01
neural_dynamic_iterations = 50

lr_start = {'ff' : 0.001}

model = SupervisedPredictiveCodingV2(architecture, activation)

In [6]:
optim_params = []
for idx in range(len(model.Wff)):
    for key_ in ["weight", "bias"]:
        optim_params.append(  {'params': model.Wff[idx][key_], 'lr': lr_start["ff"]}  )

In [7]:
optimizer = torch.optim.Adam(optim_params, maximize = True)

In [8]:
trn_acc_list = []
tst_acc_list = []

n_epochs = 50
lr = lr_start
for epoch_ in range(n_epochs):
#     lr = {'ff' : lr_start['ff'] * (0.999)**epoch_}
    for idx, (x, y) in tqdm(enumerate(train_loader)):
#         x = x.view(x.size(0),-1).T#.to(device)
        x, y = x.to(device), y.to(device)
        x = activation_inverse(2*x.view(x.size(0),-1).T - 1, "sigmoid")
#         x = activation_inverse(x.view(x.size(0),-1).T, "sigmoid")
        y_one_hot = F.one_hot(y, 10).to(device).T
        y_one_hot = 0.94 * y_one_hot + 0.03 * torch.ones(*y_one_hot.shape, device = device)
        
        optimizer.zero_grad()
        
        model.batch_step(  x, y_one_hot, lr, neural_lr_start, neural_lr_stop, neural_lr_rule,
                               neural_lr_decay_multiplier, neural_dynamic_iterations,
                               )

        optimizer.step()
        
    trn_acc = evaluatePC(  model, train_loader, neural_lr_start, neural_lr_stop, neural_lr_rule, 
                           neural_lr_decay_multiplier,
                           neural_dynamic_iterations, device, printing = False)
    tst_acc = evaluatePC(  model, test_loader, neural_lr_start, neural_lr_stop, neural_lr_rule, 
                           neural_lr_decay_multiplier,
                           neural_dynamic_iterations, device, printing = False)
    trn_acc_list.append(trn_acc)
    tst_acc_list.append(tst_acc)
    
    print("Epoch : {}, Train Accuracy : {}, Test Accuracy : {}".format(epoch_+1, trn_acc, tst_acc))

3000it [02:49, 17.72it/s]
2it [00:00, 16.29it/s]

Epoch : 1, Train Accuracy : 0.9029666666666667, Test Accuracy : 0.9075


3000it [02:50, 17.56it/s]
2it [00:00, 15.79it/s]

Epoch : 2, Train Accuracy : 0.9094333333333333, Test Accuracy : 0.9101


3000it [02:48, 17.76it/s]
2it [00:00, 16.45it/s]

Epoch : 3, Train Accuracy : 0.9372833333333334, Test Accuracy : 0.937


3000it [02:49, 17.70it/s]
2it [00:00, 16.21it/s]

Epoch : 4, Train Accuracy : 0.93495, Test Accuracy : 0.9336


3000it [02:51, 17.46it/s]
2it [00:00, 16.07it/s]

Epoch : 5, Train Accuracy : 0.9492, Test Accuracy : 0.9505


3000it [02:49, 17.69it/s]
2it [00:00, 15.91it/s]

Epoch : 6, Train Accuracy : 0.9494, Test Accuracy : 0.9484


3000it [02:49, 17.71it/s]
2it [00:00, 15.70it/s]

Epoch : 7, Train Accuracy : 0.9530666666666666, Test Accuracy : 0.9513


3000it [02:51, 17.47it/s]
2it [00:00, 16.08it/s]

Epoch : 8, Train Accuracy : 0.96295, Test Accuracy : 0.9611


3000it [02:50, 17.65it/s]
2it [00:00, 15.87it/s]

Epoch : 9, Train Accuracy : 0.9626, Test Accuracy : 0.9608


3000it [02:50, 17.64it/s]
2it [00:00, 15.78it/s]

Epoch : 10, Train Accuracy : 0.9591833333333334, Test Accuracy : 0.9524


3000it [02:52, 17.42it/s]
2it [00:00, 16.15it/s]

Epoch : 11, Train Accuracy : 0.9661666666666666, Test Accuracy : 0.9606


3000it [02:50, 17.57it/s]
2it [00:00, 15.94it/s]

Epoch : 12, Train Accuracy : 0.96465, Test Accuracy : 0.9548


3000it [02:51, 17.48it/s]
2it [00:00, 15.89it/s]

Epoch : 13, Train Accuracy : 0.9678833333333333, Test Accuracy : 0.963


3000it [02:52, 17.43it/s]
2it [00:00, 15.95it/s]

Epoch : 14, Train Accuracy : 0.9655, Test Accuracy : 0.9612


3000it [02:50, 17.58it/s]
2it [00:00, 16.04it/s]

Epoch : 15, Train Accuracy : 0.9545166666666667, Test Accuracy : 0.9492


3000it [02:50, 17.58it/s]
2it [00:00, 15.74it/s]

Epoch : 16, Train Accuracy : 0.9581166666666666, Test Accuracy : 0.9548


3000it [02:52, 17.37it/s]
2it [00:00, 15.83it/s]

Epoch : 17, Train Accuracy : 0.9562333333333334, Test Accuracy : 0.9509


3000it [02:50, 17.57it/s]
2it [00:00, 15.68it/s]

Epoch : 18, Train Accuracy : 0.9292, Test Accuracy : 0.927


3000it [02:51, 17.53it/s]
2it [00:00, 15.78it/s]

Epoch : 19, Train Accuracy : 0.9459333333333333, Test Accuracy : 0.9429


3000it [02:52, 17.36it/s]
2it [00:00, 16.08it/s]

Epoch : 20, Train Accuracy : 0.9517833333333333, Test Accuracy : 0.9485


3000it [02:50, 17.57it/s]
2it [00:00, 15.97it/s]

Epoch : 21, Train Accuracy : 0.8996333333333333, Test Accuracy : 0.8952


3000it [02:51, 17.47it/s]
2it [00:00, 16.11it/s]

Epoch : 22, Train Accuracy : 0.932, Test Accuracy : 0.9312


3000it [02:52, 17.43it/s]
2it [00:00, 15.75it/s]

Epoch : 23, Train Accuracy : 0.9356833333333333, Test Accuracy : 0.9321


3000it [02:50, 17.62it/s]
2it [00:00, 15.72it/s]

Epoch : 24, Train Accuracy : 0.9248166666666666, Test Accuracy : 0.9198


3000it [02:31, 19.86it/s]
2it [00:00, 14.53it/s]

Epoch : 25, Train Accuracy : 0.9239166666666667, Test Accuracy : 0.9214


3000it [02:17, 21.81it/s]
2it [00:00, 14.65it/s]

Epoch : 26, Train Accuracy : 0.9081333333333333, Test Accuracy : 0.9081


3000it [02:18, 21.66it/s]
2it [00:00, 15.03it/s]

Epoch : 27, Train Accuracy : 0.8918166666666667, Test Accuracy : 0.8873


3000it [02:18, 21.59it/s]
2it [00:00, 15.22it/s]

Epoch : 28, Train Accuracy : 0.89995, Test Accuracy : 0.8961


3000it [02:18, 21.62it/s]
2it [00:00, 14.72it/s]

Epoch : 29, Train Accuracy : 0.9102833333333333, Test Accuracy : 0.9096


3000it [02:18, 21.71it/s]
2it [00:00, 15.28it/s]

Epoch : 30, Train Accuracy : 0.89125, Test Accuracy : 0.8913


3000it [02:20, 21.38it/s]
2it [00:00, 13.64it/s]

Epoch : 31, Train Accuracy : 0.9251333333333334, Test Accuracy : 0.9214


3000it [02:21, 21.13it/s]
2it [00:00, 15.08it/s]

Epoch : 32, Train Accuracy : 0.9272, Test Accuracy : 0.9251


3000it [02:21, 21.20it/s]
2it [00:00, 14.33it/s]

Epoch : 33, Train Accuracy : 0.8973, Test Accuracy : 0.8916


3000it [02:22, 21.10it/s]
2it [00:00, 14.64it/s]

Epoch : 34, Train Accuracy : 0.89775, Test Accuracy : 0.8998


3000it [02:23, 20.96it/s]
2it [00:00, 12.71it/s]

Epoch : 35, Train Accuracy : 0.9133666666666667, Test Accuracy : 0.9119


3000it [02:22, 21.03it/s]
2it [00:00, 14.91it/s]

Epoch : 36, Train Accuracy : 0.8977833333333334, Test Accuracy : 0.8919


3000it [02:22, 21.03it/s]
2it [00:00, 14.20it/s]

Epoch : 37, Train Accuracy : 0.9089333333333334, Test Accuracy : 0.9061


3000it [02:18, 21.66it/s]
2it [00:00, 15.14it/s]

Epoch : 38, Train Accuracy : 0.9290166666666667, Test Accuracy : 0.9254


3000it [02:18, 21.68it/s]
2it [00:00, 14.57it/s]

Epoch : 39, Train Accuracy : 0.9357833333333333, Test Accuracy : 0.9347


3000it [02:18, 21.68it/s]
2it [00:00, 14.46it/s]

Epoch : 40, Train Accuracy : 0.8966666666666666, Test Accuracy : 0.8938


3000it [02:18, 21.67it/s]
2it [00:00, 13.41it/s]

Epoch : 41, Train Accuracy : 0.8866166666666667, Test Accuracy : 0.8803


3000it [02:18, 21.62it/s]
2it [00:00, 14.23it/s]

Epoch : 42, Train Accuracy : 0.94215, Test Accuracy : 0.9394


3000it [02:18, 21.64it/s]
2it [00:00, 14.67it/s]

Epoch : 43, Train Accuracy : 0.9009, Test Accuracy : 0.8888


3000it [02:18, 21.66it/s]
2it [00:00, 14.65it/s]

Epoch : 44, Train Accuracy : 0.9351333333333334, Test Accuracy : 0.931


3000it [02:17, 21.75it/s]
2it [00:00, 14.58it/s]

Epoch : 45, Train Accuracy : 0.9378833333333333, Test Accuracy : 0.9345


3000it [02:18, 21.66it/s]
2it [00:00, 14.76it/s]

Epoch : 46, Train Accuracy : 0.9357666666666666, Test Accuracy : 0.9316


3000it [02:18, 21.68it/s]
2it [00:00, 14.59it/s]

Epoch : 47, Train Accuracy : 0.8747166666666667, Test Accuracy : 0.8744


3000it [02:17, 21.77it/s]
2it [00:00, 15.12it/s]

Epoch : 48, Train Accuracy : 0.9394833333333333, Test Accuracy : 0.9385


3000it [02:17, 21.77it/s]
2it [00:00, 15.57it/s]

Epoch : 49, Train Accuracy : 0.9047166666666666, Test Accuracy : 0.8994


3000it [02:18, 21.72it/s]


Epoch : 50, Train Accuracy : 0.9436166666666667, Test Accuracy : 0.9401


In [9]:
Wff_init = torch.clone(model.Wff[0]['weight'])

In [10]:
x,y = next(iter(train_loader))
# x = x.view(x.size(0),-1).to(device).T
y_one_hot = F.one_hot(y, 10).to(device).T
x = activation_inverse(x.view(x.size(0),-1).T, "sigmoid").to(device)
x.shape

torch.Size([784, 20])

In [11]:
model.batch_step(x, y_one_hot, lr_start, neural_lr_start, neural_lr_stop, neural_lr_rule, 
                   neural_lr_decay_multiplier, neural_dynamic_iterations)

In [12]:
optimizer.step()

In [13]:
torch.norm(model.Wff[0]["weight"] - Wff_init)

tensor(0.0451, device='cuda:0', grad_fn=<CopyBackwards>)

In [14]:
neurons = model.forward(x, no_grad = False)
neurons[-1] = y_one_hot.to(torch.float)
neurons2 = model.run_neural_dynamics(x, y, neurons, neural_lr_start, neural_lr_stop, lr_rule = neural_lr_rule,
                          lr_decay_multiplier = neural_lr_decay_multiplier, 
                            neural_dynamic_iterations = 50)

In [15]:
neurons2

[tensor([[ 34.6358,  36.5416,  37.6643,  ...,  35.5534,  35.6299,  34.6885],
         [-35.7291, -34.8083, -36.8174,  ..., -34.2573, -37.4625, -33.6485],
         [ 32.1466,  33.3480,  33.1700,  ...,  32.3953,  32.9427,  32.1870],
         ...,
         [-29.4734, -30.1304, -30.0850,  ..., -29.8606, -29.7603, -29.5905],
         [-27.0878, -25.3572, -25.0104,  ..., -25.5114, -25.4776, -25.5118],
         [-23.9726, -22.7487, -23.1222,  ..., -24.3451, -25.2882, -22.5572]],
        device='cuda:0', requires_grad=True),
 tensor([[ 5.9757e+02,  5.9757e+02,  5.9758e+02,  ...,  5.9757e+02,
           5.9757e+02,  5.9761e+02],
         [ 4.7166e+02,  4.7165e+02,  4.7166e+02,  ...,  4.7166e+02,
           4.7166e+02,  4.7168e+02],
         [-1.9773e-01, -1.3857e-01, -1.9575e-01,  ..., -1.1359e-01,
          -2.0894e-01, -1.2648e-01],
         ...,
         [-5.1652e+02, -5.1651e+02, -5.1652e+02,  ..., -5.1651e+02,
          -5.1651e+02, -5.1655e+02],
         [-5.2754e+02, -5.2753e+02, -5.2754

In [16]:
torch.norm(neurons[1] - neurons2[1])

tensor(0., device='cuda:0', grad_fn=<CopyBackwards>)

In [17]:
neurons = model.forward(x)
neurons[-1] = y_one_hot.to(torch.float)
pc_loss = model.PC_loss(x, neurons).mean()
pc_loss.backward()
model.Wff[1]['weight'].grad

tensor([[-2.8870e-03, -5.4945e-13, -2.8870e-03,  ..., -6.9556e-11,
         -2.1319e-10, -3.1035e-10],
        [-1.3245e-03, -2.1212e-13, -1.3245e-03,  ..., -3.1860e-11,
         -5.2922e-11, -6.6940e-11],
        [ 1.9955e-01, -1.1831e-11,  1.9955e-01,  ..., -1.2546e-10,
          4.5661e-08, -1.2444e-08],
        ...,
        [ 3.5828e-03,  5.7774e-13,  3.5828e-03,  ...,  8.5537e-11,
          4.9043e-10,  4.3372e-10],
        [ 1.8738e-03,  3.7105e-13,  1.8738e-03,  ...,  4.5327e-11,
          1.9336e-10,  1.2972e-10],
        [ 3.8940e-03,  6.5016e-13,  3.8940e-03,  ...,  9.3730e-11,
          4.0010e-10,  5.2361e-10]], device='cuda:0')

In [18]:
torch.norm(model.Wff[1]['weight'].grad)

tensor(22.1732, device='cuda:0')

In [19]:
F = 0
Wff = model.Wff
layers = [x] + neurons
for jj in range(len(Wff)):
    error = (layers[jj + 1] - (Wff[jj]['weight'] @ model.activation(layers[jj]) + Wff[jj]['bias'])) / model.variances[jj]
    # print(error.shape, torch.sum(error * error, 0).shape)
    F -= model.variances[jj + 1] * torch.sum(error * error, 0)

In [20]:
layers

[tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0'),
 tensor([[ 34.6358,  36.5416,  37.6643,  ...,  35.5534,  35.6299,  34.6885],
         [-35.7291, -34.8083, -36.8174,  ..., -34.2573, -37.4625, -33.6485],
         [ 32.1466,  33.3480,  33.1700,  ...,  32.3953,  32.9427,  32.1870],
         ...,
         [-29.4734, -30.1304, -30.0850,  ..., -29.8606, -29.7603, -29.5905],
         [-27.0878, -25.3572, -25.0104,  ..., -25.5114, -25.4776, -25.5118],
         [-23.9726, -22.7487, -23.1222,  ..., -24.3451, -25.2882, -22.5572]],
        device='cuda:0', grad_fn=<AddBackward0>),
 tensor([[ 5.9757e+02,  5.9757e+02,  5.9758e+02,  ...,  5.9757e+02,
           5.9757e+02,  5.9761e+02],
         [ 4.7166e+02,  4.7165e+02,  4.7166e+02,  ...,  4.7166e+02,
           4.7166e+02

In [21]:
torch.norm(model.Wff[1]['weight'].grad)

tensor(22.1732, device='cuda:0')

In [22]:
neurons[0]

tensor([[ 34.6358,  36.5416,  37.6643,  ...,  35.5534,  35.6299,  34.6885],
        [-35.7291, -34.8083, -36.8174,  ..., -34.2573, -37.4625, -33.6485],
        [ 32.1466,  33.3480,  33.1700,  ...,  32.3953,  32.9427,  32.1870],
        ...,
        [-29.4734, -30.1304, -30.0850,  ..., -29.8606, -29.7603, -29.5905],
        [-27.0878, -25.3572, -25.0104,  ..., -25.5114, -25.4776, -25.5118],
        [-23.9726, -22.7487, -23.1222,  ..., -24.3451, -25.2882, -22.5572]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [23]:
model.Wff[0]['weight']

tensor([[ 0.0997,  0.0660, -0.1457,  ..., -0.1604,  0.2353,  0.0943],
        [-0.0935,  0.0213,  0.1141,  ...,  0.0240, -0.0926, -0.2190],
        [ 0.1823,  0.0187,  0.1770,  ..., -0.1504, -0.1178,  0.0643],
        ...,
        [ 0.1172,  0.1195,  0.1770,  ..., -0.1164, -0.1110, -0.3403],
        [-0.2720,  0.0665, -0.2049,  ..., -0.0628,  0.1066, -0.2318],
        [ 0.0061,  0.1548,  0.1897,  ...,  0.0848, -0.0991,  0.0012]],
       device='cuda:0', requires_grad=True)

In [24]:
loss2 = (model.Wff[0]['weight'] @ x).sum()
loss2.backward()
model.Wff[0]['weight'].grad

tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [7.6547e-16, 7.6547e-16, 7.6547e-16,  ..., 7.6547e-16, 7.6547e-16,
         7.6547e-16],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        ...,
        [2.3083e-14, 2.3083e-14, 2.3083e-14,  ..., 2.3083e-14, 2.3083e-14,
         2.3083e-14],
        [7.6947e-07, 7.6947e-07, 7.6947e-07,  ..., 7.6947e-07, 7.6947e-07,
         7.6947e-07],
        [2.5572e-08, 2.5572e-08, 2.5572e-08,  ..., 2.5572e-08, 2.5572e-08,
         2.5572e-08]], device='cuda:0')

In [25]:
pc_loss.backward()
model.Wff[0]['weight'].grad

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

In [None]:
torch.norm(neurons[-1] - neurons2[-1])

In [None]:
neurons = model.forward(x, no_grad = True)
neurons = model.run_neural_dynamics(x, y, neurons, neural_lr_start, neural_lr_stop, lr_rule = neural_lr_rule,
                          lr_decay_multiplier = neural_lr_decay_multiplier, 
                            neural_dynamic_iterations = 50)

In [None]:
pc_loss = model.PC_loss(x, neurons).mean()
pc_loss.backward()
model.Wff[1]['weight'].grad

In [None]:
torch.norm(model.Wff[1]['weight'].grad)

In [None]:
F = 0
Wff = model.Wff
layers = [x] + neurons
for jj in range(len(Wff)):
    error = (layers[jj + 1] - (Wff[jj]['weight'] @ model.activation(layers[jj]) + Wff[jj]['bias'])) / model.variances[jj]
    # print(error.shape, torch.sum(error * error, 0).shape)
    F -= model.variances[jj + 1] * torch.sum(error * error, 0)
    
F

In [None]:
neurons = model.forward(x)
for jj in range(len(neurons) - 1):
    neurons[jj] = neurons[jj].requires_grad_()
pc_loss = model.PC_loss(x, neurons)
init_grads = torch.tensor([1 for i in range(20)], dtype=torch.float, device=device, requires_grad=True) #Initializing gradients
grads = torch.autograd.grad(pc_loss, neurons[:-1], grad_outputs=init_grads, create_graph=False) # dPhi/ds


In [None]:
x.shape

In [None]:
neurons[1].requires_grad_()

In [None]:
torch.sum(error * error, 0)

In [None]:
trn_acc_list = []
tst_acc_list = []

n_epochs = 50
lr = lr_start
for epoch_ in range(n_epochs):
#     lr = {'ff' : lr_start['ff'] * (0.999)**epoch_}
    for idx, (x, y) in tqdm(enumerate(train_loader)):
#         x = x.view(x.size(0),-1).T#.to(device)
        x, y = x.to(device), y.to(device)
        x = activation_inverse(2*x.view(x.size(0),-1).T - 1, "sigmoid")
#         x = activation_inverse(x.view(x.size(0),-1).T, "sigmoid")
        y_one_hot = F.one_hot(y, 10).to(device).T
        y_one_hot = 0.94 * y_one_hot + 0.03 * torch.ones(*y_one_hot.shape, device = device)
        _ = model.batch_step(  x, y_one_hot, lr, neural_lr_start, neural_lr_stop, neural_lr_rule,
                               neural_lr_decay_multiplier, neural_dynamic_iterations,
                               optimizer = "adam")

    trn_acc = evaluatePC(  model, train_loader, neural_lr_start, neural_lr_stop, neural_lr_rule, 
                           neural_lr_decay_multiplier,
                           neural_dynamic_iterations, device, printing = False)
    tst_acc = evaluatePC(  model, test_loader, neural_lr_start, neural_lr_stop, neural_lr_rule, 
                           neural_lr_decay_multiplier,
                           neural_dynamic_iterations, device, printing = False)
    trn_acc_list.append(trn_acc)
    tst_acc_list.append(tst_acc)
    
    print("Epoch : {}, Train Accuracy : {}, Test Accuracy : {}".format(epoch_+1, trn_acc, tst_acc))

In [None]:
x, y = next(iter(train_loader))
x, y = x.to(device), y.to(device)
x = activation_inverse(x.view(x.size(0),-1).T, "sigmoid")
y_one_hot = F.one_hot(y, 10).to(device).T
y_one_hot = 0.94 * y_one_hot + 0.03 * torch.ones(*y_one_hot.shape, device = device)

neurons = model.fast_forward(x)
mode = "train"
if mode == "train":
    neurons[-1] = y_one_hot.to(torch.float)
    
neurons = model.run_neural_dynamics( x, y, neurons, neural_lr_start, neural_lr_stop, neural_lr_rule, 
                                            neural_lr_decay_multiplier, neural_dynamic_iterations)

layers = [x] + neurons  # concatenate the input to other layers
layers_after_activation = [list(model.activation_func(layers[jj], model.activation_type)) for jj in range(len(layers))]
error_layers = [(layers[jj+1] - (model.Wff[jj]['weight'] @ layers_after_activation[jj][0] + model.Wff[jj]['bias'])) / model.variances[jj + 1] for jj in range(len(layers) - 1)]


In [None]:
error_layers[2].shape, layers_after_activation[2][0].shape

In [None]:
(1/20) * (error_layers[jj] @ layers_after_activation[0][0].T)

In [None]:
jj = 0
torch.mean(outer_prod_broadcasting(error_layers[jj].T, layers_after_activation[jj][0].T), axis = 0)

In [None]:
error_layers[0]

In [None]:
model.Wff[0]["weight"].shape

In [None]:
model.Wff[0]['weight'].shape

In [None]:
model.activation_func(x, model.activation_type)[0].shape

In [None]:
neurons = model.fast_forward(activation_func(x, "sigmoid")[0])
neurons[-1] = y_one_hot.to(torch.float)
neurons = model.run_neural_dynamics(x, y, neurons, neural_lr_start, neural_lr_stop, neural_lr_rule, 
                          neural_lr_decay_multiplier, neural_dynamic_iterations)

In [None]:
tst_acc = evaluatePC(  model, test_loader, neural_lr_start, neural_lr_stop, neural_lr_rule, 
                                 neural_lr_decay_multiplier,
                                 neural_dynamic_iterations, device, printing = True)

In [None]:
Wff = model.Wff
Wff[0]

In [None]:
model.batch_step(x, y_one_hot, lr_start, neural_lr_start, neural_lr_stop, neural_lr_rule, 
                          neural_lr_decay_multiplier, neural_dynamic_iterations)

In [None]:
model.Wff

In [None]:
evalua

In [None]:
neurons = model.fast_forward(activation_func(x, "sigmoid")[0])

In [None]:
layers_after_activation, error_layers, grads = model.calculate_neural_dynamics_grad(x, y, neurons)

In [None]:
grads

In [None]:
len(layers_after_activation), len(error_layers), len(grads)

In [None]:
layers_after_activation[jj + 1][1].shape

In [None]:
jj = 0
error_layers[jj + 1].shape, model.Wff[jj + 1]['weight'].shape

In [None]:
len(layers_after_activation)

In [None]:
jj = 0
Wff[jj]['weight'] @ layers_after_activation[jj][0]

In [None]:
variances

In [None]:
Wff = model.Wff
variances = model.variances
layers = [x] + neurons

[(layers[jj+1] - (Wff[jj]['weight'] @ layers_after_activation[jj][0] + Wff[jj]['bias'])) / variances[jj + 1] for jj in range(len(layers) - 1)]

In [None]:
def activation_func(x, type_ = "linear"):
    if type_ == "linear":
        f_x = x
        fp_x = torch.ones(*x.shape, device = x.device)
    elif type_ == "tanh":
        f_x = torch.tanh(x)
        fp_x = torch.ones(*x.shape, device = x.device) - f_x ** 2
    elif type_ == "sigmoid":
        ones_vec = torch.ones(*x.shape, device = x.device)
        f_x = 1 / (ones_vec + torch.exp(-x))
        fp_x = f_x * (ones_vec - f_x)
    elif type_ == "relu":
        f_x = torch.maximum(x, torch.tensor([0], device = x.device))
        fp_x = 1 * (x > 0)
    elif type_ == "exp":
        f_x = torch.exp(x)
        fp_x = f_x
    else: # Use linear
        f_x = x
        fp_x = torch.ones(*x.shape, device = x.device)
        
    return f_x, fp_x

In [None]:
x = torch.randn(3,1, device = "cuda")
x

In [None]:
x = torch.randn(3,1, device = "cuda")
print(x)
activation_func(x, type_ = "sigmoid")

In [None]:
device = "cuda"

In [None]:
(2 * torch.rand(3, 3, requires_grad = False).to(device) - 1) * (4 * np.sqrt(6 / (3 + 3)))