In [None]:
%matplotlib inline

In [None]:
import os
import pprint
import torch
import torchvision
import torchvision.transforms as T
import numpy as np
import matplotlib.pyplot as plt
from EIANN import Network
import EIANN.utils as ut
import EIANN.plot as pt


# Load dataset
tensor_flatten = T.Compose([T.ToTensor(), T.Lambda(torch.flatten)])
MNIST_train_dataset = torchvision.datasets.MNIST(root='../data/mnist/datasets/MNIST_data/', train=True, download=True,
                                         transform=tensor_flatten)
MNIST_test_dataset = torchvision.datasets.MNIST(root='../data/mnist/datasets/MNIST_data/',
                                        train=False, download=True, transform=tensor_flatten)

# Add index to train & test data
MNIST_train = []
for idx,(data,target) in enumerate(MNIST_train_dataset):
    target = torch.eye(len(MNIST_train_dataset.classes))[target]
    MNIST_train.append((idx, data, target))
    
MNIST_test = []
for idx,(data,target) in enumerate(MNIST_test_dataset):
    target = torch.eye(len(MNIST_test_dataset.classes))[target]
    MNIST_test.append((idx, data, target))
    
# Put data in dataloader
data_generator = torch.Generator()
train_dataloader = torch.utils.data.DataLoader(MNIST_train[0:50000], shuffle=True, generator=data_generator)
val_dataloader = torch.utils.data.DataLoader(MNIST_train[-10000:], batch_size=10000, shuffle=False)
test_dataloader = torch.utils.data.DataLoader(MNIST_test, batch_size=10000, shuffle=False)

epochs = 1
data_seed = 257
network_seed = 66049
train_steps = 2000  # 20000
analyze_receptive_fields = None  # ['H2E']

In [None]:
def get_network(config_file_path, network_seed):
    network_config = ut.read_from_yaml(config_file_path)
    layer_config = network_config['layer_config']
    projection_config = network_config['projection_config']
    training_kwargs = network_config['training_kwargs']

    return Network(layer_config, projection_config, seed=network_seed, **training_kwargs)

In [None]:
def train_network_mnist(network, train_dataloader, val_dataloader, train_steps):
    network.train(train_dataloader, val_dataloader, samples_per_epoch=train_steps, val_interval=(0, -1, 100), store_history=True, \
                  store_history_interval=(0, -1, 100), status_bar=True)

