In [1]:
import sys
sys.path.append("../src")
import torch
import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torch.nn.functional as F

import glob
import os
from datetime import datetime
import time
import math
from tqdm import tqdm

from itertools import repeat
from torch.nn.parameter import Parameter
import collections
import matplotlib
from torch_utils import *
from models import *
from visualization import *
# matplotlib.use('Agg')

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [3]:
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), 
                                            torchvision.transforms.Normalize(mean=(0.0,), std=(1.0,))])

mnist_dset_train = torchvision.datasets.MNIST('./data', train=True, transform=transform, target_transform=None, download=True)
train_loader = torch.utils.data.DataLoader(mnist_dset_train, batch_size=20, shuffle=True, num_workers=0)

mnist_dset_test = torchvision.datasets.MNIST('./data', train=False, transform=transform, target_transform=None, download=True)
test_loader = torch.utils.data.DataLoader(mnist_dset_test, batch_size=20, shuffle=False, num_workers=0)

In [4]:
activation = hard_sigmoid
criterion = torch.nn.MSELoss(reduction='none').to(device)

In [5]:
architecture = [784, 500, 10]

x,y = next(iter(train_loader))
x = x.view(x.size(0),-1).to(device).T
y_one_hot = F.one_hot(y, 10).to(device).T

lambda_h = 0.9999
lambda_y = 0.9999
epsilon = 0.1
one_over_epsilon = 1 / epsilon
lr = {'ff' : 0.1, 'fb': 0.1, 'lat': 1e-3}
neural_lr = 0.05
model = TwoLayerCorInfoMax(architecture = architecture, lambda_h = lambda_h, lambda_y = lambda_y, 
                           epsilon = epsilon, activation = activation)

In [None]:
trn_acc_list = []
tst_acc_list = []
neural_dynamic_iterations_free = 20
neural_dynamic_iterations_nudged = 4
# lambda_h = 0.01
# lambda_y = 0.01
# epsilon = 1
# one_over_epsilon = 1 / epsilon
beta = 1
n_epochs = 50
# lr = {'ff' : 1e-3, 'fb': 1e-3, 'lat': 1e-3}
# neural_lr = 0.25

for epoch_ in range(n_epochs):
    for idx, (x, y) in tqdm(enumerate(train_loader)):
        x, y = x.to(device), y.to(device)
        x = x.view(x.size(0),-1).T
        y_one_hot = F.one_hot(y, 10).to(device).T

        h, y_hat = model.batch_step(  x, y_one_hot, lr, neural_lr, neural_dynamic_iterations_free, 
                                      neural_dynamic_iterations_nudged, beta)

    trn_acc = evaluateCorInfoMax(model, train_loader, neural_lr, 20, device = 'cuda', printing = False)
    tst_acc = evaluateCorInfoMax(model, test_loader, neural_lr, 20, device = 'cuda', printing = False)
    trn_acc_list.append(trn_acc)
    tst_acc_list.append(tst_acc)
    if epoch_ == 4:
        lr = {'ff' : 0.1, 'fb': 0.1, 'lat': 1e-3}
    
    print("Epoch : {}, Train Accuracy : {}, Test Accuracy : {}".format(epoch_+1, trn_acc, tst_acc))

3000it [00:34, 85.79it/s]
9it [00:00, 88.39it/s]

Epoch : 1, Train Accuracy : 0.19958333333333333, Test Accuracy : 0.1989


690it [00:07, 89.93it/s]

