In [1]:
import torch
import matplotlib.pyplot as plt
import EIANN.utils as ut
import EIANN.plot as pt
import EIANN._network as nt
import EIANN.generate_figures as gf

from tqdm.autonotebook import tqdm
import numpy as np
import matplotlib.gridspec as gs
%load_ext autoreload
%autoreload 2

pt.update_plot_defaults()

  from tqdm.autonotebook import tqdm


### MNIST Networks

In [None]:
train_dataloader, train_sub_dataloader, val_dataloader, test_dataloader, data_generator = ut.get_MNIST_dataloaders(sub_dataloader_size=20_000)

In [None]:
# Build network
network_name = "EIANN_1_hidden_mnist_vanBP"
config_path = f"../network_config/MNIST_templates/{network_name}.yaml"
saved_network_path = f"saved_networks/{network_name}.pkl"
bp_net = ut.build_EIANN_from_config(config_path, network_seed=66049)

# Train network
data_seed = 257
data_generator.manual_seed(data_seed)
bp_net.train(train_sub_dataloader, 
                test_dataloader, 
                epochs=1,
                val_interval=(0,-1,500),
                store_history=True, 
                store_params=True,
                status_bar=True)
ut.save_network(bp_net, path=saved_network_path)

bp_network = ut.load_network(saved_network_path)

In [None]:
plt.plot(bp_network.val_loss_history)

In [None]:
pt.plot_batch_accuracy(bp_network, test_dataloader, population='all')

### Spiral Networks

In [None]:
spiral_train_dataloader, spiral_val_dataloader, spiral_test_dataloader, spiral_data_generator = ut.get_spiral_dataloaders(points_per_spiral_arm=2000)

# Build network
network_name = "EIANN_2_hidden_spiral_bpDale_learned_bias"
config_path = f"../network_config/spiral/{network_name}.yaml" # For optimized network_config files
# config_path = f"../optimize/network_config/spiral/{network_name}.yaml" # For network_config files to be optimized
saved_network_path = f"../saved_networks/{network_name}.pkl"
spiral_net = ut.build_EIANN_from_config(config_path, network_seed=0)

# Train network
data_seed = 1
spiral_data_generator.manual_seed(data_seed)
spiral_net.train(spiral_train_dataloader, 
                spiral_test_dataloader, 
                epochs=1,
                val_interval=(0,-1,500),
                store_history=True,
                store_dynamics=True, 
                store_params=True,
                status_bar=True)
# ut.save_network(spiral_net, path=saved_network_path)
# network = ut.load_network(saved_network_path)

<torch._C.Generator at 0x14dc6838ef0>

In [None]:
pt.plot_batch_accuracy(spiral_net, spiral_test_dataloader, population='all')

In [None]:
pt.plot_spiral_accuracy(spiral_net, spiral_test_dataloader)

In [None]:
net1_params = set(bp_net.state_dict().keys())
net1_params

In [None]:
net2_params = set(spiral_net.state_dict().keys())
net2_params

In [None]:
for param1, param2 in zip(net1_params, net2_params):
    if param1 != param2:
        print(f"Param mismatch: {param1} != {param2}")

In [None]:
spiral_net2 = ut.build_EIANN_from_config(config_path, network_seed=0)

In [None]:
import EIANN.generate_figures as gf

In [None]:
net1_module_dict_keys = list(spiral_net.module_dict.keys())
net1_parameter_dict_keys = list(spiral_net.parameter_dict.keys())
net2_module_dict_keys = list(spiral_net2.module_dict.keys())
net2_parameter_dict_keys = list(spiral_net2.parameter_dict.keys())

In [None]:
net1_module_dict_keys
print(net1_module_dict_keys)
net2_module_dict_keys
print(net2_module_dict_keys)

In [None]:
net1_module_dict_keys == net2_module_dict_keys

In [None]:
net1_parameter_dict_keys == net2_parameter_dict_keys

In [None]:
net1_params = set(spiral_net.state_dict().keys())
net2_params = set(spiral_net2.state_dict().keys())

for key in spiral_net.state_dict():
    param1_shape = spiral_net.state_dict()[key].shape
    param2_shape = spiral_net2.state_dict()[key].shape
    if param1_shape != param2_shape:
        print(f'Mismatch in {key}: {param1_shape} != {param2_shape}')
print('No mismatches found')

In [None]:
# net1 = ut.build_EIANN_from_config("../network_config/spiral/20250108_EIANN_2_hidden_spiral_BP_like_1_fixed_SomaI_learned_bias_config_complete_optimized.yaml", network_seed=0)
# net1 = ut.build_EIANN_from_config(config_path, network_seed=0)
net1 = ut.build_EIANN_from_config('../network_config/spiral/20250108_EIANN_0_hidden_spiral_van_bp_relu_learned_bias_config_complete_optimized.yaml', network_seed=0)
net2 = ut.build_EIANN_from_config(config_path, network_seed=0)

In [None]:
gf.compare_networks(net1, net2)

In [None]:
net1_state_dict = net1.state_dict()
net2_state_dict = {name:param for name, param in net1_state_dict.items() if name in net2.state_dict()}

print(net1_state_dict.keys())
print(net2_state_dict.keys())

In [None]:
net1_params

In [None]:
net2_params

In [None]:
decision_data = ut.compute_spiral_decisions_data(spiral_net, spiral_test_dataloader)
pt.plot_spiral_decisions(decision_data, graph='scatter')

In [None]:
decision_data

In [None]:
pt.plot_network_dynamics(spiral_net)

