In [1]:

from EvolutionStrategy import ESModel
from RandmanFunctions import get_randman_dataset
from Utilities import spike_to_label

import snntorch as snn
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim

device = 'cuda'

## 1. SNN model

In [2]:
class SNN(nn.Module):
    def __init__(self, num_inputs, num_hidden, num_outputs, learn_beta, beta=0.95):
        super(SNN, self).__init__()
        self.fc1 = nn.Linear(num_inputs, num_hidden)
        self.lif1 = snn.Leaky(beta=beta, learn_beta=learn_beta)

        self.fc2 = nn.Linear(num_hidden, num_outputs)
        self.lif2 = snn.Leaky(beta=beta, learn_beta=learn_beta)

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        # import pdb; pdb.set_trace()
        batch_size, time_steps, num_neurons = x.shape
        x = x.permute(1, 0, 2)  # (time, batch, neurons)

        mem1, mem2 = [torch.zeros(batch_size, layer.out_features, device=x.device)
                      for layer in [self.fc1, self.fc2]]

        mem2_rec = []

        for t in range(time_steps):
            spk1, mem1 = self.lif1(self.fc1(x[t]), mem1)
            _, mem2 = self.lif2(self.fc2(spk1), mem2)
            mem2_rec.append(mem2)

        return torch.stack(mem2_rec, dim=0)  # (time_steps, batch_size, num_outputs)
    
def losscustom(pred, labels):
  # batch, classes, time_steps
  # import pdb; pdb.set_trace()
  mem = pred.permute(1,2,0)
  labels = labels.long()
  non_labels = 1-labels

  batch_idx = torch.arange(mem.shape[0])

  correct = mem[batch_idx, labels]
  non_correct = mem[batch_idx, non_labels]

  diff = non_correct - correct
  diff_activated = torch.where(diff > 0, diff, torch.zeros_like(diff))
  return (diff_activated).mean()


## 1.2 Training SNN for Randman

In [3]:
from sklearn.model_selection import train_test_split
import wandb
import os

def _run_snn_on_batch(model, x, y, loss_fn): 
    # shape: [time_steps, batch_size, classes]
    spike_train = model(x)
    pred_y = spike_to_label(spike_train, scheme = 'highest_voltage')
    
    loss = loss_fn(spike_train, y.long())
    correct = (pred_y == y).sum().item()
    
    return loss, correct

def log_model(es_model,run):
    filename = 'best-model.pth'
    model = es_model.get_best_model()
    torch.save(model.state_dict(), filename)
    run.log_model(path=filename)
    os.remove(filename)  

def val_loop_snn(es_model, dataloader, loss_fn):
    model = es_model.get_best_model()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        batch_loss, batch_correct = _run_snn_on_batch(model, x, y, loss_fn) 
        test_loss += batch_loss
        correct += batch_correct

    test_loss /= num_batches
    test_acc = correct / size
    print(f"Test Error: \nAccuracy: {(100*test_acc):>0.1f}%, Avg loss: {test_loss:>8f} \n")
        
    return test_loss.item(), test_acc

def train_loop_snn(es_model, train_dataloader, val_dataloader, loss_fn, nb_model_samples, run):
    """ one epoch of training, going through all the batches once
    """    
    for batch, (x, y) in enumerate(train_dataloader):
        x, y = x.to(device), y.to(device)
        # train model with samples
        samples_loss = []
        for model in es_model.samples(nb_model_samples):
            loss, _ = _run_snn_on_batch(model, x, y, loss_fn)
            samples_loss.append(loss)            
            
        samples_loss = torch.stack(samples_loss) 
        es_model.gradient_descent(samples_loss)
    
        # best model loss and accuracy
        best_model = es_model.get_best_model()
        best_loss, best_correct = _run_snn_on_batch(best_model, x, y, loss_fn)   
        best_acc = best_correct / len(y)
        print(f"batch {batch}, loss: {best_loss:>7f}, accuracy: {100 * best_acc:>0.1f}%")
        
        # validation loss and accuracy
        val_loss, val_acc = val_loop_snn(es_model, val_dataloader, loss_fn)
        
        # record keeping
        run.log({'train_loss': best_loss.item(), 'train_acc' : best_acc, 'val_loss': val_loss, 'val_acc': val_acc}) 
        log_model(es_model, run)
    
  