In [None]:
class TwoLayerCorInfoMax():
    
    def __init__(self, architecture, lambda_h, lambda_y, epsilon, activation = hard_sigmoid):
        
        self.architecture = architecture
        self.lambda_h = lambda_h
        self.lambda_y = lambda_y
        self.epsilon = epsilon
        self.one_over_epsilon = one_over_epsilon
        self.activation = activation
        self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
        
        # Feedforward Synapses Initialization
        Wff = []
        for idx in range(len(architecture)-1):
            weight = torch.randn(architecture[idx + 1], architecture[idx], requires_grad = False).to(device)
            torch.nn.init.xavier_uniform_(weight)
            bias = torch.zeros(architecture[idx + 1], 1, requires_grad = False).to(device)
            Wff.append({'weight': weight, 'bias': bias})
        Wff = np.array(Wff)
        
        # Feedback Synapses Initialization
        Wfb = []
        for idx in range(len(architecture)-1):
            weight = torch.randn(architecture[idx], architecture[idx + 1], requires_grad = False).to(device)
            torch.nn.init.xavier_uniform_(weight)
            bias = torch.zeros(architecture[idx], 1, requires_grad = False).to(device)
            Wfb.append({'weight': weight, 'bias': bias})
        Wfb = np.array(Wfb)
        
        # Lateral Synapses Initialization
        B = []
        for idx in range(len(architecture)-1):
            weight = torch.randn(architecture[idx + 1], architecture[idx + 1], requires_grad = False).to(device)
            torch.nn.init.xavier_uniform_(weight)
            weight = weight @ weight.T
            B.append({'weight': weight})
        B = np.array(B)
#         # Feedforward Synapses Initialization
#         Wff = []
#         for idx in range(len(architecture)-1):
#             weight = torch.eye(architecture[idx + 1], architecture[idx], requires_grad = False).to(device)
#             #torch.nn.init.xavier_uniform_(weight)
#             bias = torch.zeros(architecture[idx + 1], 1, requires_grad = False).to(device)
#             Wff.append({'weight': weight, 'bias': bias})

#         # Feedback Synapses Initialization
#         Wfb = []
#         for idx in range(len(architecture)-1):
#             weight = torch.eye(architecture[idx], architecture[idx + 1], requires_grad = False).to(device)
#             #torch.nn.init.xavier_uniform_(weight)
#             bias = torch.zeros(architecture[idx], 1, requires_grad = False).to(device)
#             Wfb.append({'weight': weight, 'bias': bias})

