In [None]:
import wandb
import torch, wandb, os
import torch.optim as optim
from torch.nn.functional import cross_entropy
from functools import partial

from src.RandmanFunctions import RandmanConfig, split_and_load
from src.Models import RandmanSNN, spike_regularized_cross_entropy
from src.EvolutionAlgorithms.EvolutionStrategy import ESModel
from src.EvolutionAlgorithms.PseudoPSO import PPSOModel, PPSOModelWithPooling
from src.SurrogateGD.VanillaSGD import SGDModel
from src.Training import train_loop_snn
from src.Utilities import init_result_csv, set_seed
from src.LandscapeAnalysis.LossSurfacePlotter import LossSurfacePlotter

device = 'cuda'
loss_fn_spk = lambda logits, y, spikes: spike_regularized_cross_entropy(logits, y, spikes, lambda_reg=1e-3)

@torch.no_grad()
def train_snn():
    run_name = "es"
    config = {  # Dataset:
        "nb_input": 10,
        "nb_output": 10,
        "nb_steps": 50,
        "nb_data_samples": 1000,
        "dim_manifold": 2,
        "alpha": 2.0,
        # SNN:
        "nb_hidden": 100,
        "learn_beta": False,
        # Evolution Strategy:
        "nb_model_samples": 100,
        "mirror": True,
        # Training:
        "std": 0.1,
        "epochs": 5,
        "batch_size": 256,
        "method": "es",
        # Optimization:
        "loss": "loss_fn",
        "optimizer": "Adam",
        "lr": 0.01,
        "regularization": "none",
    }
    with wandb.init(
        entity="DarwinNeuron", project="big-sweep-test", name=run_name, config=config
    ) as run:
        # update current run_name
        keys = ["method", "std", "batch_size", "lr", "nb_model_samples", "loss", "nb_input"]
        sorted_items = [f"{getattr(run.config, k)}" for k in sorted(keys)]
        run.name = "-".join(sorted_items)
        
        # setting up local csv recording (optional)
        result_path, _, _ = init_result_csv(dict(run.config), run.project)

        my_model = ESModel(
            RandmanSNN,
            run.config.nb_input,
            run.config.nb_hidden,
            run.config.nb_output,
            run.config.learn_beta, 
            sample_size=run.config.nb_model_samples,
            param_std=run.config.std,
            Optimizer=optim.Adam,
            lr=run.config.lr,
            device=device,
            mirror=run.config.mirror,
        )

        # load dataset
        train_loader, val_loader = split_and_load(
            RandmanConfig(
                nb_classes=run.config.nb_output,
                nb_units=run.config.nb_input,
                nb_steps=run.config.nb_steps,
                nb_samples=run.config.nb_data_samples,
                dim_manifold=run.config.dim_manifold,
                alpha=run.config.alpha,
            ).read_dataset(),
            run.config.batch_size,
        )
        
        # loss surface plotter
        plotter_dir = f"results/{run.project}/runs/{run.id}/"
        os.makedirs(plotter_dir, exist_ok=True)
        loss_plotter = LossSurfacePlotter(plotter_dir+"illuminated_loss_surface.npz")

        # epochs
        for epoch in range(run.config.epochs):
            print(f"Epoch {epoch}\n-------------------------------")

            # train the model
            train_loop_snn(my_model, train_loader, val_loader, loss_fn_spk, device, run, epoch, loss_plotter=None)

train_snn()

Epoch 0
-------------------------------
batch 0, loss: 2.300166, accuracy: 10.5%
Accuracy: 10.7%, Avg loss: 2.301792 

batch 1, loss: 2.301507, accuracy: 9.8%
Accuracy: 10.3%, Avg loss: 2.302051 

batch 2, loss: 2.302517, accuracy: 9.4%
Accuracy: 10.4%, Avg loss: 2.301648 

batch 3, loss: 2.298573, accuracy: 12.9%
Accuracy: 10.5%, Avg loss: 2.301483 

batch 4, loss: 2.302034, accuracy: 12.5%
Accuracy: 10.4%, Avg loss: 2.301721 

batch 5, loss: 2.299236, accuracy: 9.8%
Accuracy: 10.5%, Avg loss: 2.301780 

batch 6, loss: 2.300574, accuracy: 11.3%
Accuracy: 10.5%, Avg loss: 2.301749 

batch 7, loss: 2.302858, accuracy: 8.6%
Accuracy: 10.5%, Avg loss: 2.301686 

batch 8, loss: 2.301569, accuracy: 10.9%
Accuracy: 10.7%, Avg loss: 2.301609 

batch 9, loss: 2.302341, accuracy: 9.0%
Accuracy: 10.7%, Avg loss: 2.301648 

batch 10, loss: 2.301835, accuracy: 9.8%
Accuracy: 10.7%, Avg loss: 2.301631 

batch 11, loss: 2.301969, accuracy: 13.3%
Accuracy: 10.7%, Avg loss: 2.301694 

batch 12, loss: 