In [None]:
percent_correct, average_pop_activity_dict = ut.compute_test_activity(spiral_net, spiral_test_dataloader, sort=True)

In [None]:
average_pop_activity_dict['InputE']

In [None]:
spiral_net.Input.E.activity_history

In [None]:
# Test batch inputs
inputs = spiral_net.Input.E.activity
inputs

In [None]:
# Predicted labels

outputs = spiral_net.Output.E.activity
_, predicted = torch.max(outputs, 1)
predicted

In [None]:
dataloader_iter = spiral_test_dataloader
on_device = False

for sample_idx, sample_data, sample_target in dataloader_iter:
    sample_data = torch.squeeze(sample_data)
    sample_target = torch.squeeze(sample_target)
    if not on_device:
        if sample_data.device == spiral_net.device:
            on_device = True
        else:
            sample_data = sample_data.to(spiral_net.device)
            sample_target = sample_target.to(spiral_net.device)

            print(sample_data)
            print(sample_target)
            
    break

# sample_target has the test labels

In [None]:
sample_target

In [None]:
# Test labels

_, test_labels = torch.max(sample_target, 1)
test_labels

In [None]:
# Check to see if we found right things (we did)

correct = (test_labels == predicted).sum().item()
total = test_labels.size(0)
accuracy = correct / total
accuracy

In [None]:
plt.figure()
plt.imshow(spiral_net.H1.E.forward_dendritic_state_history, aspect='auto', interpolation='none')
plt.colorbar()

In [None]:
plt.figure()
plt.plot(torch.mean(spiral_net.H1.E.forward_dendritic_state_history, dim=1))

In [None]:
plt.figure()
# plt.plot(torch.mean(torch.abs(spiral_bp.H2.E.forward_dendritic_state_history), dim=1))
plt.imshow(spiral_net.H2.E.forward_dendritic_state_history.T, aspect='auto', interpolation='none')
plt.colorbar()

In [None]:
plt.figure()
# plt.plot(torch.mean(spiral_bp.Output.E.plateau_history, dim=1))
plt.imshow(spiral_net.Output.E.plateau_history.T, aspect='auto', interpolation='none')
# plt.xlim(0, 1000)
plt.colorbar()

In [None]:
plt.figure()
# plt.plot(torch.mean(spiral_bp.Output.E.plateau_history, dim=1))
plt.imshow(spiral_net.Output.E.activity_history[-10:].T, aspect='auto', interpolation='none')
# plt.xlim(0, 10)
plt.colorbar()

In [None]:
plt.figure()
plt.plot(torch.mean(spiral_net.H2.E.H2.E.weight_history, dim=(1,2)))

In [None]:
spiral_net.H2.E.attribute_history_dict.keys()

In [None]:
plt.plot(spiral_net.val_loss_history)

In [None]:
plt.plot(spiral_net.loss_history)

In [None]:
index, data, label = next(iter(spiral_test_dataloader))

print(spiral_net.forward(data))
print(label)

In [None]:
plt.plot(label[:,1])

In [None]:
ut.compute_test_activity(spiral_net, spiral_test_dataloader, sort=False)

In [4]:
d = ut.read_from_yaml(config_path)
d

{'layer_config': {'Input': {'E': {'size': 2}},
  'H1': {'E': {'size': 128,
    'activation': 'relu',
    'include_bias': True,
    'bias_learning_rule': 'Backprop',
    'bias_learning_rule_kwargs': {'learning_rate': 0.0023723741273216116}},
   'SomaI': {'size': 32, 'activation': 'relu'}},
  'H2': {'E': {'size': 32,
    'activation': 'relu',
    'include_bias': True,
    'bias_learning_rule': 'Backprop',
    'bias_learning_rule_kwargs': {'learning_rate': 0.002384074673608417}},
   'SomaI': {'size': 8, 'activation': 'relu'}},
  'Output': {'E': {'size': 4,
    'activation': 'relu',
    'include_bias': True,
    'bias_learning_rule': 'Backprop',
    'bias_learning_rule_kwargs': {'learning_rate': 0.008060654677551906}},
   'SomaI': {'size': 4, 'activation': 'relu'}}},
 'projection_config': {'H1': {'E': {'Input': {'E': {'weight_init': 'half_kaiming',
      'weight_init_args': (0.45341851689677354,),
      'weight_bounds': [0, None],
      'direction': 'F',
      'learning_rule': 'Backprop',


In [6]:
d['projection_config']

{'H1': {'E': {'Input': {'E': {'weight_init': 'half_kaiming',
     'weight_init_args': (0.45341851689677354,),
     'weight_bounds': [0, None],
     'direction': 'F',
     'learning_rule': 'Backprop',
     'learning_rule_kwargs': {'learning_rate': 0.17515893364196825}}},
   'H1': {'SomaI': {'weight_init': 'half_kaiming',
     'weight_init_args': (0.07472398800493957,),
     'weight_bounds': [None, 0],
     'direction': 'R',
     'learning_rule': 'Backprop',
     'learning_rule_kwargs': {'learning_rate': 0.07836630590668701}}}},
  'SomaI': {'Input': {'E': {'weight_init': 'half_kaiming',
     'weight_init_args': (4.7663561644212695,),
     'weight_bounds': [0, None],
     'direction': 'F',
     'learning_rule': 'Backprop',
     'learning_rule_kwargs': {'learning_rate': 0.24925779245776264}}},
   'H1': {'E': {'weight_init': 'half_kaiming',
     'weight_init_args': (3.808885460644247,),
     'weight_bounds': [0, None],
     'direction': 'R',
     'learning_rule': 'Backprop',
     'learning_