In [1]:

from EvolutionStrategy import ESModel
from RandmanFunctions import get_randman_dataset
from Utilities import spike_to_label, voltage_to_logits
from torch.nn.functional import cross_entropy

import snntorch as snn
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torch.optim as optim

device = 'cuda'

## 1. SNN model

In [2]:
class SNN(nn.Module):
    def __init__(self, num_inputs, num_hidden, num_outputs, learn_beta, beta=0.95):
        super(SNN, self).__init__()
        self.fc1 = nn.Linear(num_inputs, num_hidden, bias=False)
        self.lif1 = snn.Leaky(beta=beta, learn_beta=learn_beta)

        self.fc2 = nn.Linear(num_hidden, num_outputs, bias= False)
        self.lif2 = snn.Leaky(beta=beta, learn_beta=learn_beta, reset_mechanism='none')

        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                # nn.init.zeros_(m.bias)

    def forward(self, x):
        # import pdb; pdb.set_trace()
        batch_size, time_steps, num_neurons = x.shape
        x = x.permute(1, 0, 2)  # (time, batch, neurons)

        mem1, mem2 = [torch.zeros(batch_size, layer.out_features, device=x.device)
                      for layer in [self.fc1, self.fc2]]

        mem2_rec = []

        for t in range(time_steps):
            spk1, mem1 = self.lif1(self.fc1(x[t]), mem1)
            _, mem2 = self.lif2(self.fc2(spk1), mem2)
            mem2_rec.append(mem2)

        return torch.stack(mem2_rec, dim=0)  # (time_steps, batch_size, num_outputs)
    
def losscustom(pred, labels):
  # batch, classes, time_steps
  # import pdb; pdb.set_trace()
  mem = pred.permute(1,2,0)
  labels = labels.long()
  non_labels = 1-labels

  batch_idx = torch.arange(mem.shape[0])

  correct = mem[batch_idx, labels]
  non_correct = mem[batch_idx, non_labels]

  diff = non_correct - correct
  diff_activated = torch.where(diff > 0, diff, torch.zeros_like(diff))
  return (diff_activated).mean()

def regularized_cross_entropy(pred, y):
    # pred shape: [batch, classes (logits)]
    regularization_term = torch.sigmoid(-15 * torch.abs(pred[:,0] - pred[:, 1]))
    
    return cross_entropy(pred, y) + regularization_term.mean()


## 2 Training SNN for Randman

In [3]:
from sklearn.model_selection import train_test_split
import wandb
import os

def _run_snn_on_batch(model, x, y, loss_fn): 
    # shape: [time_steps, batch_size, classes]
    voltages = model(x)
    pred_y = spike_to_label(voltages, scheme = 'highest_voltage')
    logits = voltage_to_logits(voltages, scheme='highest-voltage')
    
    loss = loss_fn(logits, y.long())
    correct = (pred_y == y).sum().item()
    
    return loss, correct

def log_model(es_model,run):
    filename = 'best-model.pth'
    model = es_model.get_best_model()
    torch.save(model.state_dict(), filename)
    run.log_model(path=filename)
    os.remove(filename)  

def val_loop_snn(es_model, dataloader, loss_fn):
    model = es_model.get_best_model()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0

    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        batch_loss, batch_correct = _run_snn_on_batch(model, x, y, loss_fn) 
        test_loss += batch_loss
        correct += batch_correct

    test_loss /= num_batches
    test_acc = correct / size
    print(f"Test Error: \nAccuracy: {(100*test_acc):>0.1f}%, Avg loss: {test_loss:>8f} \n")
        
    return test_loss.item(), test_acc

