In [1]:
import os
import h5py

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns

import torch
# print(torch.__version__)
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
# import torchvision
from torch.utils import data


import pickle
import math
from utils import get_shd_dataset

# The coarse network structure and the time steps are dicated by the SHD dataset.
nb_inputs  = 700
nb_hidden  = 200
nb_outputs = 20

time_step = 1e-3
nb_steps = 100
max_time = 1.4

batch_size = 64

dtype = torch.float

# Check whether a GPU is available
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

# Here we load the Dataset
cache_dir = os.path.expanduser("~/data")
cache_subdir = "hdspikes"
get_shd_dataset(cache_dir, cache_subdir)

# Here we load the Dataset
cache_dir = os.path.expanduser("~/data")
cache_subdir = "hdspikes"
get_shd_dataset(cache_dir, cache_subdir)

train_file = h5py.File(os.path.join(cache_dir, cache_subdir, 'shd_train.h5'), 'r')
test_file = h5py.File(os.path.join(cache_dir, cache_subdir, 'shd_test.h5'), 'r')

x_train = train_file['spikes']
y_train = train_file['labels']
x_test = test_file['spikes']
y_test = test_file['labels']

def sparse_data_generator_from_hdf5_spikes(X, y, batch_size, nb_steps, nb_units, max_time, shuffle=True):
    """ This generator takes a spike dataset and generates spiking network input as sparse tensors.

    Args:
        X: The data ( sample x event x 2 ) the last dim holds (time,neuron) tuples
        y: The labels
    """

    labels_ = np.array(y,dtype=int)
    number_of_batches = len(labels_)//batch_size
    sample_index = np.arange(len(labels_))

    # compute discrete firing times
    firing_times = X['times']
    units_fired = X['units']

    time_bins = np.linspace(0, max_time, num=nb_steps)

    if shuffle:
        np.random.shuffle(sample_index)

    total_batch_count = 0
    counter = 0
    while counter<number_of_batches:
        batch_index = sample_index[batch_size*counter:batch_size*(counter+1)]

        coo = [ [] for i in range(3) ]
        for bc,idx in enumerate(batch_index):
            times = np.digitize(firing_times[idx], time_bins)
            units = units_fired[idx]
            batch = [bc for _ in range(len(times))]

            coo[0].extend(batch)
            coo[1].extend(times)
            coo[2].extend(units)

        i = torch.LongTensor(coo).to(device)
        v = torch.FloatTensor(np.ones(len(coo[0]))).to(device)

        X_batch = torch.sparse.FloatTensor(i, v, torch.Size([batch_size,nb_steps,nb_units])).to(device)
        y_batch = torch.tensor(labels_[batch_index],device=device)

        yield X_batch.to(device=device), y_batch.to(device=device)

        counter += 1
        
class SurrGradSpike(torch.autograd.Function):
    """
    Here we implement our spiking nonlinearity which also implements
    the surrogate gradient. By subclassing torch.autograd.Function,
    we will be able to use all of PyTorch's autograd functionality.
    Here we use the normalized negative part of a fast sigmoid
    as this was done in Zenke & Ganguli (2018).
    """

    scale = 100.0 # controls steepness of surrogate gradient

    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we compute a step function of the input Tensor
        and return it. ctx is a context object that we use to stash information which
        we need to later backpropagate our error signals. To achieve this we use the
        ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        out = torch.zeros_like(input)
        out[input > 0] = 1.0
        return out

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor we need to compute the
        surrogate gradient of the loss with respect to the input.
        Here we use the normalized negative part of a fast sigmoid
        as this was done in Zenke & Ganguli (2018).
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad = grad_input/(SurrGradSpike.scale*torch.abs(input)+1.0)**2
        return grad

# here we overwrite our naive spike function by the "SurrGradSpike" nonlinearity which implements a surrogate gradient
spike_fn  = SurrGradSpike.apply

def dist_fn(dist):
    return {
        'gamma': lambda mean, k, size: np.random.gamma(k, scale=mean/k, size=size),
        'normal': lambda mean, k, size: np.random.normal(loc=mean, scale=mean/np.sqrt(k), size=size), #change standard deviation to match gamma
        'uniform': lambda _, maximum, size: np.random.uniform(low=0, high=maximum, size=size),
    }[dist.lower()]

cuda
Available at: /rds/general/user/aqa20/home/data/hdspikes/shd_train.h5
Available at: /rds/general/user/aqa20/home/data/hdspikes/shd_test.h5
Available at: /rds/general/user/aqa20/home/data/hdspikes/shd_train.h5
Available at: /rds/general/user/aqa20/home/data/hdspikes/shd_test.h5


In [2]:
def run_snn_hetero(inputs):
    # Expand parameters locally to match original sizes
    alpha_1_local = alpha_hetero_1.repeat_interleave(group_size, dim=1)
    beta_1_local = beta_hetero_1.repeat_interleave(group_size, dim=1)
    thresholds_local = thresholds_1.repeat_interleave(group_size, dim=1)
    reset_local = reset_1.repeat_interleave(group_size, dim=1)
    rest_local = rest_1.repeat_interleave(group_size, dim=1)
    alpha_2_local = alpha_hetero_2.repeat_interleave(group_size, dim=1)
    beta_2_local = beta_hetero_2.repeat_interleave(group_size, dim=1)