In [4]:
def train_snn():   
    config = { # Dataset:
              'nb_input' : 100,
              'nb_output' : 2,  
              'nb_steps' : 200,
              'nb_data_samples': 1000,
              # SNN:
              'nb_hidden' : 100,
              'learn_beta' : False,             
              # Evolution Strategy:
              'nb_model_samples' : 1000, 
              # Training: 
              'std' : 0.05,
              'epochs' : 50, 
              'batch_size' : 256,
              # Optimization:
              'loss': 'Parashbuh',
              'optimizer' : 'AdamW',
              'lr' : 0.01,
              'regularization':'None'}
    with torch.no_grad(), wandb.init(entity = 'DarwinNeuron', project = 'DarwinNeuron', config=config) as run:  
        train_dataset, val_dataset = train_test_split(get_randman_dataset(nb_classes=run.config.nb_output, nb_units=run.config.nb_input, nb_steps=run.config.nb_steps, nb_samples=run.config.nb_data_samples), test_size=0.2)
        train_dataloader = DataLoader(train_dataset, batch_size=run.config.batch_size, shuffle=True)
        val_dataloader = DataLoader(val_dataset, batch_size=512, shuffle=False)      
        es_model = ESModel(SNN, run.config.nb_input, run.config.nb_hidden, run.config.nb_output, 0.95, param_std = run.config.std, Optimizer=optim.AdamW, lr=run.config.lr, weight_decay=1e-3)
        for epoch in range(run.config.epochs):
            print(f"Epoch {epoch}\n-------------------------------")
            # train the model
            train_loop_snn(es_model,train_dataloader, val_dataloader, losscustom, run.config.nb_model_samples, run)
    
train_snn() 

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33myixing[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 0
-------------------------------
batch 0, loss: 0.019187, accuracy: 46.5%
Test Error: 
Accuracy: 57.8%, Avg loss: 0.018650 

batch 1, loss: 0.039958, accuracy: 57.0%
Test Error: 
Accuracy: 52.0%, Avg loss: 0.042622 

batch 2, loss: 0.083525, accuracy: 52.7%
Test Error: 
Accuracy: 48.0%, Avg loss: 0.086030 

batch 3, loss: 0.126849, accuracy: 44.1%
Test Error: 
Accuracy: 51.0%, Avg loss: 0.125744 

batch 4, loss: 0.140900, accuracy: 60.5%
Test Error: 
Accuracy: 55.5%, Avg loss: 0.143034 

batch 5, loss: 0.154596, accuracy: 52.3%
Test Error: 
Accuracy: 47.0%, Avg loss: 0.150514 

batch 6, loss: 0.150560, accuracy: 60.9%
Test Error: 
Accuracy: 51.2%, Avg loss: 0.153189 

Epoch 1
-------------------------------
batch 0, loss: 0.157537, accuracy: 52.3%
Test Error: 
Accuracy: 52.8%, Avg loss: 0.155472 

batch 1, loss: 0.160378, accuracy: 46.9%
Test Error: 
Accuracy: 50.2%, Avg loss: 0.158860 

batch 2, loss: 0.159865, accuracy: 48.8%
Test Error: 
Accuracy: 50.0%, Avg loss: 0.160475 



Traceback (most recent call last):
  File "/tmp/ipykernel_925953/1949407616.py", line 29, in train_snn
    train_loop_snn(es_model,train_dataloader, val_dataloader, losscustom, run.config.nb_model_samples, run)
  File "/tmp/ipykernel_925953/4221535779.py", line 48, in train_loop_snn
    loss, _ = _run_snn_on_batch(model, x, y, loss_fn)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_925953/4221535779.py", line 7, in _run_snn_on_batch
    spike_train = model(x)
                  ^^^^^^^^
  File "/home/wyx/miniconda3/envs/snn/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wyx/miniconda3/envs/snn/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_925953/1753872684.py", line 27, in f

BrokenPipeError: [Errno 32] Broken pipe

Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7f1be392f140>> (for post_run_cell), with arguments args (<ExecutionResult object at 7f1be3989fa0, execution_count=4 error_before_exec=None error_in_exec=[Errno 32] Broken pipe info=<ExecutionInfo object at 7f1be395bb30, raw_cell="def train_snn():   
    config = { # Dataset:
    .." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://wsl%2Bubuntu-22.04/home/wyx/darwin_neuron/Workspace.ipynb#W6sdnNjb2RlLXJlbW90ZQ%3D%3D> result=None>,),kwargs {}:


BrokenPipeError: [Errno 32] Broken pipe

In [None]:
def test_retrieve():   
  with torch.no_grad(), wandb.init() as run:
    model_path = run.use_model('DarwinNeuron/SingleNeuron/run-eouglkvm-test-model.pth:v1')  
    model = torch.load(model_path, weights_only=False)

test_retrieve()