In [1]:

from EvolutionStrategy import ESModel
from RandmanFunctions import get_randman_dataset
from Utilities import spike_to_label, voltage_to_logits


import snntorch as snn
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim
from torch.nn.functional import cross_entropy

device = 'cuda'

## 1. SNN model

In [None]:
class ExcitationPopulation(nn.Module):
    def __init__(self, nb_neurons, nb_input, nb_inh, beta, nb_decision_steps):
        super().__init__()        
        self.input_fc = nn.Linear(nb_input, nb_neurons, bias = False)
        self.recurrent_fc = nn.Linear(nb_neurons, nb_neurons, bias = False)
        self.inhibition_fc = nn.Linear(nb_inh, nb_neurons, bias = False)
        self.lif = snn.Leaky(beta, learn_beta = False, threshold=1, reset_mechanism='subtract')
        self.readout_fc = nn.Linear(nb_neurons, 1, bias=False)
        self.readout_lif = snn.Leaky(beta=0.95, reset_mechanism = 'none')
        self.nb_decision_steps = nb_decision_steps
        self.init_states(nb_decision_steps)       

    def get_nb_neurons(self):
        return self.input_fc.out_features
    
    def init_states(self, nb_decision_steps=None):
        # excitatory neurons
        self.mem = self.lif.init_leaky().to(device)
        self.last_spks_queue = [torch.zeros(self.get_nb_neurons(), device=device) for _ in range(nb_decision_steps if nb_decision_steps != None else len(self.last_spks_queue))]  
        
        # readout neuron
        self.readout_mem = self.readout_lif.init_leaky().to(device)
        self.readout_mem_rec = []

    def forward(self, input, inhibition):
        # excitatory neurons
        # TODO: 
        # PROBLEM: THE ABS() IS APPLIED IN THE WRONG PLACE. SHOULD APPLY ON WEIGHTS BEFORE LINEAR COMBINATION
        curr = torch.abs(self.input_fc(input)) + torch.abs(self.recurrent_fc(self.last_spks_queue[-1])) - torch.abs(self.inhibition_fc(inhibition))
        spk, self.mem = self.lif(curr, self.mem)
        
        # readout neuron
        readout_curr = torch.abs(self.readout_fc(spk))
        _, self.readout_mem = self.readout_lif(readout_curr, self.readout_mem)    
        
        # update spk record
        self.last_spks_queue.pop(0)
        self.last_spks_queue.append(spk.clone().detach())
        
        # update readout record
        self.readout_mem_rec.append(self.readout_mem.squeeze(dim=-1).clone())
        
        return spk
        
    def get_last_spikes_means(self):
        # stacked shape: [nb_decision_steps, batch_size, nb_neurons]
        # For each batch, the mean should include all the final steps and all the neurons (first and last dimension)
        # return shape: [batch_size,]
        return torch.stack(self.last_spks_queue).mean(dim = [0, 2])
    
    def get_readout(self):    
        # stacked shape [nb_decision_steps, batch_size]
        return torch.stack(self.readout_mem_rec[-self.nb_decision_steps: ]).mean(dim=0)
    
# def test_ep():
#     ep = ExcitationPopulation(nb_neurons=3, nb_input=10, nb_inh=1, beta=0.95, nb_decision_steps=5).to(device)
#     for _ in range(100):
#         fake_spk = torch.rand([64, 10], device=device)
#         fake_inh = torch.rand([64, 1], device=device)
#         out = ep(fake_spk, fake_inh)
#     print(ep.get_readout().shape)
        
# test_ep()