#     print(beta_2_local)
#     print(thresholds_local.shape)


    # Initialize memory and synaptic variables
    syn = torch.zeros((batch_size_hetero, nb_hidden), device=device, dtype=dtype)
    mem = torch.zeros((batch_size_hetero, nb_hidden), device=device, dtype=dtype)

    mem_rec = []
    spk_rec = []

    # Compute hidden layer activity
    out = torch.zeros((batch_size_hetero, nb_hidden), device=device, dtype=dtype)
    h1_from_input = torch.einsum("abc,cd->abd", (inputs, w1))
    for t in range(nb_steps):
        h1 = h1_from_input[:, t] + torch.einsum("ab,bc->ac", (out, v1))
        mthr = mem - thresholds_local
        out = spike_fn(mthr)
        rst = torch.zeros_like(mem)
        c = (mthr > 0)
        rst[c] = torch.ones_like(mem)[c]

        new_syn = alpha_1_local * syn + h1
        new_mem = beta_1_local * (mem - rest_local) + rest_local + (1 - beta_1_local) * syn - rst * (thresholds_local - reset_local)

        mem_rec.append(mem)
        spk_rec.append(out)

        mem = new_mem
        syn = new_syn

    mem_rec = torch.stack(mem_rec, dim=1)
    spk_rec = torch.stack(spk_rec, dim=1)

    # Readout layer
    h2 = torch.einsum("abc,cd->abd", (spk_rec, w2))
    flt = torch.zeros((batch_size_hetero, nb_outputs), device=device, dtype=dtype)
    out = torch.zeros((batch_size_hetero, nb_outputs), device=device, dtype=dtype)
    out_rec = [out]
    for t in range(nb_steps):
        new_flt = alpha_2_local * flt + h2[:, t]
        new_out = beta_2_local * out + (1 - beta_2_local) * flt

        flt = new_flt
        out = new_out

        out_rec.append(out)

    out_rec = torch.stack(out_rec, dim=1)
    other_recs = [mem_rec, spk_rec]
    return out_rec, other_recs

In [3]:
def compute_classification_accuracy_hetero(x_data, y_data):
    """ Computes classification accuracy on supplied data in batches. """
    accs = []
    for x_local, y_local in sparse_data_generator_from_hdf5_spikes(x_data, y_data, batch_size_hetero, nb_steps, nb_inputs, max_time, shuffle=False):
        output,_ = run_snn_hetero(x_local.to_dense())
        m,_= torch.max(output,1) # max over time
        _,am=torch.max(m,1)      # argmax over output units
        tmp = np.mean((y_local==am).detach().cpu().numpy()) # compare to labels
        accs.append(tmp)
    return np.mean(accs)

In [4]:
def train_snn_hetero(x_data, y_data, lr=1e-3, nb_epochs=10):
    params = [w1, w2, v1, alpha_hetero_1, beta_hetero_1,
              alpha_hetero_2, beta_hetero_2,
              thresholds_1, reset_1, rest_1]
    optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.999))
    loss_fn = nn.NLLLoss()
    log_softmax_fn = nn.LogSoftmax(dim=1)

    loss_hist = []
    best_accuracy = 0
    best_params = params

    for e in range(nb_epochs):
        local_loss = []
        local_ground_loss = []
        local_reg_loss = []
        accs = []
#         print(thresholds_1)
#         print(beta_hetero_2)
        for x_local, y_local in sparse_data_generator_from_hdf5_spikes(x_data, y_data, batch_size_hetero, nb_steps, nb_inputs, max_time):
            output, recs = run_snn_hetero(x_local.to_dense())
            _, spks = recs
            m, _ = torch.max(output, 1)

            _, am = torch.max(m, 1)  # argmax over output units
            tmp = np.mean((y_local == am).detach().cpu().numpy())  # compare to labels
            accs.append(tmp)

            log_p_y = nn.LogSoftmax(dim=1)(m)
            ground_loss = loss_fn(log_p_y, y_local)

            reg_loss = 1e-6 * torch.sum(spks)  # L1 loss on total number of spikes
            reg_loss += 1e-6 * torch.mean(torch.sum(torch.sum(spks, dim=0), dim=0) ** 2)  # L2 loss on spikes per neuron

            loss_val = ground_loss + reg_loss

            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()
            # Clamping the values
            with torch.no_grad():
                alpha_hetero_1.clamp_(0.5, 0.995)
                beta_hetero_1.clamp_(0.5, 0.995)
                alpha_hetero_2.clamp_(0.5, 0.995)
                beta_hetero_2.clamp_(0.5, 0.995)
                thresholds_1.clamp_(0.5, 1.5)
                
            local_loss.append(loss_val.item())

        mean_loss = np.mean(local_loss)
        loss_hist.append(mean_loss)
        print(f"Epoch {e + 1}: loss={mean_loss:.5f}")

        current_accuracy = compute_classification_accuracy_hetero(x_test, y_test)
        print(f"Epoch {e + 1}: Train= {np.mean(accs):.5f} Test Accuracy={current_accuracy:.5f}")

        saved_params_hetero = {
            'w1': w1.clone(),
            'w2': w2.clone(),
            'v1': v1.clone(),
            'alpha': alpha_hetero_1.clone(),
            'beta': beta_hetero_1.clone(),
            'threshold': thresholds_1.clone(),
            'reset': reset_1.clone(),
            'rest': rest_1.clone(),
            'alpha_2': alpha_hetero_2.clone(),
            'beta_2': beta_hetero_2.clone()
        }

        # Save parameters along with the current epoch and accuracy
        directory = f'hetero_group/epochs_{group_size}'

        # Create the directory if it does not exist
        if not os.path.exists(directory):
            os.makedirs(directory)

        # Save the file in the specified directory
        file_path = os.path.join(directory, f'snn_{e + 1}.pth')
        torch.save({
            'epoch': e + 1,
            'accuracy': current_accuracy,
            'params': saved_params_hetero,
            'loss': loss_hist
        }, file_path)

        # Print the best accuracy so far
        if current_accuracy > best_accuracy:
            best_accuracy = current_accuracy
            print(f"Epoch {e + 1}: Best Test Accuracy={best_accuracy:.5f}")

            directory = f'hetero_group/best_{group_size}'

            # Create the directory if it does not exist
            if not os.path.exists(directory):
                os.makedirs(directory)

            # Save parameters only when a new best accuracy is achieved
            # Create a dictionary of current parameters to save
            saved_params_hetero = {
                'w1': w1.clone(),
                'w2': w2.clone(),
                'v1': v1.clone(),
                'alpha': alpha_hetero_1.clone(),
                'beta': beta_hetero_1.clone(),
                'threshold': thresholds_1.clone(),
                'reset': reset_1.clone(),
                'rest': rest_1.clone(),
                'alpha_2': alpha_hetero_2.clone(),
                'beta_2': beta_hetero_2.clone()
            }

            # Save the file in the specified directory
            file_path = os.path.join(directory, 'best_snn.pth')
            torch.save({
                'epoch': e + 1,
                'accuracy': best_accuracy,
                'params': saved_params_hetero,
                'loss': loss_hist
            }, file_path)
        else:
            print('Best', best_accuracy)

    return loss_hist

