In [1]:
import wandb
import torch, wandb, os
import torch.optim as optim
from torch.nn.functional import cross_entropy
from functools import partial

from src.RandmanFunctions import RandmanConfig, split_and_load
from src.Models import RandmanSNN, spike_regularized_cross_entropy
from src.EvolutionAlgorithms.EvolutionStrategy import ESModel
from src.EvolutionAlgorithms.PseudoPSO import PPSOModel, PPSOModelWithPooling
from src.SurrogateGD.VanillaSGD import SGDModel
from src.Training import train_loop_snn
from src.Utilities import init_result_csv, set_seed
# from src.LandscapeAnalysis import LossSurfacePlotter
from src.LandscapeAnalysis.LossSurfacePlotter import LossSurfacePlotter

device = 'cuda'
loss_fn_spk = lambda logits, y, spikes: spike_regularized_cross_entropy(logits, y, spikes, lambda_reg=1e-3)

# @torch.no_grad()
def train_snn():
    run_name = "SGD-test"
    config = {  # Dataset:
        "nb_input": 10,
        "nb_output": 10,
        "nb_steps": 50,
        "nb_data_samples": 1000,
        "dim_manifold": 2,
        "alpha": 2.0,
        # SNN:
        "nb_hidden": 100,
        "learn_beta": False,
        "recurrent": False,
        # Evolution Strategy:
        "nb_model_samples": 100,
        "mirror": True,
        # Training:
        "std": 0.1,
        "epochs": 5,
        "batch_size": 256,
        "method": "SGD",
        # Optimization:
        "loss": "loss_fn_spk", #loss_fn_spk
        "optimizer": "Adam",
        "lr": 0.01,
        "regularization": "none",
    }
    with wandb.init(
        entity="DarwinNeuron", project="Test", name=run_name, config=config
    ) as run:
        # update current run_name
        keys = ["method", "std", "batch_size", "lr", "nb_model_samples", "loss","nb_input"]
        sorted_items = [f"{getattr(run.config, k)}" for k in sorted(keys)]
        run.name = "-".join(sorted_items)
        
        # setting up local csv recording (optional)
        result_path, _, _ = init_result_csv(dict(run.config), run.project)

        # initialize Evolution Strategy instance
        my_model = SGDModel(
                RandmanSNN,
                run.config.nb_input,
                run.config.nb_hidden,
                run.config.nb_output,
                run.config.learn_beta,
                spike_grad=None,   # ← default surrogate gradient
                recurrent=run.config.recurrent,
                Optimizer=optim.Adam,
                lr=run.config.lr,
                device=device,
            )

        # load dataset
        train_loader, val_loader = split_and_load(
            RandmanConfig(
                nb_classes=run.config.nb_output,
                nb_units=run.config.nb_input,
                nb_steps=run.config.nb_steps,
                nb_samples=run.config.nb_data_samples,
                dim_manifold=run.config.dim_manifold,
                alpha=run.config.alpha,
            ).read_dataset(),
            run.config.batch_size,
        )
        
        # loss surface plotter
        plotter_dir = f"results/{run.project}/runs/{run.id}/"
        os.makedirs(plotter_dir, exist_ok=True)
        loss_plotter = LossSurfacePlotter(plotter_dir+"illuminated_loss_surface.npz")

        # epochs
        for epoch in range(run.config.epochs):
            print(f"Epoch {epoch}\n-------------------------------")

            # train the model
            train_loop_snn(my_model, train_loader, val_loader, loss_fn_spk, device, run, epoch, loss_plotter=None)

train_snn()

[34m[1mwandb[0m: Currently logged in as: [33myixing[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


filepath data/randman\a818832ebb0748ed8125b520de22a658.pt
Epoch 0
-------------------------------
batch 0, loss: 2.302585, accuracy: 10.5%
Accuracy: 10.2%, Avg loss: 2.302585 

batch 1, loss: 2.302585, accuracy: 7.0%
Accuracy: 10.2%, Avg loss: 2.302585 

batch 2, loss: 2.302585, accuracy: 12.5%
Accuracy: 10.2%, Avg loss: 2.302585 

batch 3, loss: 2.302585, accuracy: 7.8%
Accuracy: 10.2%, Avg loss: 2.302585 

batch 4, loss: 2.302585, accuracy: 10.2%
Accuracy: 10.2%, Avg loss: 2.302585 

batch 5, loss: 2.302585, accuracy: 11.3%
Accuracy: 10.2%, Avg loss: 2.302585 

batch 6, loss: 2.302585, accuracy: 11.3%
Accuracy: 10.2%, Avg loss: 2.302585 

batch 7, loss: 2.302585, accuracy: 9.4%
Accuracy: 10.2%, Avg loss: 2.302585 

batch 8, loss: 2.302585, accuracy: 7.8%
Accuracy: 10.2%, Avg loss: 2.302585 

batch 9, loss: 2.302585, accuracy: 12.1%
Accuracy: 10.2%, Avg loss: 2.302585 

batch 10, loss: 2.302585, accuracy: 9.4%
Accuracy: 10.2%, Avg loss: 2.302585 

batch 11, loss: 2.302585, accuracy: 7

0,1
epoch,▁▁▁▁▁▁▁▁▃▃▃▃▃▃▃▃▅▅▅▅▅▅▅▅▅▆▆▆▆▆▆█████████
epoch_val_acc,▁▁▁▃█
epoch_val_average_neuron_spikes,▁▁▁▂█
epoch_val_loss,███▇▁
epoch_val_spike_percentage,▁▁▁▂█
train_acc,▂▁▂▁▂▃▁▁▂▁▂▁▂▁▂▃▂▂▂▁▂▂▂▂▁▂▁▂▃▁▂▁▄▄▂▅▅▇▆█
train_average_neuron_spikes,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▃▄▅▆▆█
train_loss,███████████████████████████████▇█▆▆▅▃▃▂▁
train_spike_percentage,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▃▄▄▆█
val_acc,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▃▃▄▅▅▆█

0,1
epoch,4.0
epoch_val_acc,0.367
epoch_val_average_neuron_spikes,0.11569
epoch_val_loss,1.93399
epoch_val_spike_percentage,0.11462
train_acc,0.375
train_average_neuron_spikes,0.12545
train_loss,1.91589
train_spike_percentage,0.12402
val_acc,0.367