#         # Lateral Synapses Initialization
#         B = []
#         for idx in range(len(architecture)-1):
#             weight = 10*torch.eye(architecture[idx + 1], architecture[idx + 1], requires_grad = False).to(device)
#             #torch.nn.init.xavier_uniform_(weight)
#             #weight = weight @ weight.T
#             B.append({'weight': weight})
            
        self.Wff = Wff
        self.Wfb = Wfb
        self.B = B
        
    def init_neurons(self, mbs, random_initialize = True, device = 'cuda'):
        # Initializing the neurons
        if random_initialize:
            neurons = []
            append = neurons.append
            for size in self.architecture[1:]:  
                append(torch.randn((mbs, size), requires_grad=False, device=device).T)       
        else:
            neurons = []
            append = neurons.append
            for size in self.architecture[1:]:  
                append(torch.zeros((mbs, size), requires_grad=False, device=device).T)
        return neurons
    
    def calculate_neural_dynamics_grad(self, x, h, y_hat, y, beta):
        Wff = self.Wff
        Wfb = self.Wfb
        B = self.B
        lambda_h = self.lambda_h
        lambda_y = self.lambda_y
        one_over_epsilon = self.one_over_epsilon
        
        grad_h = 0.5*(one_over_epsilon * Wfb[0]['weight'].T @ (x - (Wfb[0]['weight'] @ h + Wfb[0]['bias'])) + 
             ((1 - lambda_h) / lambda_h) * B[0]['weight'] @ h -
             one_over_epsilon * (h - (Wff[0]['weight'] @ x + Wff[0]['bias'])))

        grad_y = 0.5*(one_over_epsilon * Wfb[1]['weight'].T @ (h - (Wfb[1]['weight'] @ y_hat + Wfb[1]['bias'])) +
             ((1 - lambda_y) / lambda_y) * B[1]['weight'] @ y_hat - 
             one_over_epsilon * (y_hat - (Wff[1]['weight'] @ h + Wff[1]['bias']))) + 2 * beta * (y - y_hat)

        return grad_h, grad_y

    def run_neural_dynamics(self, x, h, y_hat, y, neural_lr, neural_dynamic_iterations, beta):
        for iter_count in range(neural_dynamic_iterations):
            with torch.no_grad():       
                grad_h, grad_y = self.calculate_neural_dynamics_grad(x, h, y_hat, y, beta)
                h = self.activation(h + neural_lr * grad_h)
                y_hat = self.activation(y_hat + neural_lr * grad_y)
        return h, y_hat
    
    def batch_step(self, x, lr, neural_lr, neural_dynamic_iterations_free, 
                   neural_dynamic_iterations_nudged, beta):
        
        Wff, Wfb, B = self.Wff, self.Wfb, self.B
        
        h, y_hat = model.init_neurons(x.size(1), device = model.device)

        h, y_hat = self.run_neural_dynamics(x, h, y_hat, y_one_hot, neural_lr, 
                                            neural_dynamic_iterations_free, 0)
        neurons1 = [h, y_hat].copy()

        error_hx_free = h - (self.Wff[0]['weight'] @ x + self.Wff[0]['bias'])
        error_xh_free = x - (self.Wfb[0]['weight'] @ h + self.Wfb[0]['bias'])

        error_yh_free = y_hat - (self.Wff[1]['weight'] @ h + self.Wff[1]['bias'])
        error_hy_free = h - (self.Wfb[1]['weight'] @ y_hat + self.Wfb[1]['bias'])

        h, y_hat = self.run_neural_dynamics(x, h, y_hat, y_one_hot, neural_lr, 
                                            neural_dynamic_iterations_nudged, beta)
        neurons2 = [h, y_hat].copy()

        error_hx_nudged = h - (self.Wff[0]['weight'] @ x + self.Wff[0]['bias'])
        error_xh_nudged = x - (self.Wfb[0]['weight'] @ h + self.Wfb[0]['bias'])

        error_yh_nudged = y_hat - (self.Wff[1]['weight'] @ h + self.Wff[1]['bias'])
        error_hy_nudged = h - (self.Wfb[1]['weight'] @ y_hat + self.Wfb[1]['bias'])
        
        Wff_old = torch.clone(Wff[0]['weight'])
        ### Weight Updates
        #k = 5  # Below lines output ---> tensor(0., device='cuda:0')
        #torch.norm(outer_prod_broadcasting(error_hx_free.T, x.T)[k] - (torch.outer(error_hx_free[:,k], x[:,k])))
        Wff[0]['weight'] -= lr['ff'] * torch.mean(outer_prod_broadcasting((error_hx_free - error_hx_nudged).T, x.T), axis = 0)
        Wfb[0]['weight'] -= lr['ff'] * torch.mean(outer_prod_broadcasting(error_xh_free.T, neurons1[0].T) - outer_prod_broadcasting(error_xh_nudged.T, neurons2[0].T), axis = 0)
        Wff[1]['weight'] -= lr['ff'] * torch.mean(outer_prod_broadcasting(error_yh_free.T, neurons1[0].T) - outer_prod_broadcasting(error_yh_nudged.T, neurons2[0].T), axis = 0)
        Wfb[1]['weight'] -= lr['ff'] * torch.mean(outer_prod_broadcasting(error_hy_free.T, neurons1[1].T) - outer_prod_broadcasting(error_hy_nudged.T, neurons2[1].T), axis = 0)
        
        Wff[0]['bias'] -= lr['fb'] * torch.mean(error_hx_nudged - error_hx_free, axis = 1, keepdims = True)
        Wfb[0]['bias'] -= lr['fb'] * torch.mean(error_xh_nudged - error_xh_free, axis = 1, keepdims = True)
        Wff[1]['bias'] -= lr['fb'] * torch.mean(error_yh_nudged - error_yh_free, axis = 1, keepdims = True)
        Wfb[1]['bias'] -= lr['fb'] * torch.mean(error_hy_nudged - error_hy_free, axis = 1, keepdims = True)

        B[0]['weight'] -= lr['lat'] * (torch.mean(outer_prod_broadcasting(neurons2[0].T, neurons2[0].T), axis = 0) - torch.mean(outer_prod_broadcasting(neurons1[0].T, neurons1[0].T), axis = 0))
        B[1]['weight'] -= lr['lat'] * (torch.mean(outer_prod_broadcasting(neurons2[1].T, neurons2[1].T), axis = 0) - torch.mean(outer_prod_broadcasting(neurons1[1].T, neurons1[1].T), axis = 0))
        
        self.Wff = Wff
        self.Wfb = Wfb
        self.B = B
        
#         print(torch.norm(Wff_old - Wff[0]['weight']))
#         print(torch.norm(torch.mean(outer_prod_broadcasting((error_hx_free - error_hx_nudged).T, x.T), axis = 0)))
        return h, y_hat