In [3]:
class CompetitionModel(nn.Module):
    def __init__(self, nb_input, nb_ext, nb_inh, beta_ext, beta_inh, nb_decision_steps):
        super().__init__()
        
        # excitatory
        self.excitatory_1 = ExcitationPopulation(nb_ext, nb_input, nb_inh, beta_ext, nb_decision_steps)
        self.excitatory_2 = ExcitationPopulation(nb_ext, nb_input, nb_inh, beta_ext, nb_decision_steps)
        
        # inhibitory.
        self.inh_fc = nn.Linear(nb_ext, 1, bias = False) # Note: two ext share same inh weights
        self.inh_lif = snn.Leaky(beta_inh, learn_beta = False)
        
        # records
        self.nb_decision_steps = nb_decision_steps
        
    def get_nb_ext(self):
        return self.excitatory_1.get_nb_neurons()
    
    def get_nb_inh(self):
        return self.inh_fc.out_features
    
    def init_states(self):
        self.excitatory_1.init_states()
        self.excitatory_2.init_states()
        self.mem_inh = self.inh_lif.init_leaky()
        
    def forward(self, x):        
        # change x shape from [batch, time steps, nb_input] to [time steps, batch, nb_input]
        x = x.permute([1, 0, 2])
        
        # pad time steps for model to go to steady states
        x = torch.cat([x, torch.zeros(5 + self.nb_decision_steps, x.shape[1], x.shape[2], device=device)])
        
        # initalize membrane potentials
        self.init_states()
        
        # init spikes with shape [nb_neurons]. The batch size will be broadcasted
        inh_spk = torch.zeros([self.get_nb_inh()], device=device)
        
        for t in range(len(x)):          
            # excitation
            ext_1_spk = self.excitatory_1(x[t], inh_spk)
            ext_2_spk = self.excitatory_2(x[t], inh_spk)
            
            # inhibition. Inhibitory neurons are excited, so curr should be positive
            curr_inh = torch.abs(self.inh_fc(ext_1_spk)) + torch.abs(self.inh_fc(ext_2_spk))
            
            inh_spk, self.mem_inh = self.inh_lif(curr_inh, self.mem_inh)
        
        # return shape: [batch_size, 2], where column 0 is ext1, column 1 is ext2
        return torch.stack([self.excitatory_1.get_readout(), self.excitatory_2.get_readout()], dim = 1)
    
    def get_mem_rec(self):
        # stacked shape: [batch, time_steps]
        mem_rec_1 = torch.stack(self.excitatory_1.readout_mem_rec, dim = 1)
        mem_rec_2 = torch.stack(self.excitatory_2.readout_mem_rec, dim = 1)
        
        # return shape: [batch, time_steps, 2]
        return torch.stack([mem_rec_1, mem_rec_2], dim=2)

# def test_cm():
#     cm = CompetitionModel(nb_input=10, nb_ext=3, nb_inh=1, beta_ext=0.75, beta_inh=0.95, nb_decision_steps=10)
#     x = torch.rand([64, 100, 10])
#     print(cm(x).shape)
# test_cm()
            

## 2 Training SNN for Randman

In [5]:
from sklearn.model_selection import train_test_split
import wandb
import os

def _run_snn_on_batch(model, x, y, loss_fn): 
    # shape: [time_steps, batch_size, classes]
    logits = model(x)
    pred_y = torch.argmax(logits, dim=1)
    
    loss = loss_fn(logits, y.long())
    correct = (pred_y == y).sum().item()
    
    return loss, correct

def log_model(es_model,run):
    filename = 'best-model.pth'
    model = es_model.get_best_model()
    torch.save(model.state_dict(), filename)
    run.log_model(path=filename)
    os.remove(filename)  

def val_loop_snn(es_model, dataloader, loss_fn):
    model = es_model.get_best_model()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        batch_loss, batch_correct = _run_snn_on_batch(model, x, y, loss_fn) 
        test_loss += batch_loss
        correct += batch_correct

    test_loss /= num_batches
    test_acc = correct / size
    print(f"Test Error: \nAccuracy: {(100*test_acc):>0.1f}%, Avg loss: {test_loss:>8f} \n")
        
    return test_loss.item(), test_acc

def train_loop_snn(es_model, train_dataloader, val_dataloader, loss_fn, nb_model_samples, run):
    """ one epoch of training, going through all the batches once
    """    
    for batch, (x, y) in enumerate(train_dataloader):
        x, y = x.to(device), y.to(device)
        # train model with samples
        samples_loss = []
        for model in es_model.samples(nb_model_samples):
            loss, _ = _run_snn_on_batch(model, x, y, loss_fn)
            samples_loss.append(loss)            
            
        samples_loss = torch.stack(samples_loss) 
        es_model.gradient_descent(samples_loss)
    
        # best model loss and accuracy
        best_model = es_model.get_best_model()
        best_loss, best_correct = _run_snn_on_batch(best_model, x, y, loss_fn)   
        best_acc = best_correct / len(y)
        print(f"batch {batch}, loss: {best_loss:>7f}, accuracy: {100 * best_acc:>0.1f}%")
        
        # validation loss and accuracy
        val_loss, val_acc = val_loop_snn(es_model, val_dataloader, loss_fn)
        
        # record keeping
        run.log({'train_loss': best_loss.item(), 'train_acc' : best_acc, 'val_loss': val_loss, 'val_acc': val_acc}) 
        log_model(es_model, run)