In [6]:
weight_scale = 0.2

w1 = torch.empty((nb_inputs, nb_hidden),  device=device, dtype=dtype, requires_grad=True)
torch.nn.init.normal_(w1, mean=0.0, std=weight_scale/np.sqrt(nb_inputs))

w2 = torch.empty((nb_hidden, nb_outputs), device=device, dtype=dtype, requires_grad=True)
torch.nn.init.normal_(w2, mean=0.0, std=weight_scale/np.sqrt(nb_hidden))

v1 = torch.empty((nb_hidden, nb_hidden), device=device, dtype=dtype, requires_grad=True)
torch.nn.init.normal_(v1, mean=0.0, std=weight_scale/np.sqrt(nb_hidden))

group_size = 10

tau_syn = 10e-3
tau_mem = 20e-3
distribution = dist_fn('gamma')

# Initialize parameters with sizes divided by group size
thresholds_1 = torch.empty((1, nb_hidden // group_size), device=device, dtype=dtype, requires_grad=True)
torch.nn.init.uniform_(thresholds_1, a=0.5, b=1.5)

reset_1 = torch.empty((1, nb_hidden // group_size), device=device, dtype=dtype, requires_grad=True)
torch.nn.init.uniform_(reset_1, a=-0.5, b=0.5)

rest_1 = torch.empty((1, nb_hidden // group_size), device=device, dtype=dtype, requires_grad=True)
torch.nn.init.uniform_(rest_1, a=-0.5, b=0.5)

alpha_hetero_1_dist = torch.tensor(distribution(tau_syn, 3, (1, nb_hidden // group_size)), device=device, dtype=dtype)
alpha_hetero_1 = torch.exp(-time_step / alpha_hetero_1_dist)
alpha_hetero_1.requires_grad_(True)

beta_hetero_1_dist = torch.tensor(distribution(tau_mem, 3, (1, nb_hidden // group_size)), device=device, dtype=dtype)
beta_hetero_1 = torch.exp(-time_step / beta_hetero_1_dist)
beta_hetero_1.requires_grad_(True)

alpha_hetero_2_dist = torch.tensor(distribution(tau_syn, 3, (1, nb_outputs // group_size)), device=device, dtype=dtype)
alpha_hetero_2 = torch.exp(-time_step / alpha_hetero_2_dist)
alpha_hetero_2.requires_grad_(True)

beta_hetero_2_dist = torch.tensor(distribution(tau_mem, 3, (1, nb_outputs // group_size)), device=device, dtype=dtype)
beta_hetero_2 = torch.exp(-time_step / beta_hetero_2_dist)
beta_hetero_2.requires_grad_(True)

tensor([[0.8771, 0.9271]], device='cuda:0', requires_grad=True)

In [100]:
print(thresholds_1.shape)

torch.Size([1, 20])


In [98]:
# group size 20
nb_epochs_snn_hetero = 150
batch_size_hetero = 64
loss_hist_snn_hetero = train_snn_hetero(x_train, y_train, lr=2e-4, nb_epochs=nb_epochs_snn_hetero)

  labels_ = np.array(y,dtype=int)


Epoch 1: loss=2.91877
Epoch 1: Train= 0.12094 Test Accuracy=0.22232
Epoch 1: Best Test Accuracy=0.22232
Epoch 2: loss=2.13737
Epoch 2: Train= 0.39333 Test Accuracy=0.46563
Epoch 2: Best Test Accuracy=0.46563
Epoch 3: loss=1.40638
Epoch 3: Train= 0.62217 Test Accuracy=0.58348
Epoch 3: Best Test Accuracy=0.58348
Epoch 4: loss=1.08133
Epoch 4: Train= 0.73155 Test Accuracy=0.64018
Epoch 4: Best Test Accuracy=0.64018
Epoch 5: loss=0.89405
Epoch 5: Train= 0.78752 Test Accuracy=0.63482
Best 0.6401785714285714
Epoch 6: loss=0.65510
Epoch 6: Train= 0.84252 Test Accuracy=0.71518
Epoch 6: Best Test Accuracy=0.71518
Epoch 7: loss=0.52789
Epoch 7: Train= 0.88398 Test Accuracy=0.69554
Best 0.7151785714285714
Epoch 8: loss=0.45395
Epoch 8: Train= 0.90650 Test Accuracy=0.71116
Best 0.7151785714285714
Epoch 9: loss=0.39933
Epoch 9: Train= 0.92015 Test Accuracy=0.70759
Best 0.7151785714285714
Epoch 10: loss=0.35738
Epoch 10: Train= 0.93393 Test Accuracy=0.72232
Epoch 10: Best Test Accuracy=0.72232
Epoch

KeyboardInterrupt: 

In [134]:
# group size 10
nb_epochs_snn_hetero = 150
batch_size_hetero = 64
loss_hist_snn_hetero = train_snn_hetero(x_train, y_train, lr=2e-4, nb_epochs=nb_epochs_snn_hetero)

  labels_ = np.array(y,dtype=int)


Epoch 1: loss=2.87604
Epoch 1: Train= 0.12623 Test Accuracy=0.25268
Epoch 1: Best Test Accuracy=0.25268
Epoch 2: loss=2.32391
Epoch 2: Train= 0.32074 Test Accuracy=0.48482
Epoch 2: Best Test Accuracy=0.48482
Epoch 3: loss=1.61333
Epoch 3: Train= 0.56398 Test Accuracy=0.58080
Epoch 3: Best Test Accuracy=0.58080
Epoch 4: loss=1.17592
Epoch 4: Train= 0.67938 Test Accuracy=0.62813
Epoch 4: Best Test Accuracy=0.62813
Epoch 5: loss=0.93247
Epoch 5: Train= 0.75283 Test Accuracy=0.67277
Epoch 5: Best Test Accuracy=0.67277
Epoch 6: loss=0.75331
Epoch 6: Train= 0.80536 Test Accuracy=0.71161
Epoch 6: Best Test Accuracy=0.71161
Epoch 7: loss=0.63463
Epoch 7: Train= 0.84732 Test Accuracy=0.72143
Epoch 7: Best Test Accuracy=0.72143
Epoch 8: loss=0.57061
Epoch 8: Train= 0.86270 Test Accuracy=0.68884
Best 0.7214285714285714
Epoch 9: loss=0.51218
Epoch 9: Train= 0.87426 Test Accuracy=0.73661
Epoch 9: Best Test Accuracy=0.73661
Epoch 10: loss=0.45911
Epoch 10: Train= 0.89087 Test Accuracy=0.73973
Epoch 

KeyboardInterrupt: 

In [None]:
# DONT RUN NEXT ONE

In [91]:
# group size 5
nb_epochs_snn_hetero = 150
batch_size_hetero = 64
loss_hist_snn_hetero = train_snn_hetero(x_train, y_train, lr=2e-4, nb_epochs=nb_epochs_snn_hetero)

  labels_ = np.array(y,dtype=int)


Epoch 1: loss=2.94189
Epoch 1: Train= 0.09560 Test Accuracy=0.20714
Epoch 1: Best Test Accuracy=0.20714
Epoch 2: loss=2.50568
Epoch 2: Train= 0.27153 Test Accuracy=0.37545
Epoch 2: Best Test Accuracy=0.37545
Epoch 3: loss=2.00341
Epoch 3: Train= 0.41941 Test Accuracy=0.52054
Epoch 3: Best Test Accuracy=0.52054
Epoch 4: loss=1.55043
Epoch 4: Train= 0.55217 Test Accuracy=0.55357
Epoch 4: Best Test Accuracy=0.55357
Epoch 5: loss=1.17684
Epoch 5: Train= 0.66302 Test Accuracy=0.63393
Epoch 5: Best Test Accuracy=0.63393
Epoch 6: loss=0.93456
Epoch 6: Train= 0.74975 Test Accuracy=0.65045
Epoch 6: Best Test Accuracy=0.65045
Epoch 7: loss=0.78834
Epoch 7: Train= 0.79860 Test Accuracy=0.68973
Epoch 7: Best Test Accuracy=0.68973
Epoch 8: loss=0.67215
Epoch 8: Train= 0.83022 Test Accuracy=0.70045
Epoch 8: Best Test Accuracy=0.70045
Epoch 9: loss=0.59558
Epoch 9: Train= 0.85273 Test Accuracy=0.73527
Epoch 9: Best Test Accuracy=0.73527
Epoch 10: loss=0.52889
Epoch 10: Train= 0.87180 Test Accuracy=0.

KeyboardInterrupt: 

# MLP

In [140]:
class hetero_mlp_a_b_spikes(nn.Module):
    def __init__(self, group_size=1):
        super(hetero_mlp_a_b_spikes, self).__init__()
        self.input_size = 1961 #(adding 40 for alpha and beta 2)
        self.hidden_size = 2048
        self.group_size = group_size
        self.output_size = (200 // group_size) * 5 + (20 // group_size) * 2

        self.layers = nn.Sequential(
            nn.Linear(self.input_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.output_size),
            nn.Sigmoid()
        )

        self.init_weights()

    def forward(self, x):
        return self.layers(x)

    def init_weights(self):
        with torch.no_grad():
            # Initialize first layer weights and biases
            self.layers[0].weight.fill_(0)
            self.layers[0].bias.fill_(0)

            # Pass the first 1000 inputs directly to the hidden layer; the rest to 0
            for i in range(1040):
                self.layers[0].weight[i, i] = 1

            # Initialize second layer weights and biases
            self.layers[2].weight.fill_(0)
            self.layers[2].bias.fill_(0)

            # Pass the first 1000 inputs directly to the hidden layer; the rest to 0
            for i in range(1040):
                self.layers[2].weight[i, i] = 1


In [147]:
def run_snn_hybrid_alpha_beta_spikes_HETERO(inputs, mlp, mlp_interval, batch_size_MLP):
    device, dtype = inputs.device, inputs.dtype

    # Initialize local copies of alpha, beta, threshold, reset, and rest for all 200 hidden neurons
    alpha_1_local = alpha_hetero_1.repeat_interleave(group_size, dim=1).expand(batch_size_MLP, 200).detach().clone()
    beta_1_local = beta_hetero_1.repeat_interleave(group_size, dim=1).expand(batch_size_MLP, 200).detach().clone()
    thresholds_local = thresholds_1.repeat_interleave(group_size, dim=1).expand(batch_size_MLP, 200).detach().clone()
    reset_local = reset_1.repeat_interleave(group_size, dim=1).expand(batch_size_MLP, 200).detach().clone()
    rest_local = rest_1.repeat_interleave(group_size, dim=1).expand(batch_size_MLP, 200).detach().clone()
    alpha_2_local = alpha_hetero_2.repeat_interleave(group_size, dim=1).expand(batch_size_MLP, 20).detach().clone()
    beta_2_local = beta_hetero_2.repeat_interleave(group_size, dim=1).expand(batch_size_MLP, 20).detach().clone()

    # Initialize synaptic and membrane potentials
    syn = torch.zeros((batch_size_MLP, nb_hidden), device=device, dtype=dtype)
    mem = torch.zeros((batch_size_MLP, nb_hidden), device=device, dtype=dtype)

    # Initialize recordings for membrane potentials and spikes
    mem_rec = []
    spk_rec = []

    # Initialize outputs for the hidden layer
    out = torch.zeros((batch_size_MLP, nb_hidden), device=device, dtype=dtype)
    h1_from_input = torch.einsum("abc,cd->abd", (inputs, w1))

    # Prepare readout layer variables
    flt2 = torch.zeros((batch_size_MLP, nb_outputs), device=device, dtype=dtype)
    out2 = torch.zeros((batch_size_MLP, nb_outputs), device=device, dtype=dtype)
    out_rec = [out2]

    for t in range(nb_steps):
        h1 = h1_from_input[:, t] + torch.einsum("ab,bc->ac", (out, v1))
        mthr = mem - thresholds_local
        out = spike_fn(mthr)
        rst = torch.zeros_like(mem)
        c = (mthr > 0)
        rst[c] = torch.ones_like(mem)[c]

        # Update synaptic and membrane potentials
        syn = alpha_1_local * syn + h1
        mem = beta_1_local * (mem - rest_local) + rest_local + (1 - beta_1_local) * syn - rst * (thresholds_local - reset_local)

        # Record membrane potentials and spikes
        mem_rec.append(mem)
        spk_rec.append(out)

        # Now compute h2 on the fly
        h2_t = torch.einsum("ab,bc->ac", (out, w2))
        flt2 = alpha_2_local * flt2 + h2_t
        out2 = beta_2_local * out2 + flt2 * (1 - beta_2_local)
        out_rec.append(out2)

        # Flatten and concatenate spikes for each item in the batch
        input_spikes_flat = inputs[:, t, :].reshape(batch_size_MLP, -1)  # Shape: [batch_size, 700]
        hidden_spikes_flat = out.reshape(batch_size_MLP, -1)  # Shape: [batch_size, 200]
        output_spikes_flat = out2.reshape(batch_size_MLP, -1)  # Shape: [batch_size, 20]

        # Time tensor
        time_tensor = torch.full((batch_size_MLP, 1), t, device=device, dtype=dtype)

        # Concatenate tensors
        mlp_input = torch.cat([
            alpha_1_local, beta_1_local, thresholds_local, reset_local, rest_local,
            alpha_2_local, beta_2_local,
            time_tensor,
            input_spikes_flat, hidden_spikes_flat, output_spikes_flat
        ], dim=1)

        # Process with MLP (in a single call for the whole batch)
        mlp_outputs = mlp(mlp_input)

        if t % mlp_interval == 0:
            # Update alpha_local and beta_local based on MLP outputs
            num_groups = 200 // group_size
            for i in range(num_groups):
                alpha_1_update = mlp_outputs[:, i].unsqueeze(1).expand(-1, group_size)
                beta_1_update = mlp_outputs[:, num_groups + i].unsqueeze(1).expand(-1, group_size)
                thresholds_update = mlp_outputs[:, 2*num_groups + i].unsqueeze(1).expand(-1, group_size) + 0.5
                reset_update = mlp_outputs[:, 3*num_groups + i].unsqueeze(1).expand(-1, group_size) - 0.5
                rest_update = mlp_outputs[:, 4*num_groups + i].unsqueeze(1).expand(-1, group_size) - 0.5
                
                alpha_1_local = alpha_1_local.clone()
                beta_1_local = beta_1_local.clone()
                thresholds_local = thresholds_local.clone()
                reset_local = reset_local.clone()
                rest_local = rest_local.clone()
                
                alpha_1_local[:, i*group_size:(i+1)*group_size] = alpha_1_update
                beta_1_local[:, i*group_size:(i+1)*group_size] = beta_1_update
                thresholds_local[:, i*group_size:(i+1)*group_size] = thresholds_update
                reset_local[:, i*group_size:(i+1)*group_size] = reset_update
                rest_local[:, i*group_size:(i+1)*group_size] = rest_update

            num_groups2 = 20 // group_size
            for i in range(num_groups2):
                alpha_2_update = mlp_outputs[:, 5*num_groups + i].unsqueeze(1).expand(-1, group_size)
                beta_2_update = mlp_outputs[:, 5*num_groups + num_groups2 + i].unsqueeze(1).expand(-1, group_size)
                
                alpha_2_local = alpha_2_local.clone()
                beta_2_local = beta_2_local.clone()
                
                alpha_2_local[:, i*group_size:(i+1)*group_size] = alpha_2_update
                beta_2_local[:, i*group_size:(i+1)*group_size] = beta_2_update

    # Stack recordings for output
    mem_rec = torch.stack(mem_rec, dim=1).to(device)
    spk_rec = torch.stack(spk_rec, dim=1).to(device)
    out_rec = torch.stack(out_rec[1:], dim=1).to(device)  # Skip the initial zero tensor

    other_recs = [mem_rec, spk_rec]

    return out_rec, other_recs


In [142]:
def compute_classification_accuracy_MLP(x_data, y_data, mlp, mlp_interval):
    """ Computes classification accuracy on supplied data in batches. """
    accs = []
    with torch.no_grad():
        for x_local, y_local in sparse_data_generator_from_hdf5_spikes(x_data, y_data, batch_size_hetero, nb_steps, nb_inputs, max_time, shuffle=False):
            output, _ = run_snn_hybrid_alpha_beta_spikes_HETERO(x_local.to_dense(),  mlp=mlp, mlp_interval=mlp_interval, batch_size_MLP=batch_size_hetero)
            m,_= torch.max(output,1) # max over time
            _,am=torch.max(m,1)      # argmax over output units
            tmp = np.mean((y_local==am).detach().cpu().numpy()) # compare to labels
            accs.append(tmp)
    return np.mean(accs)

In [143]:
def train_hybrid(mlp, x_data, y_data, lr=1e-3, nb_epochs=10, mlp_interval=10):

    snn_params = [w1, w2, v1, 
                  alpha_hetero_1, beta_hetero_1,
                  thresholds_1, reset_1, rest_1,
                  alpha_hetero_2, beta_hetero_2]
#  

    # Optimizers
    combined_params = [
        {'params': snn_params, 'lr': lr},  # Parameters for SNN with specific learning rate
        {'params': mlp.parameters(), 'lr': lr}  # Parameters for MLP with its own learning rate
    ]

    # Using a single optimizer for both SNN and MLP
    combined_optimizer = torch.optim.Adam(combined_params)


    #Loss functions
    loss_fn = nn.NLLLoss()
    log_softmax_fn = nn.LogSoftmax(dim=1)

    best_accuracy = 0

    loss_hist = []
    for epoch in range(nb_epochs):
        local_loss = []
        local_ground_loss = []
        local_reg_loss = []
        accs = []
#         print(w1)
#         print(alpha_hetero_1)
#         print(reset_1)
        for x_local, y_local in sparse_data_generator_from_hdf5_spikes(x_data, y_data, batch_size_hetero, nb_steps, nb_inputs, max_time):
                output, recs = run_snn_hybrid_alpha_beta_spikes_HETERO(inputs=x_local.to_dense(), mlp=mlp, mlp_interval=mlp_interval, batch_size_MLP=batch_size_hetero)
                _ , spks = recs
                m, _ = torch.max(output,1)

                _,am=torch.max(m,1)      # argmax over output units
                tmp = np.mean((y_local==am).detach().cpu().numpy()) # compare to labels
                accs.append(tmp)

                log_p_y = log_softmax_fn(m)
                ground_loss = loss_fn(log_p_y, y_local)
                reg_loss = 1e-6*torch.sum(spks) # L1 loss on total number of spikes
                reg_loss += 1e-6*torch.mean(torch.sum(torch.sum(spks,dim=0),dim=0)**2) # L2 loss on spikes per neuron

                loss_MLP = ground_loss + reg_loss

                combined_optimizer.zero_grad()
                loss_MLP.backward()
                combined_optimizer.step()
                
                # Clamping the values
                with torch.no_grad():
                    alpha_hetero_1.clamp_(0.367, 0.995)
                    beta_hetero_1.clamp_(0.367, 0.995)
                    alpha_hetero_2.clamp_(0.367, 0.995)
                    beta_hetero_2.clamp_(0.367, 0.995)
                    thresholds_1.clamp_(0.5, 1.5)

                local_loss.append(loss_MLP.item())
                local_ground_loss.append(ground_loss.item())
                local_reg_loss.append(reg_loss.item())


        mean_loss = np.mean(local_loss)
        loss_hist.append(mean_loss)
        print(f"Epoch {epoch+1}: loss={mean_loss:.5f}")
        print("ground_loss", np.mean(local_ground_loss))
        print("reg_loss", np.mean(local_reg_loss))
        current_accuracy = compute_classification_accuracy_MLP(x_test, y_test, mlp, mlp_interval)
        print(f"Epoch {epoch+1}: Train= {np.mean(accs):.5f} Test Accuracy={current_accuracy:.5f}")

        # Print the best accuracy so far
        if current_accuracy > best_accuracy:
            best_accuracy = current_accuracy

            directory = f'hetero_group/hybrid_{group_size}'

            # Create the directory if it does not exist
            if not os.path.exists(directory):
                os.makedirs(directory)

            best_model_state = mlp_mlp.state_dict()

            # Save the file in the specified directory
            file_path = os.path.join(directory, 'mlp.pt')
            torch.save(best_model_state, file_path)

            # Save parameters only when a new best accuracy is achieved
            # Create a dictionary of current parameters to save
            saved_params_hetero = {
                'w1': w1.clone(),
                'w2': w2.clone(),
                'v1': v1.clone(),
                'alpha': alpha_hetero_1.clone(),
                'beta': beta_hetero_1.clone(),
                'threshold': thresholds_1.clone(),
                'reset': reset_1.clone(),
                'rest': rest_1.clone(),
                'alpha_2': alpha_hetero_2.clone(),
                'beta_2': beta_hetero_2.clone()
            }

            # Save the file in the specified directory
            file_path = os.path.join(directory, 'snn.pth')
            torch.save({
                'epoch': epoch + 1,
                'accuracy': best_accuracy,
                'params': saved_params_hetero,
                'loss': loss_hist
            }, file_path)
        else:
            print("Best", best_accuracy)

    return loss_hist

In [125]:
# loaded_weights_snn = torch.load('Hetero/3_Hetero/epochs_2048/snn_9.pth')
loaded_weights_snn = torch.load('hetero_group/epochs_5/snn_11.pth')
# loaded_weights_snn = torch.load('hybrid_hetero/snn_2048.pth')
# Convert tensors to parameters and ensure they are leaf tensors by re-wrapping them
# Move tensors to device first and then wrap them as parameters
w1 = torch.nn.Parameter(loaded_weights_snn['params']['w1'].to(device))
w2 = torch.nn.Parameter(loaded_weights_snn['params']['w2'].to(device))
v1 = torch.nn.Parameter(loaded_weights_snn['params']['v1'].to(device))
alpha_hetero_1 = torch.nn.Parameter(loaded_weights_snn['params']['alpha'].to(device))
beta_hetero_1 = torch.nn.Parameter(loaded_weights_snn['params']['beta'].to(device))
thresholds_1 = torch.nn.Parameter(loaded_weights_snn['params']['threshold'].to(device))
reset_1 = torch.nn.Parameter(loaded_weights_snn['params']['reset'].to(device))
rest_1 = torch.nn.Parameter(loaded_weights_snn['params']['rest'].to(device))
alpha_hetero_2 = torch.nn.Parameter(loaded_weights_snn['params']['alpha_2'].to(device))
beta_hetero_2 = torch.nn.Parameter(loaded_weights_snn['params']['beta_2'].to(device))
print(loaded_weights_snn['accuracy'])

0.7651785714285714


In [126]:
nb_epochs_mlp = 100
batch_size_hetero = 64
group_size = 5
mlp_mlp = hetero_mlp_a_b_spikes().to(device)
# mlp_mlp = hetero_mlp_no_hidden().to(device)
loss_hist_MLP = train_hybrid(mlp_mlp, x_train, y_train, lr=2e-4, nb_epochs=nb_epochs_mlp, mlp_interval = 10)

  labels_ = np.array(y,dtype=int)


Epoch 1: loss=0.74033
ground_loss 0.6297724763239463
reg_loss 0.11056003267840138
Epoch 1: Train= 0.79970 Test Accuracy=0.75089
Epoch 2: loss=0.55367
ground_loss 0.4610331496150475
reg_loss 0.09263711023753084
Epoch 2: Train= 0.84953 Test Accuracy=0.70625
Best 0.7508928571428571
Epoch 3: loss=0.51194
ground_loss 0.42002318811228895
reg_loss 0.09191734707496298
Epoch 3: Train= 0.86097 Test Accuracy=0.71964
Best 0.7508928571428571
Epoch 4: loss=0.47671
ground_loss 0.3872542736802514
reg_loss 0.0894587017361104
Epoch 4: Train= 0.87525 Test Accuracy=0.66473
Best 0.7508928571428571
Epoch 5: loss=0.45719
ground_loss 0.37025028197314797
reg_loss 0.08693560752577668
Epoch 5: Train= 0.88017 Test Accuracy=0.76250
Epoch 6: loss=0.41809
ground_loss 0.33189681160637713
reg_loss 0.08619182879530539
Epoch 6: Train= 0.89542 Test Accuracy=0.74062
Best 0.7625
Epoch 7: loss=0.38184
ground_loss 0.3024492536004134
reg_loss 0.07939259238599793
Epoch 7: Train= 0.90674 Test Accuracy=0.73571
Best 0.7625
Epoch 

KeyboardInterrupt: 

In [149]:
# loaded_weights_snn = torch.load('Hetero/3_Hetero/epochs_2048/snn_9.pth')
loaded_weights_snn = torch.load('hetero_group/epochs_10/snn_7.pth')
# loaded_weights_snn = torch.load('hybrid_hetero/snn_2048.pth')
# Convert tensors to parameters and ensure they are leaf tensors by re-wrapping them
# Move tensors to device first and then wrap them as parameters
w1 = torch.nn.Parameter(loaded_weights_snn['params']['w1'].to(device))
w2 = torch.nn.Parameter(loaded_weights_snn['params']['w2'].to(device))
v1 = torch.nn.Parameter(loaded_weights_snn['params']['v1'].to(device))
alpha_hetero_1 = torch.nn.Parameter(loaded_weights_snn['params']['alpha'].to(device))
beta_hetero_1 = torch.nn.Parameter(loaded_weights_snn['params']['beta'].to(device))
thresholds_1 = torch.nn.Parameter(loaded_weights_snn['params']['threshold'].to(device))
reset_1 = torch.nn.Parameter(loaded_weights_snn['params']['reset'].to(device))
rest_1 = torch.nn.Parameter(loaded_weights_snn['params']['rest'].to(device))
alpha_hetero_2 = torch.nn.Parameter(loaded_weights_snn['params']['alpha_2'].to(device))
beta_hetero_2 = torch.nn.Parameter(loaded_weights_snn['params']['beta_2'].to(device))
print(loaded_weights_snn['accuracy'])

0.7214285714285714


In [150]:
print(alpha_hetero_2)

Parameter containing:
tensor([[0.9841, 0.9950]], device='cuda:0', requires_grad=True)


In [151]:
nb_epochs_mlp = 100
batch_size_hetero = 64
group_size = 10
mlp_mlp = hetero_mlp_a_b_spikes().to(device)
# mlp_mlp = hetero_mlp_no_hidden().to(device)
loss_hist_MLP = train_hybrid(mlp_mlp, x_train, y_train, lr=2e-4, nb_epochs=nb_epochs_mlp, mlp_interval = 10)

  labels_ = np.array(y,dtype=int)


Epoch 1: loss=0.72540
ground_loss 0.6006458985993243
reg_loss 0.1247497781759172
Epoch 1: Train= 0.80438 Test Accuracy=0.72991
Epoch 2: loss=0.58929
ground_loss 0.4800604441034512
reg_loss 0.10922743649933282
Epoch 2: Train= 0.84781 Test Accuracy=0.71920
Best 0.7299107142857143
Epoch 3: loss=0.55015
ground_loss 0.4474364751436579
reg_loss 0.10271459336825244
Epoch 3: Train= 0.85605 Test Accuracy=0.71875
Best 0.7299107142857143
Epoch 4: loss=0.49637
ground_loss 0.3964260559617065
reg_loss 0.09994044307414002
Epoch 4: Train= 0.87574 Test Accuracy=0.76339
Epoch 5: loss=0.49527
ground_loss 0.3982166673724107
reg_loss 0.09705540461568382
Epoch 5: Train= 0.87315 Test Accuracy=0.73661
Best 0.7633928571428571
Epoch 6: loss=0.43391
ground_loss 0.3408663220058276
reg_loss 0.09304651829201406
Epoch 6: Train= 0.89579 Test Accuracy=0.73884
Best 0.7633928571428571
Epoch 7: loss=0.41770
ground_loss 0.32535451298623574
reg_loss 0.09234427731102846
Epoch 7: Train= 0.90367 Test Accuracy=0.77857
Epoch 8:

KeyboardInterrupt: 

In [1]:
import torch

In [5]:
loaded_weights_snn = torch.load('SHD_Results/hetero_group/best_20/best_snn.pth')

# Convert tensors to parameters and ensure they are leaf tensors by re-wrapping them
# Move tensors to device first and then wrap them as parameters
# w1 = torch.nn.Parameter(loaded_weights_snn['params']['w1'].to(device))
# w2 = torch.nn.Parameter(loaded_weights_snn['params']['w2'].to(device))
# v1 = torch.nn.Parameter(loaded_weights_snn['params']['v1'].to(device))
# alpha_hetero_1 = torch.nn.Parameter(loaded_weights_snn['params']['alpha'].to(device))
# beta_hetero_1 = torch.nn.Parameter(loaded_weights_snn['params']['beta'].to(device))
# thresholds_1 = torch.nn.Parameter(loaded_weights_snn['params']['threshold'].to(device))
# reset_1 = torch.nn.Parameter(loaded_weights_snn['params']['reset'].to(device))
# rest_1 = torch.nn.Parameter(loaded_weights_snn['params']['rest'].to(device))
# alpha_hetero_2 = torch.nn.Parameter(loaded_weights_snn['params']['alpha_2'].to(device))
# beta_hetero_2 = torch.nn.Parameter(loaded_weights_snn['params']['beta_2'].to(device))
# print(loaded_weights_snn['train_acc_hist'])
# print(loaded_weights_snn['test_acc_hist'])
print(loaded_weights_snn['accuracy'])

0.7580357142857143


In [47]:
loaded_weights_snn = torch.load('SSC/Python_Tests/Hybrid_Hetero_no_reg/epochs/snn_12.pth')
# loaded_weights_snn = torch.load('SSC/Python_Tests/Hybrid_Homo_no_reg/epochs/snn_24.pth')
# loaded_weights_snn = torch.load('SSC/Python_Tests/Hybrid_Hetero_no_reg/snn_best.pth')
# loaded_weights_snn = torch.load('SSC/Python_Tests/Hybrid_Homo_no_reg/snn_best.pth')

# Convert tensors to parameters and ensure they are leaf tensors by re-wrapping them
# Move tensors to device first and then wrap them as parameters
# w1 = torch.nn.Parameter(loaded_weights_snn['params']['w1'].to(device))
# w2 = torch.nn.Parameter(loaded_weights_snn['params']['w2'].to(device))
# v1 = torch.nn.Parameter(loaded_weights_snn['params']['v1'].to(device))
# alpha_hetero_1 = torch.nn.Parameter(loaded_weights_snn['params']['alpha'].to(device))
# beta_hetero_1 = torch.nn.Parameter(loaded_weights_snn['params']['beta'].to(device))
# thresholds_1 = torch.nn.Parameter(loaded_weights_snn['params']['threshold'].to(device))
# reset_1 = torch.nn.Parameter(loaded_weights_snn['params']['reset'].to(device))
# rest_1 = torch.nn.Parameter(loaded_weights_snn['params']['rest'].to(device))
# alpha_hetero_2 = torch.nn.Parameter(loaded_weights_snn['params']['alpha_2'].to(device))
# beta_hetero_2 = torch.nn.Parameter(loaded_weights_snn['params']['beta_2'].to(device))
print(loaded_weights_snn['train_acc_hist'])
print(loaded_weights_snn['test_acc_hist'])
print(loaded_weights_snn['epoch'])

[0.49601092027141647, 0.5752358990670059, 0.6082882739609838, 0.6287505301102629, 0.641128604749788, 0.6568861323155216, 0.664585983884648, 0.6778917514843087, 0.6891831000848176, 0.6979696776929601, 0.7086779050042409, 0.7160596904156065]
[0.5218160377358491, 0.5707547169811321, 0.584561713836478, 0.60844143081761, 0.6238207547169812, 0.6235259433962265, 0.6201356132075472, 0.6298643867924528, 0.6335986635220126, 0.6460298742138365, 0.6479952830188679, 0.6422955974842768]
12