In [None]:
def evaluateCorInfoMax(model, loader, neural_lr, T, device, printing = True):
    # Evaluate the model on a dataloader with T steps for the dynamics
    #model.eval()
    correct=0
    phase = 'Train' if loader.dataset.train else 'Test'
    
    for x, y in loader:
        x = x.view(x.size(0),-1).to(device).T
        y = y.to(device)
        
        h, y_hat = model.init_neurons(x.size(1), device = model.device)
        
        # dynamics for T time steps
        h, y_hat = model.run_neural_dynamics(x, h, y_hat, 0, neural_lr = neural_lr, 
                                             neural_dynamic_iterations = T, beta = 0) 
        
        pred = torch.argmax(y_hat, dim=0).squeeze()  # in this case prediction is done directly on the last (output) layer of neurons
        correct += (y == pred).sum().item()

    acc = correct/len(loader.dataset) 
    if printing:
        print(phase+' accuracy :\t', acc)   
    return acc

In [None]:
architecture = [784, 500, 10]

x,y = next(iter(train_loader))
x = x.view(x.size(0),-1).to(device).T
y_one_hot = F.one_hot(y, 10).to(device).T

lambda_h = 0.5
lambda_y = 0.5
epsilon = 1
one_over_epsilon = 1 / epsilon
lr = {'ff' : 1e-1, 'fb': 1e-1, 'lat': 1e-3}
neural_lr = 0.02
model = TwoLayerCorInfoMax(architecture = architecture, lambda_h = lambda_h, lambda_y = lambda_y, 
                           epsilon = epsilon, activation = activation)

In [None]:
# evaluateCorInfoMax(model, test_loader, neural_lr, 20, device = 'cuda')

# Training

In [None]:
trn_acc_list = []
tst_acc_list = []
neural_dynamic_iterations_free = 20
neural_dynamic_iterations_nudged = 4
lambda_h = 0.5
lambda_y = 0.5
epsilon = 1
one_over_epsilon = 1 / epsilon
beta = 1
n_epochs = 50
# lr = 1e-3
lr = {'ff' : 1e-3, 'fb': 1e-3, 'lat': 1e-3}
neural_lr = 0.1
# Wff_old = torch.clone(model.Wff[0]['weight'])
# print(model.Wff[0]['weight'])
for epoch_ in range(n_epochs):
    for idx, (x, y) in tqdm(enumerate(train_loader)):
        x, y = x.to(device), y.to(device)
        x = x.view(x.size(0),-1).T
        y_one_hot = F.one_hot(y, 10).to(device).T

        h, y_hat = model.batch_step(  x, lr, neural_lr, neural_dynamic_iterations_free, 
                                      neural_dynamic_iterations_nudged, beta)

#         break
#     break
    trn_acc = evaluateCorInfoMax(model, train_loader, neural_lr, 20, device = 'cuda', printing = False)
    tst_acc = evaluateCorInfoMax(model, test_loader, neural_lr, 20, device = 'cuda', printing = False)
    trn_acc_list.append(trn_acc)
    tst_acc_list.append(tst_acc)
    
    print("Epoch : {}, Train Accuracy : {}, Test Accuracy : {}".format(epoch_+1, trn_acc, tst_acc))

# print(model.Wff[0]['weight'])

In [None]:
def torch2numpy(x):
    return x.detach().cpu().numpy()

In [None]:
plt.imshow(torch2numpy(Wff_old - model.Wff[0]['weight']))

In [None]:
model.Wff[0]['weight']

In [None]:
h, y_hat = model.init_neurons(x.size(1), device = model.device)

# dynamics for T time steps
h, y_hat = model.run_neural_dynamics(x, h, y_hat, 0, neural_lr = neural_lr, 
                                     neural_dynamic_iterations = 20, beta = 0) 

In [None]:
model.Wff

In [None]:
architecture = [784, 500, 10]

x,y = next(iter(train_loader))
x = x.view(x.size(0),-1).to(device).T
y_one_hot = F.one_hot(y, 10).to(device).T

lambda_h = 1e-2
lambda_y = 1e-2
epsilon = 1e-2
one_over_epsilon = 1 / epsilon
lr = 1e-3

