Single Run

In [None]:
import wandb
import torch, wandb, os
import torch.optim as optim
from torch.nn.functional import cross_entropy
from functools import partial

from src.RandmanFunctions import RandmanConfig, split_and_load
from src.Models import RandmanSNN, spike_regularized_cross_entropy
from src.EvolutionAlgorithms.EvolutionStrategy import ESModel
from src.EvolutionAlgorithms.PseudoPSO import PPSOModel, PPSOModelWithPooling
from src.SurrogateGD.VanillaSGD import SGDModel
from src.Training import train_loop_snn
from src.Utilities import init_result_csv, set_seed
from src.LandscapeAnalysis.LossSurfacePlotter import LossSurfacePlotter

device = 'cuda'
loss_fn_spk = lambda logits, y, spikes: spike_regularized_cross_entropy(logits, y, spikes, lambda_reg=1e-3)

@torch.no_grad()
def train_snn():
    run_name = "ppso"
    config = {  # Dataset:
        "nb_input": 10,
        "nb_output": 10,
        "nb_steps": 50,
        "nb_data_samples": 1000,
        "dim_manifold": 2,
        "alpha": 2.0,
        # SNN:
        "nb_hidden": 100,
        "learn_beta": False,
        # Evolution Strategy:
        "nb_model_samples": 100,
        "mirror": True,
        # Training:
        "std": 0.1,
        "epochs": 5,
        "batch_size": 256,
        "method": "ppso",
        # Optimization:
        "loss": "loss_fn",
        "optimizer": "Adam",
        "lr": 0.01,
        "regularization": "none",
    }
    with wandb.init(
        entity="DarwinNeuron", project="big-sweep-test", name=run_name, config=config
    ) as run:
        # update current run_name
        keys = ["method", "std", "batch_size", "lr", "nb_model_samples", "loss"]
        sorted_items = [f"{getattr(run.config, k)}" for k in sorted(keys)]
        run.name = "-".join(sorted_items)
        
        # setting up local csv recording (optional)
        result_path, _, _ = init_result_csv(dict(run.config), run.project)

        my_model = PPSOModelWithPooling(
                RandmanSNN,
                run.config.nb_input, 
                run.config.nb_hidden, 
                run.config.nb_output, 
                run.config.learn_beta, 
                beta=0.95,
                sample_size=run.config.nb_model_samples,
                param_std=run.config.std,
                lr=run.config.lr,
                device=device,
                mirror=run.config.mirror,
                acc_threshold=0.90,
                topk_ratio=0.25
            )

        # load dataset
        train_loader, val_loader = split_and_load(
            RandmanConfig(
                nb_classes=run.config.nb_output,
                nb_units=run.config.nb_input,
                nb_steps=run.config.nb_steps,
                nb_samples=run.config.nb_data_samples,
                dim_manifold=run.config.dim_manifold,
                alpha=run.config.alpha,
            ).read_dataset(),
            run.config.batch_size,
        )
        
        # loss surface plotter
        plotter_dir = f"results/{run.project}/runs/{run.id}/"
        os.makedirs(plotter_dir, exist_ok=True)
        loss_plotter = LossSurfacePlotter(plotter_dir+"illuminated_loss_surface.npz")

        # epochs
        for epoch in range(run.config.epochs):
            print(f"Epoch {epoch}\n-------------------------------")

            # train the model
            train_loop_snn(my_model, train_loader, val_loader, loss_fn_spk, device, run, epoch, loss_plotter=None)

train_snn()

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mwx2178[0m ([33mRNNFNNneuron[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


filepath data/randman/a818832ebb0748ed8125b520de22a658.pt
Epoch 0
-------------------------------
batch 0, loss: 2.293459, accuracy: 7.8%
Accuracy: 9.2%, Avg loss: 2.299597 

batch 1, loss: 2.278652, accuracy: 15.2%
Accuracy: 12.5%, Avg loss: 2.292703 



Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x1508eee524a0>>
Traceback (most recent call last):
  File "/scratch/wx2178/.conda/envs/cns/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 


batch 2, loss: 2.288794, accuracy: 14.8%
Accuracy: 12.5%, Avg loss: 2.292703 

batch 3, loss: 2.300711, accuracy: 12.9%
Accuracy: 12.5%, Avg loss: 2.292703 

batch 4, loss: 2.296641, accuracy: 12.5%
Accuracy: 12.5%, Avg loss: 2.292703 

batch 5, loss: 2.275235, accuracy: 11.3%
Accuracy: 12.2%, Avg loss: 2.280998 

batch 6, loss: 2.272424, accuracy: 16.0%
Accuracy: 14.5%, Avg loss: 2.288818 

batch 7, loss: 2.261144, accuracy: 15.6%
Accuracy: 14.6%, Avg loss: 2.286752 

batch 8, loss: 2.270815, accuracy: 15.6%
Accuracy: 14.6%, Avg loss: 2.286752 