In [None]:
def plot_mnist_network_intermediates(network, analyze_receptive_fields=None):
    if not ('H2E' in network.populations):
        return
    
    if hasattr(network.H2.E, 'plateau_history') and network.H2.E.plateau_history is not None:
        fig, axes = plt.subplots(4, 3, figsize=(10., 12.))

        binned_train_steps = network.val_history_train_steps
        xmin = binned_train_steps[0]
        xmax = binned_train_steps[-1] + 1
        ymax = 0
        
        ymin = network.H1.E.size
        im = axes[0][0].imshow(network.H1.E.plateau_history.T, aspect='auto', interpolation='none', extent=(xmin, xmax, ymin, ymax))
        plt.colorbar(im, ax=axes[0][0])
        axes[0][0].set_title('Plateau history: H1.E')
        axes[0][0].set_ylabel('Unit ID')
        axes[0][0].set_xlabel('Training steps')

        ymin = network.H2.E.size
        im = axes[0][1].imshow(network.H2.E.plateau_history.T, aspect='auto', interpolation='none', extent=(xmin, xmax, ymin, ymax))
        plt.colorbar(im, ax=axes[0][1])
        axes[0][1].set_title('Plateau history: H2.E')
        axes[0][1].set_ylabel('Unit ID')
        axes[0][1].set_xlabel('Training steps')

        ymin = network.Output.E.size
        im = axes[0][2].imshow(network.Output.E.plateau_history.T, aspect='auto', interpolation='none', extent=(xmin, xmax, ymin, ymax))
        plt.colorbar(im, ax=axes[0][2])
        axes[0][2].set_title('Plateau history: Output.E')
        axes[0][2].set_xlabel('Training steps')
        
        binned_Output_E_plateau_prob = network.Output.E.plateau_history.mean(dim=(1))
        binned_H1_E_plateau_prob = network.H1.E.plateau_history.mean(dim=(1))
        binned_H2_E_plateau_prob = network.H2.E.plateau_history.mean(dim=(1))
          
        axes[1][0].plot(binned_train_steps, binned_H1_E_plateau_prob)
        axes[1][1].plot(binned_train_steps, binned_H2_E_plateau_prob)
        axes[1][2].plot(binned_train_steps, binned_Output_E_plateau_prob)
        axes[1][0].set_ylabel('Mean plateau amp')
        axes[1][0].set_xlabel('Training steps')
        axes[1][1].set_xlabel('Training steps')
        axes[1][2].set_xlabel('Training steps')
        axes[1][0].set_title('Plateau_history: H1.E')
        axes[1][1].set_title('Plateau_history: H2.E')
        axes[1][2].set_title('Plateau_history: Output.E')
        
        binned_H1_E_forward_dend_state = network.H1.E.forward_dendritic_state_history.mean(dim=(1))
        binned_H1_E_backward_dend_state = network.H1.E.backward_dendritic_state_history.mean(dim=(1))
        
        axes[2][0].plot(binned_train_steps, binned_H1_E_forward_dend_state)
        axes[2][0].set_title('Forward dendritic state: H1.E')
        axes[2][0].set_xlabel('Training steps')
        axes[2][0].set_ylabel('Mean amplitude')
        
        axes[2][1].plot(binned_train_steps, binned_H1_E_backward_dend_state)
        axes[2][1].set_title('Backward dendritic state: H1.E')
        axes[2][1].set_xlabel('Training steps')
        
        binned_H2_E_forward_dend_state = network.H2.E.forward_dendritic_state_history.mean(dim=(1))
        binned_H2_E_backward_dend_state = network.H2.E.backward_dendritic_state_history.mean(dim=(1))
        
        axes[3][0].plot(binned_train_steps, binned_H2_E_forward_dend_state)
        axes[3][0].set_title('Forward dendritic state: H2.E')
        axes[3][0].set_xlabel('Training steps')
        axes[3][0].set_ylabel('Mean amplitude')
        
        axes[3][1].plot(binned_train_steps, binned_H2_E_backward_dend_state)
        axes[3][1].set_title('Backward dendritic state: H2.E')
        axes[3][1].set_xlabel('Training steps')
            
        fig.tight_layout()
        fig.show()
    
    if analyze_receptive_fields is not None:
        for population_name in analyze_receptive_fields:
            if population_name in network.populations:
                population = network.populations[population_name]
                receptive_fields = ut.compute_maxact_receptive_fields(population)
                if population is network.Output.E:
                    sort = False
                    num_rows = 1
                else:
                    sort = True
                    num_rows = 2
                pt.plot_receptive_fields(receptive_fields, sort=sort, num_cols=10, num_rows=num_rows, title=population_name)
                ut.compute_representation_metrics(population, test_dataloader, receptive_fields, plot=True)

In [None]:
def analyze_network_mnist(network, test_dataloader, analyze_receptive_fields=False):
    pt.plot_batch_accuracy(network, test_dataloader, population='all', sorted_output_idx=torch.arange(0, network.Output.E.size),
                          title='')
    pt.plot_train_loss_history(network)
    pt.plot_validate_loss_history(network)
    plot_mnist_network_intermediates(network, analyze_receptive_fields=analyze_receptive_fields)

## Feedforward ANN (Backprop)

In [None]:
# Create network

config_file_path = '../network_config/mnist/20231129_EIANN_2_hidden_mnist_van_bp_relu_SGD_config_G_complete_optimized.yaml'