model = TwoLayerCorInfoMax(architecture = architecture, lambda_h = lambda_h, lambda_y = lambda_y, 
                           epsilon = epsilon, activation = activation)

Wff = model.Wff
Wfb = model.Wfb
B = model.B

h, y_hat = model.init_neurons(x.size(1), device = model.device)

h, y_hat = model.run_neural_dynamics(x, h, y_hat, y_one_hot, 1e-3, 20, 0)

In [None]:
torch.argmax(y_hat, dim=0).squeeze() - y.to(device)

In [None]:
architecture = [784, 500, 10]

x,y = next(iter(train_loader))
x = x.view(x.size(0),-1).to(device).T
y_one_hot = F.one_hot(y, 10).to(device).T

lambda_h = 1e-2
lambda_y = 1e-2
epsilon = 1e-2
one_over_epsilon = 1 / epsilon
lr = 1e-3

model = TwoLayerCorInfoMax(architecture = architecture, lambda_h = lambda_h, lambda_y = lambda_y, 
                           epsilon = epsilon, activation = activation)

Wff = model.Wff
Wfb = model.Wfb
B = model.B

h, y_hat = model.init_neurons(x.size(1), device = model.device)

h, y_hat = model.run_neural_dynamics(x, h, y_hat, y_one_hot, 1e-3, 20, 0)
neurons1 = [h, y_hat].copy()

error_hx_free = h - (Wff[0]['weight'] @ x + Wff[0]['bias'])
error_xh_free = x - (Wfb[0]['weight'] @ h + Wfb[0]['bias'])

error_yh_free = y_hat - (Wff[1]['weight'] @ h + Wff[1]['bias'])
error_hy_free = h - (Wfb[1]['weight'] @ y_hat + Wfb[1]['bias'])

h, y_hat = model.run_neural_dynamics(x, h, y_hat, y_one_hot, 1e-3, 4, 1)
neurons2 = [h, y_hat].copy()

error_hx_nudged = h - (Wff[0]['weight'] @ x + Wff[0]['bias'])
error_xh_nudged = x - (Wfb[0]['weight'] @ h + Wfb[0]['bias'])

error_yh_nudged = y_hat - (Wff[1]['weight'] @ h + Wff[1]['bias'])
error_hy_nudged = h - (Wfb[1]['weight'] @ y_hat + Wfb[1]['bias'])


Wff[0]['weight'] = Wff[0]['weight'] - lr * torch.mean(outer_prod_broadcasting((error_hx_free - error_hx_nudged).T, x.T), axis = 0)
Wfb[0]['weight'] = Wfb[0]['weight'] - lr * torch.mean(outer_prod_broadcasting(error_xh_free.T, neurons1[0].T) - outer_prod_broadcasting(error_xh_nudged.T, neurons2[0].T), axis = 0)
Wff[1]['weight'] = Wff[1]['weight'] - lr * torch.mean(outer_prod_broadcasting(error_yh_free.T, neurons1[0].T) - outer_prod_broadcasting(error_yh_nudged.T, neurons2[0].T), axis = 0)
Wfb[1]['weight'] = Wfb[1]['weight'] - lr * torch.mean(outer_prod_broadcasting(error_hy_free.T, neurons1[1].T) - outer_prod_broadcasting(error_hy_nudged.T, neurons2[1].T), axis = 0)

Wff[0]['bias'] = Wff[0]['bias'] - lr * torch.mean(error_hx_nudged - error_hx_free, axis = 1, keepdims = True)
Wfb[0]['bias'] = Wfb[0]['bias'] - lr * torch.mean(error_xh_nudged - error_xh_free, axis = 1, keepdims = True)
Wff[1]['bias'] = Wff[1]['bias'] - lr * torch.mean(error_yh_nudged - error_yh_free, axis = 1, keepdims = True)
Wfb[1]['bias'] = Wfb[1]['bias'] - lr * torch.mean(error_hy_nudged - error_hy_free, axis = 1, keepdims = True)

B[0]['weight'] = B[0]['weight'] - lr * (torch.mean(outer_prod_broadcasting(neurons2[0].T, neurons2[0].T), axis = 0) - torch.mean(outer_prod_broadcasting(neurons1[0].T, neurons1[0].T), axis = 0))
B[1]['weight'] = B[1]['weight'] - lr * (torch.mean(outer_prod_broadcasting(neurons2[1].T, neurons2[1].T), axis = 0) - torch.mean(outer_prod_broadcasting(neurons1[1].T, neurons1[1].T), axis = 0))