In [6]:
def train_snn():   
    run_name = 'ABSOLUUUUUTE'
    config = { # Dataset:
              'nb_input' : 100,
              'nb_output' : 2,  
              'nb_steps' : 50,
              'nb_data_samples': 2000,
              # SNN:
              'nb_ext' : 3,
              'nb_inh' : 1,
              'beta_ext': 0.95,
              'beta_inh' : 0.75,
              'nb_decision_steps' : 10,           
              # Evolution Strategy:
              'nb_model_samples' : 1000, 
              # Training: 
              'std' : 0.15,
              'epochs' : 50, 
              'batch_size' : 256,
              # Optimization:
              'loss': 'cross-entropy',
              'optimizer' : 'Adam',
              'lr' : 0.01,
              'regularization':'none'}
    with torch.no_grad(), wandb.init(entity = 'DarwinNeuron', project = 'ES-Randman-Competition', name=run_name, config=config) as run:  
        train_dataset, val_dataset = train_test_split(get_randman_dataset(run.config.nb_output, run.config.nb_input, run.config.nb_steps, run.config.nb_data_samples), test_size=0.2, shuffle=False)
        train_dataloader = DataLoader(train_dataset, batch_size=run.config.batch_size, shuffle=True)
        val_dataloader = DataLoader(val_dataset, batch_size=512, shuffle=False)      
        es_model = ESModel(CompetitionModel, run.config.nb_input, run.config.nb_ext, run.config.nb_inh, run.config.beta_ext, run.config.beta_inh, run.config.nb_decision_steps, param_std = run.config.std, Optimizer=optim.Adam, lr=run.config.lr)
        for epoch in range(run.config.epochs):
            print(f"Epoch {epoch}\n-------------------------------")
            # train the model
            train_loop_snn(es_model,train_dataloader, val_dataloader, cross_entropy, run.config.nb_model_samples, run)
    
train_snn() 

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33myixing[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 0
-------------------------------
batch 0, loss: 0.712599, accuracy: 49.2%
Test Error: 
Accuracy: 49.4%, Avg loss: 0.711836 

batch 1, loss: 0.706591, accuracy: 50.4%
Test Error: 
Accuracy: 47.8%, Avg loss: 0.717977 

batch 2, loss: 0.697645, accuracy: 53.5%
Test Error: 
Accuracy: 50.2%, Avg loss: 0.717653 

batch 3, loss: 0.701385, accuracy: 52.7%
Test Error: 
Accuracy: 47.9%, Avg loss: 0.729832 

batch 4, loss: 0.693850, accuracy: 54.3%
Test Error: 
Accuracy: 47.5%, Avg loss: 0.724438 

batch 5, loss: 0.713635, accuracy: 50.8%
Test Error: 
Accuracy: 48.8%, Avg loss: 0.727030 

batch 6, loss: 0.727859, accuracy: 52.0%
Test Error: 
Accuracy: 47.9%, Avg loss: 0.719463 

batch 7, loss: 0.709250, accuracy: 48.4%
Test Error: 
Accuracy: 48.5%, Avg loss: 0.731352 

batch 8, loss: 0.688512, accuracy: 52.7%
Test Error: 
Accuracy: 47.4%, Avg loss: 0.721888 

batch 9, loss: 0.704369, accuracy: 51.6%
Test Error: 
Accuracy: 49.9%, Avg loss: 0.716532 

batch 10, loss: 0.694753, accuracy: 54.7

Traceback (most recent call last):
  File "/tmp/ipykernel_1074863/1566236211.py", line 33, in train_snn
    train_loop_snn(es_model,train_dataloader, val_dataloader, cross_entropy, run.config.nb_model_samples, run)
  File "/tmp/ipykernel_1074863/2498007126.py", line 65, in train_loop_snn
    log_model(es_model, run)
  File "/tmp/ipykernel_1074863/2498007126.py", line 18, in log_model
    torch.save(model.state_dict(), filename)
  File "/home/wyx/miniconda3/envs/snn/lib/python3.12/site-packages/torch/serialization.py", line 849, in save
    with _open_zipfile_writer(f) as opened_zipfile:
         ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wyx/miniconda3/envs/snn/lib/python3.12/site-packages/torch/serialization.py", line 716, in _open_zipfile_writer
    return container(name_or_buffer)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/wyx/miniconda3/envs/snn/lib/python3.12/site-packages/torch/serialization.py", line 687, in __init__
    super().__init__(torch._C.PyTorchFileWriter(self.name))

OSError: [Errno 19] No such device: '/mnt/d/darwin_neuron'

--- Logging error ---
Traceback (most recent call last):
  File "/home/wyx/miniconda3/envs/snn/lib/python3.12/logging/__init__.py", line 1164, in emit
    self.flush()
  File "/home/wyx/miniconda3/envs/snn/lib/python3.12/logging/__init__.py", line 1144, in flush
    self.stream.flush()
OSError: [Errno 5] Input/output error
Call stack:
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/wyx/miniconda3/envs/snn/lib/python3.12/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/home/wyx/miniconda3/envs/snn/lib/python3.12/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/home/wyx/miniconda3/envs/snn/lib/python3.12/site-packages/ipykernel/kernelapp.py", line 739, in start
    self.io_loop.start()
  File "/home/wyx/miniconda3/envs/snn/lib/python3.12/site-packages/tornado/platform/asyncio.py", line 205, in start
    self.asyncio_