In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from pprint import pprint
import matplotlib.pyplot as plt

from EIANN import Network
import EIANN.utils as ut
import EIANN.plot as ep

ep.update_plot_defaults()

In [None]:
tensor_normalize = transforms.Compose([transforms.ToTensor()])
#                                       transforms.Normalize(mean=[0.1307],std=[0.3081])])

MNIST_train = torchvision.datasets.MNIST(root='../datasets/MNIST_data/',train=True,transform=tensor_normalize,download=True)
MNIST_train_sub = torch.utils.data.Subset(MNIST_train, range(1000))
MNIST_test = torchvision.datasets.MNIST(root='../datasets/MNIST_data/',train=False,transform=tensor_normalize,download=True)

In [None]:
flat_MNIST_train = []
for idx,(data,target) in enumerate(MNIST_train):
    data = data.flatten()
    target = torch.eye(len(MNIST_train.classes))[target]
    flat_MNIST_train.append((idx, data, target))

flat_MNIST_test = []
for idx,(data,target) in enumerate(MNIST_test):
    data = data.flatten()
    target = torch.eye(len(MNIST_test.classes))[target]
    flat_MNIST_test.append((idx, data, target))

data_generator = torch.Generator()
    
train_dataloader = torch.utils.data.DataLoader(flat_MNIST_train, shuffle=True, generator=data_generator)

train_sub_dataloader = torch.utils.data.DataLoader(flat_MNIST_train[:2000], shuffle=True, generator=data_generator)

test_dataloader = torch.utils.data.DataLoader(flat_MNIST_test, batch_size=10000, shuffle=False)

In [None]:
data_seed = 0
seed = 42

### Backprop (no weight contraints)

In [None]:
network_config = ut.read_from_yaml('../config/EIANN_2_hidden_mnist_backprop_relu_SGD_config.yaml')
pprint(network_config)

layer_config = network_config['layer_config']
projection_config = network_config['projection_config']
training_kwargs = network_config['training_kwargs']

network = Network(layer_config, projection_config, seed=seed, **training_kwargs)

In [None]:
data_generator.manual_seed(data_seed)
network.train(train_sub_dataloader,
              epochs=1, 
              store_history=False, 
              status_bar=True)

In [None]:
ep.plot_performance(network)
ep.plot_MNIST_examples(network, test_dataloader)
ut.compute_batch_accuracy(network, test_dataloader)

### Backprop (Dale's Law)

In [None]:
network_config = ut.read_from_yaml('../config/EIANN_2_hidden_mnist_backprop_Dale_relu_SGD_config.yaml')

layer_config = network_config['layer_config']
projection_config = network_config['projection_config']
training_kwargs = network_config['training_kwargs']

network = Network(layer_config, projection_config, seed=seed, **training_kwargs)

In [None]:
data_generator.manual_seed(data_seed)
network.train(train_sub_dataloader, 
              epochs=1, 
              store_history=True, 
              status_bar=True)

In [None]:
torch.any(network.Output.E.H2.E.weight.grad)

In [None]:
print(network.Output.E.activity_history.shape)

plt.figure()
for i in range(network.Output.E.size):
    plt.plot(network.Output.E.activity_history[-1,:,i], c='grey')
for i in range(network.Output.FBI.size):
    plt.plot(network.Output.FBI.activity_history[-1,:,i], c='r')

In [None]:
ep.plot_performance(network)
ep.plot_MNIST_examples(network, test_dataloader)
ut.compute_batch_accuracy(network, test_dataloader)

In [None]:
plt.figure()
plt.imshow(network.Output.E.activity.detach().numpy().T, aspect='auto')
plt.colorbar()

plt.figure()
plt.imshow(network.H2.E.activity.detach().numpy().T, aspect='auto')
plt.colorbar()

plt.figure()
plt.imshow(network.H1.E.activity.detach().numpy().T, aspect='auto')
plt.colorbar()

plt.figure()
plt.imshow(network.Input.E.activity.detach().numpy().T, aspect='auto')
plt.colorbar()

In [None]:
network.Output.E.activity.shape