In [2]:
import torch
import matplotlib.pyplot as plt
import EIANN.utils as ut
import EIANN.plot as pt

from tqdm.autonotebook import tqdm
import numpy as np
import matplotlib.gridspec as gs
%load_ext autoreload
%autoreload 2

pt.update_plot_defaults()

train_dataloader, train_sub_dataloader, val_dataloader, test_dataloader, data_generator = ut.get_MNIST_dataloaders(sub_dataloader_size=20_000)

## Figure 1. Hebbian network comparison

### Van BP  **|**  Unsup. HebbWeightNorm  **|**  Top-layer sup. HebbWeightNorm  **|**  Dale BP (fully learned E+I)

In [37]:
# Load VanBP model
config_path = "../network_config/mnist/20231120_EIANN_1_hidden_mnist_van_bp_relu_SGD_config_G_optimized.yaml"
saved_network_path = "../data/mnist/20231120_EIANN_1_hidden_mnist_van_bp_relu_SGD_config_G_66049_257.pkl"
bp_network = ut.build_EIANN_from_config(config_path, network_seed=66049)
bp_network.load(saved_network_path)

# Load (U)HebbWeightNorm model
config_path = "../network_config/mnist/20231025_EIANN_1_hidden_mnist_Gjorgjieva_Hebb_config_F_optimized.yaml"
saved_network_path = "saved_networks/20231025_EIANN_1_hidden_mnist_Gjorgjieva_Hebb_config_F_66049_257_retrained.pkl"
ugj_network = ut.build_EIANN_from_config(config_path, network_seed=66049)
ut.rename_population(ugj_network, 'I', 'SomaI')
ugj_network.load(saved_network_path)

# Load (S)HebbWeightNorm model
config_path = "../network_config/mnist/20231025_EIANN_1_hidden_mnist_Supervised_Gjorgjieva_Hebb_config_F_optimized.yaml"
saved_network_path = "../data/mnist/20230505_EIANN_1_hidden_mnist_Supervised_Gjorgjieva_Hebb_config_F_66049_257.pkl"
gj_network = ut.build_EIANN_from_config(config_path, network_seed=66049)
ut.rename_population(gj_network, 'I', 'SomaI')
gj_network.load(saved_network_path)

# Load bpDale (full) model
config_path = "../network_config/mnist/20231018_EIANN_1_hidden_mnist_bpDale_relu_SGD_config_G_optimized.yaml"
saved_network_path = "saved_networks/20230815_EIANN_1_hidden_mnist_bpDale_softplus_SGD_config_G_66050_257_retrained.pkl"
bpDale_network = ut.build_EIANN_from_config(config_path, network_seed=66049)
bpDale_network.load(saved_network_path)

Epochs:   0%|          | 0/1 [00:00<?, ?it/s]

Samples:   0%|          | 0/20000 [00:00<?, ?it/s]

Model saved to saved_networks/20230815_EIANN_1_hidden_mnist_bpDale_softplus_SGD_config_G_66050_257_retrained.pkl
Loading model data from 'saved_networks/20230815_EIANN_1_hidden_mnist_bpDale_softplus_SGD_config_G_66050_257_retrained.pkl'...
Loading parameters into the network...
Model successfully loaded from 'saved_networks/20230815_EIANN_1_hidden_mnist_bpDale_softplus_SGD_config_G_66050_257_retrained.pkl'


In [38]:
fig, ax = plt.subplots(1, 4, figsize=(20, 5))

# Plot example receptive fields for each network
# VanBP
pt.plot_hidden_weights(bp_network.module_dict['H1E_InputE'].weight, sort=True, max_units=10, ax=ax[0])

Network(
  (criterion): MSELoss()
  (module_dict): ModuleDict(
    (H1E_InputE): Projection(in_features=784, out_features=500, bias=False)
    (H1E_H1SomaI): Projection(in_features=50, out_features=500, bias=False)
    (H1SomaI_InputE): Projection(in_features=784, out_features=50, bias=False)
    (H1SomaI_H1E): Projection(in_features=500, out_features=50, bias=False)
    (H1SomaI_H1SomaI): Projection(in_features=50, out_features=50, bias=False)
    (OutputE_H1E): Projection(in_features=500, out_features=10, bias=False)
    (OutputE_OutputSomaI): Projection(in_features=10, out_features=10, bias=False)
    (OutputSomaI_H1E): Projection(in_features=500, out_features=10, bias=False)
    (OutputSomaI_OutputE): Projection(in_features=10, out_features=10, bias=False)
    (OutputSomaI_OutputSomaI): Projection(in_features=10, out_features=10, bias=False)
  )
  (parameter_dict): ParameterDict(
      (H1E_bias): Parameter containing: [torch.FloatTensor of size 500]
      (H1SomaI_bias): Parameter

In [30]:
for i in bpDale_network.state_dict().keys():
    print(i)

module_dict.H1E_InputE.weight
module_dict.H1E_H1SomaI.weight
module_dict.H1SomaI_InputE.weight
module_dict.H1SomaI_H1E.weight
module_dict.H1SomaI_H1SomaI.weight
module_dict.OutputE_H1E.weight
module_dict.OutputE_OutputSomaI.weight
module_dict.OutputSomaI_H1E.weight
module_dict.OutputSomaI_OutputE.weight
module_dict.OutputSomaI_OutputSomaI.weight
parameter_dict.H1E_bias
parameter_dict.H1SomaI_bias
parameter_dict.OutputE_bias
parameter_dict.OutputSomaI_bias


In [18]:
bp_network.state_dict().keys()

odict_keys(['module_dict.H1E_InputE.weight', 'module_dict.OutputE_H1E.weight', 'projections.H1E_InputE.weight', 'projections.OutputE_H1E.weight', 'parameter_dict.H1E_bias', 'parameter_dict.OutputE_bias'])

## Figure 2. Comparison with fixed I weights (only learn E)

### Dale BP  **|**  Top-layer sup. HebbWeightNorm  **|**  unsup. BCM  **|**  Top-layer sup. BCM 

## Figure 3. 

### BP-like-1 (perfect dendI subtraction)  **|**  BP-like-2 (learned dendI)  **|**  Hebb (?)  **|**  BCM (?)

## Figure 4. BTSP variants

### BTSP (Burstprop-like with multiplexing) **|**