In [None]:
Wff[0]['bias'].shape

In [None]:
torch.mean(error_hx_nudged - error_hx_free, axis = 1, keepdims = True).shape

In [None]:
error_hx_free.shape, x.shape

In [None]:
torch.mean(outer_prod_broadcasting(error_hx_free.T, x.T), axis = 0).shape

In [None]:
k = 5
torch.norm(outer_prod_broadcasting(error_hx_free.T, x.T)[k] - (torch.outer(error_hx_free[:,k], x[:,k])))

In [None]:
(torch.outer(error_hx_free[0], x[0])).shape

In [None]:
torch.argmax(y_hat, dim=0).squeeze()

In [None]:
torch.mean(outer_prod_broadcasting(h.T, h.T), axis = 0).shape

In [None]:
architecture = [784, 500, 10]

# # Feedforward Synapses Initialization
# Wff = torch.nn.ModuleList()
# for idx in range(len(architecture)-1):
#     m = torch.nn.Linear(architecture[idx], architecture[idx+1], bias=True)
#     torch.nn.init.xavier_uniform_(m.weight)
#     # m.weight.data.mul_(torch.tensor([1]))
#     if m.bias is not None:
#         m.bias.data.mul_(0)
#     Wff.append(m)

# # Feedback Synapses Initialization
# Wfb = torch.nn.ModuleList()
# for idx in range(len(architecture)-1):
#     m = torch.nn.Linear(architecture[idx+1], architecture[idx], bias=True)
#     torch.nn.init.xavier_uniform_(m.weight)
#     # m.weight.data.mul_(torch.tensor([1]))
#     if m.bias is not None:
#         m.bias.data.mul_(0)
#     Wfb.append(m)
    
# # Lateral Synapses Initialization
# B = torch.nn.ModuleList()
# for idx in range(1,len(architecture)-1):
#     m = torch.nn.Linear(architecture[idx], architecture[idx], bias = False)
#     torch.nn.init.xavier_uniform_(m.weight)
#     m.weight.data = m.weight.data @ m.weight.data.T
#     B.append(m)
    
# Feedforward Synapses Initialization
Wff = []
for idx in range(len(architecture)-1):
    weight = torch.randn(architecture[idx + 1], architecture[idx], requires_grad = False).to(device)
    torch.nn.init.xavier_uniform_(weight)
    bias = torch.zeros(architecture[idx + 1], 1, requires_grad = False).to(device)
    Wff.append({'weight': weight, 'bias': bias})
    
# Feedback Synapses Initialization
Wfb = []
for idx in range(len(architecture)-1):
    weight = torch.randn(architecture[idx], architecture[idx + 1], requires_grad = False).to(device)
    torch.nn.init.xavier_uniform_(weight)
    bias = torch.zeros(architecture[idx], 1, requires_grad = False).to(device)
    Wfb.append({'weight': weight, 'bias': bias})
    
# Lateral Synapses Initialization
B = []
for idx in range(len(architecture)-1):
    weight = torch.randn(architecture[idx + 1], architecture[idx + 1], requires_grad = False).to(device)
    torch.nn.init.xavier_uniform_(weight)
    weight = weight @ weight.T
    B.append({'weight': weight})
    
    
# Wff = Wff.to(device)
# Wfb = Wfb.to(device)
# B = B.to(device)

a1 = 1
a2 = 1
b1 = 1
b2 = 1
beta = 0

lambda_h = 1e-2
lambda_y = 1e-2
epsilon = 1e-1
one_over_epsilon = 1 / epsilon

In [None]:
def init_neurons(mbs, architecture, random_initialize = False, device = 'cuda'):
    # Initializing the neurons
    if random_initialize:
        neurons = []
        append = neurons.append
        for size in architecture[1:]:  
            append(torch.randn((mbs, size), requires_grad=False, device=device).T)       
    else:
        neurons = []
        append = neurons.append
        for size in architecture[1:]:  
            append(torch.zeros((mbs, size), requires_grad=False, device=device).T)
    return neurons