data_generator.manual_seed(data_seed)

network = get_network(config_file_path, network_seed)

train_network_mnist(network, train_dataloader, val_dataloader, train_steps=train_steps)

analyze_network_mnist(network, test_dataloader, analyze_receptive_fields=analyze_receptive_fields)

FFANN_bp_network = network

## EIANN - Learned SomaI (Backprop)

In [None]:
# Create network

config_file_path = '../network_config/mnist/20240419_EIANN_2_hidden_mnist_bpDale_relu_SGD_config_F_complete_optimized.yaml'

data_generator.manual_seed(data_seed)

network = get_network(config_file_path, network_seed)

train_network_mnist(network, train_dataloader, val_dataloader, train_steps=train_steps)

analyze_network_mnist(network, test_dataloader, analyze_receptive_fields=analyze_receptive_fields)

EIANN_bp_learned_somaI_network = network

## EIANN - Fixed SomaI (Backprop)

In [None]:
# Create network

config_file_path = '../network_config/mnist/20231129_EIANN_2_hidden_mnist_bpDale_relu_SGD_config_G_complete_optimized.yaml'

data_generator.manual_seed(data_seed)

network = get_network(config_file_path, network_seed)

train_network_mnist(network, train_dataloader, val_dataloader, train_steps=train_steps)

analyze_network_mnist(network, test_dataloader, analyze_receptive_fields=analyze_receptive_fields)

EIANN_bp_fixed_somaI_network = network

## EIANN - Learned SomaI (Hebb, Top-Layer Supervised)

In [None]:
# Create network

config_file_path = '../network_config/mnist/20241105_EIANN_2_hidden_mnist_Top_Layer_Supervised_Hebb_WeightNorm_config_7_complete_optimized.yaml'

data_generator.manual_seed(data_seed)

network = get_network(config_file_path, network_seed)

train_network_mnist(network, train_dataloader, val_dataloader, train_steps=train_steps)

analyze_network_mnist(network, test_dataloader, analyze_receptive_fields=analyze_receptive_fields)

EIANN_hebb_network = network

## Dendritic Target Propagation (LDS, Top-down Weight Symmetry)

In [None]:
# Create network

config_file_path = '../network_config/mnist/20241009_EIANN_2_hidden_mnist_BP_like_config_5J_complete_optimized.yaml'

data_generator.manual_seed(data_seed)

network = get_network(config_file_path, network_seed)

train_network_mnist(network, train_dataloader, val_dataloader, train_steps=train_steps)

analyze_network_mnist(network, test_dataloader, analyze_receptive_fields=analyze_receptive_fields)

DTP_LDS_network = network

In [None]:
analyze_network_mnist(network, test_dataloader, analyze_receptive_fields=analyze_receptive_fields)

## Dendritic Target Propagation (BTSP, Top-down Weight Symmetry)

In [None]:
# Create network

config_file_path = '../network_config/mnist/20241212_EIANN_2_hidden_mnist_BTSP_config_5L_complete_optimized.yaml'

data_generator.manual_seed(data_seed)

network = get_network(config_file_path, network_seed)

train_network_mnist(network, train_dataloader, val_dataloader, train_steps=train_steps)

analyze_network_mnist(network, test_dataloader, analyze_receptive_fields=analyze_receptive_fields)

DTP_BTSP_network = network

## Dendritic Target Propagation (BTSP, Learned Top-down weights)

In [None]:
# Create network

config_file_path = '../network_config/mnist/20241216_EIANN_2_hidden_mnist_BTSP_config_5L_learn_TD_HTCWN_3_complete_optimized.yaml'

data_generator.manual_seed(data_seed)

network = get_network(config_file_path, network_seed)

train_network_mnist(network, train_dataloader, val_dataloader, train_steps=train_steps)

analyze_network_mnist(network, test_dataloader, analyze_receptive_fields=analyze_receptive_fields)

DTP_BTSP_TD_HTCWN_network = network