def train_loop_snn(es_model, train_dataloader, val_dataloader, loss_fn, nb_model_samples, run):
    """ one epoch of training, going through all the batches once
    """    
    for batch, (x, y) in enumerate(train_dataloader):
        x, y = x.to(device), y.to(device)
        # train model with samples
        samples_loss = []
        for model in es_model.samples(nb_model_samples):
            loss, _ = _run_snn_on_batch(model, x, y, loss_fn)
            samples_loss.append(loss)            
            
        samples_loss = torch.stack(samples_loss) 
        es_model.gradient_descent(samples_loss)
    
        # best model loss and accuracy
        best_model = es_model.get_best_model()
        best_loss, best_correct = _run_snn_on_batch(best_model, x, y, loss_fn)   
        best_acc = best_correct / len(y)
        print(f"batch {batch}, loss: {best_loss:>7f}, accuracy: {100 * best_acc:>0.1f}%")
        
        # validation loss and accuracy
        val_loss, val_acc = val_loop_snn(es_model, val_dataloader, loss_fn)
        
        # record keeping
        run.log({'train_loss': best_loss.item(), 'train_acc' : best_acc, 'val_loss': val_loss, 'val_acc': val_acc}) 
        log_model(es_model, run)

In [None]:
def train_snn():   
    run_name = 'cross-entropy-no-weight'
    config = { # Dataset:
              'nb_input' : 100,
              'nb_output' : 2,  
              'nb_steps' : 50,
              'nb_data_samples': 2000,
              # SNN:
              'nb_hidden' : 10,
              'learn_beta' : False,             
              # Evolution Strategy:
              'nb_model_samples' : 1000, 
              # Training: 
              'std' : 0.05,
              'epochs' : 50, 
              'batch_size' : 256,
              # Optimization:
              'loss': 'cross-entropy',
              'optimizer' : 'Adam',
              'lr' : 0.01,
              'regularization':'none'}
    with torch.no_grad(), wandb.init(entity = 'DarwinNeuron', project = 'DarwinNeuron', name=run_name, config=config) as run:  
        train_dataset, val_dataset = train_test_split(get_randman_dataset(run.config.nb_output, run.config.nb_input, run.config.nb_steps, run.config.nb_data_samples), test_size=0.2, shuffle=False)
        train_dataloader = DataLoader(train_dataset, batch_size=run.config.batch_size, shuffle=True)
        val_dataloader = DataLoader(val_dataset, batch_size=512, shuffle=False)      
        es_model = ESModel(SNN, run.config.nb_input, run.config.nb_hidden, run.config.nb_output, 0.95, param_std = run.config.std, Optimizer=optim.Adam, lr=run.config.lr)
        for epoch in range(run.config.epochs):
            print(f"Epoch {epoch}\n-------------------------------")
            # train the model
            train_loop_snn(es_model,train_dataloader, val_dataloader, cross_entropy, run.config.nb_model_samples, run)
    
train_snn() 

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33myixing[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Epoch 0
-------------------------------
batch 0, loss: 0.759352, accuracy: 52.0%
Test Error: 
Accuracy: 43.2%, Avg loss: 0.786760 

batch 1, loss: 0.766292, accuracy: 42.6%
Test Error: 
Accuracy: 44.5%, Avg loss: 0.773640 

batch 2, loss: 0.766191, accuracy: 43.0%
Test Error: 
Accuracy: 44.1%, Avg loss: 0.761532 

batch 3, loss: 0.744340, accuracy: 44.9%
Test Error: 
Accuracy: 45.0%, Avg loss: 0.749920 

batch 4, loss: 0.724434, accuracy: 45.3%
Test Error: 
Accuracy: 45.5%, Avg loss: 0.747345 

batch 5, loss: 0.702150, accuracy: 52.7%
Test Error: 
Accuracy: 45.6%, Avg loss: 0.743667 

batch 6, loss: 0.724966, accuracy: 49.6%
Test Error: 
Accuracy: 47.0%, Avg loss: 0.735923 

batch 7, loss: 0.741561, accuracy: 44.1%
Test Error: 
Accuracy: 47.0%, Avg loss: 0.732970 

batch 8, loss: 0.716242, accuracy: 46.5%
Test Error: 
Accuracy: 46.6%, Avg loss: 0.731681 

batch 9, loss: 0.708410, accuracy: 52.3%
Test Error: 
Accuracy: 46.6%, Avg loss: 0.729497 

batch 10, loss: 0.700385, accuracy: 53.5

## 3. Analysis


- ~~print last layer voltage trace over time~~
- ~~hidden layer spike train~~
- ~~hidden layer voltage~~