def calculate_neural_dynamics_grad(x, h, y_hat, Wff, Wfb, B, one_over_epsilon, lambda_h, lambda_y, beta):

    grad_h = 0.5*(one_over_epsilon * Wfb[0]['weight'].T @ (x - (Wfb[0]['weight'] @ h + Wfb[0]['bias'])) + 
         ((1 - lambda_h) / lambda_h) * B[0]['weight'] @ h -
         one_over_epsilon * (h - (Wff[0]['weight'] @ x - Wff[0]['bias'])))

    grad_y = 0.5*(one_over_epsilon * Wfb[1]['weight'].T @ (h - (Wfb[1]['weight'] @ y_hat + Wfb[1]['bias'])) +
         ((1 - lambda_y) / lambda_y) * B[1]['weight'] @ y_hat - 
         one_over_epsilon * (y_hat - (Wff[1]['weight'] @ h + Wff[1]['bias']))) + 2 * beta * (y - y_hat)
    
    return grad_h, grad_y

In [None]:
x,y = next(iter(train_loader))
x = x.view(x.size(0),-1).to(device).T
y_one_hot = F.one_hot(y, 10).to(device).T

h,y_hat = init_neurons(x.size(1), architecture, random_initialize = True, device = 'cuda')

In [None]:
x.shape, h.shape, y_hat.shape, y_one_hot.shape

In [None]:
grad_h, grad_y = calculate_neural_dynamics_grad(x, h, y_hat, Wff, Wfb, B, one_over_epsilon, lambda_h, lambda_y, beta)

In [None]:
neural_dynamic_iterations = 20
neural_lr = 1e-3
for iter_count in range(neural_dynamic_iterations):
    with torch.no_grad():
        grad_h, grad_y = calculate_neural_dynamics_grad(x, h, y_hat, Wff, Wfb, B, one_over_epsilon, lambda_h, lambda_y, beta)
        h = activation(h - neural_lr * grad_h)
        y_hat = activation(y_hat - neural_lr * grad_y)
        break

In [None]:
Wfb[1]['weight'].shape

In [None]:
grad_h = 0.5*(one_over_epsilon * Wfb[0]['weight'].T @ (x - (Wfb[0]['weight'] @ h + Wfb[0]['bias'])) + 
         ((1 - lambda_h) / lambda_h) * B[0]['weight'] @ h -
         one_over_epsilon * (h - (Wff[0]['weight'] @ x - Wff[0]['bias'])))

In [None]:
grad_y = 0.5*(one_over_epsilon * Wfb[1]['weight'].T @ (h - (Wfb[1]['weight'] @ y_hat + Wfb[1]['bias'])) +
         ((1 - lambda_y) / lambda_y) * B[1]['weight'] @ y_hat - 
         one_over_epsilon * (y_hat - (Wff[1]['weight'] @ h + Wff[1]['bias']))) + 2 * beta * (y - y_hat)

In [None]:
grad_h.shape, grad_y.shape

In [None]:
Wfb[0].weight.data.shape

In [None]:
(Wfb[0].weight.data @ h.T).shape

In [None]:
Wfb[0].bias.data.shape

In [None]:
torch.norm(Wfb[0](h) - (Wfb[0].weight.data @ h.T + Wfb[0].bias.data.view(-1,1)).T)

In [None]:
grad_h.shape

In [None]:
Wff[0](x) - (Wff[0].weight.data @ x.T).T

In [None]:
x.shape, my_x.shape

In [None]:
W[0].weight

In [None]:
torch.norm(x @ W[0].weight.data.T - W[0](x))

In [None]:
list_of_dict = [{'weight': np.ones(3), 'bias': np.array([0])}, {'weight': np.zeros(3), 'bias': np.array([1])}]

In [None]:
list_of_dict[0]['weight'] = list_of_dict[0]['weight'] - 1e-3 * np.ones(3)

In [None]:
list_of_dict

In [None]:
class foo:
    def __init__(self, l):
        self.l = l
        
    def step(self):
        list_of_dict[0]['weight'] = list_of_dict[0]['weight'] - 1e-3 * np.ones(3)

In [None]:
myfoo = foo(list_of_dict)

In [None]:
myfoo.l

In [None]:
myfoo.step()

In [None]:
myfoo.l