In [None]:
import wandb
import torch, wandb, os
import torch.optim as optim
from torch.nn.functional import cross_entropy
from functools import partial

from src.RandmanFunctions import RandmanConfig, split_and_load
from src.Models import RandmanSNN, spike_regularized_cross_entropy, RandmanSNNConfig
from src.EvolutionAlgorithms.EvolutionStrategy import ESModel
from src.EvolutionAlgorithms.PseudoPSO import PPSOModel, PPSOModelWithPooling
from src.SurrogateGD.VanillaSGD import SGDModel
from src.Training import train_loop_snn
from src.Utilities import init_result_csv, set_seed
# from src.LandscapeAnalysis import LossSurfacePlotter
from src.LandscapeAnalysis.LossSurfacePlotter import LossSurfacePlotter

device = 'cuda'
loss_fn_spk = lambda logits, y, spikes: spike_regularized_cross_entropy(logits, y, spikes, lambda_reg=1e-3)

# @torch.no_grad()
def train_snn():
    run_name = "SGD-test"
    config = {  # Dataset:
        "nb_input": 10,
        "nb_output": 2,
        "nb_steps": 50,
        "nb_data_samples": 1000,
        "dim_manifold": 1,
        "alpha": 3.0,
        # SNN:
        "nb_hidden_1": 30,
        "nb_hidden_2": None,
        "learn_beta": True,
        "recurrent": False,
        # Evolution Strategy:
        "nb_model_samples": 100,
        "mirror": True,
        # Training:
        "std": 0.1,
        "epochs": 20,
        "batch_size": 256,
        "method": "SGD",
        # Optimization:
        "loss": "loss_fn_spk", #loss_fn_spk
        "optimizer": "Adam",
        "lr": 0.01,
        "regularization": "none",
    }
    with wandb.init(
        entity="DarwinNeuron", project="Test", name=run_name, config=config
    ) as run:
        # update current run_name
        keys = ["method", "std", "batch_size", "lr", "nb_model_samples", "loss","nb_input"]
        sorted_items = [f"{getattr(run.config, k)}" for k in sorted(keys)]
        run.name = "-".join(sorted_items)
        
        # setting up local csv recording (optional)
        result_path, _, _ = init_result_csv(dict(run.config), run.project)

        # initialize Evolution Strategy instance
        my_model = SGDModel(
                RandmanSNN,
                run.config.nb_input,
                run.config.nb_output,
                snn_config=RandmanSNNConfig(
                    nb_hidden_1=run.config.nb_hidden_1,
                    nb_hidden_2=run.config.nb_hidden_2,
                    beta=0.9,
                    learn_beta=run.config.learn_beta,
                    recurrent=run.config.recurrent,
                    parameter_type="weights"
                ),
                Optimizer=optim.Adam,
                lr=run.config.lr,
                device=device,
                spike_grad=None,  # can change based on config.spike_grad
            )

        # load dataset
        train_loader, val_loader = split_and_load(
            RandmanConfig(
                nb_classes=run.config.nb_output,
                nb_units=run.config.nb_input,
                nb_steps=run.config.nb_steps,
                nb_samples=run.config.nb_data_samples,
                dim_manifold=run.config.dim_manifold,
                alpha=run.config.alpha,
            ).read_dataset(),
            run.config.batch_size,
        )
        
        # loss surface plotter
        plotter_dir = f"results/{run.project}/runs/{run.id}/"
        os.makedirs(plotter_dir, exist_ok=True)
        loss_plotter = LossSurfacePlotter(plotter_dir+"illuminated_loss_surface.npz")

        # epochs
        for epoch in range(run.config.epochs):
            print(f"Epoch {epoch}\n-------------------------------")

            # train the model
            train_loop_snn(my_model, train_loader, val_loader, loss_fn_spk, device, run, epoch, loss_plotter=None)

train_snn()