In [None]:
import torch
from torch.utils.data import DataLoader
from pprint import pprint
import matplotlib.pyplot as plt

from EIANN import Network
import EIANN.utils as ut
import EIANN.plot as pt

pt.update_plot_defaults()

- Export the weights along the trajectory for both models
- combine
- PCA
- interpolate (+- 10%)
- create set of weights
- probe loss of both networks
    - convert BP net to ReLU before computing loss
- plot 2D loss landscape as 3D heatmap (including 2 actual trajectories, plot separately)

In [None]:
input_size = 21
dataset = torch.eye(input_size) #each row is a different pattern
target = torch.eye(dataset.shape[0])

data_seed = 0
data_generator = torch.Generator()
sample_indexes = torch.arange(len(dataset))
dataloader = DataLoader(list(zip(sample_indexes, dataset, target)), 
                        shuffle=True, 
                        generator=data_generator)

test_dataloader = DataLoader(list(zip(sample_indexes, dataset, target)), 
                             batch_size = 21)
epochs=30
seed=42

### Gjorgjieva learning rule

In [None]:
network_config = ut.read_from_yaml('../optimize/data/20220902_EIANN_1_hidden_Gjorgieva_Hebb_config_A.yaml')
layer_config = network_config['layer_config']
projection_config = network_config['projection_config']
training_kwargs = network_config['training_kwargs']

gj_network = Network(layer_config, projection_config, seed=seed, **training_kwargs)

In [None]:
data_generator.manual_seed(data_seed)
gj_network.train(dataloader, epochs, store_history=True, store_weights=True, status_bar=True)

pt.plot_test_loss_history(gj_network, test_dataloader)

In [None]:
min_loss_sorting = ut.get_optimal_sorting(gj_network, test_dataloader)
ut.recompute_history(gj_network, min_loss_sorting)

pt.plot_test_loss_history(gj_network, test_dataloader)

In [None]:
flat_param_history_gj,_ = pt.get_flat_param_history(gj_network)
pt.plot_param_history_PCs(flat_param_history_gj)

In [None]:
pt.plot_loss_landscape(test_dataloader, gj_network, num_points=20)

### Backprop (softplus; SGD)

In [None]:
network_config = ut.read_from_yaml('../config/EIANN_1_hidden_backprop_softplus_SGD_matched_config.yaml')
layer_config = network_config['layer_config']
projection_config = network_config['projection_config']
training_kwargs = network_config['training_kwargs']

bp_network = Network(layer_config, projection_config, seed=seed, **training_kwargs)

In [None]:
# Initialize backprop net with same weights as Gjorg. init
gj_initial_state = gj_network.param_history[0]
bp_network.load_state_dict(gj_initial_state)

In [None]:
data_generator.manual_seed(data_seed)
bp_network.train(dataloader, epochs, store_history=True, store_weights=True, status_bar=True)

pt.plot_test_loss_history(bp_network, test_dataloader)

In [None]:
flat_param_history_bp,_ = pt.get_flat_param_history(bp_network)
pt.plot_param_history_PCs(flat_param_history_bp)

In [None]:
pt.plot_loss_landscape(test_dataloader, bp_network, num_points=20)

## Combined loss landscape

In [None]:
flat_param_history1,_ = pt.get_flat_param_history(gj_network)
flat_param_history2,_ = pt.get_flat_param_history(bp_network)
combined_param_history = torch.cat([flat_param_history1,flat_param_history2])
pt.plot_param_history_PCs(combined_param_history)

In [None]:
gj_network.name = 'Gjorgjieva'
bp_network.name = 'Backprop'
pt.plot_loss_landscape(test_dataloader, gj_network, bp_network, num_points=20, extension=0.5)

In [None]:
# turn notebooks into pytest scripts

# add hardcoded network reference to output population (e.g network.output_pop)