In [1]:
import os
import h5py

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns

import torch
# print(torch.__version__)
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
# import torchvision
from torch.utils import data


import pickle
import math
from utils import get_shd_dataset

# The coarse network structure and the time steps are dicated by the SHD dataset.
nb_inputs  = 700
nb_hidden  = 200
nb_outputs = 35

time_step = 1e-3
nb_steps = 100
max_time = 1.4

batch_size = 64

dtype = torch.float

# Check whether a GPU is available
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

# Here we load the Dataset
# cache_dir = os.path.expanduser("~/data")
# cache_subdir = "hdspikes"
# get_shd_dataset(cache_dir, cache_subdir)

# Here we load the Dataset
cache_dir = os.path.expanduser("~/data")
cache_subdir = "ssc_data"
train_file = h5py.File(os.path.join(cache_dir, cache_subdir, 'ssc_train.h5'), 'r')
validation_file = h5py.File(os.path.join(cache_dir, cache_subdir, 'ssc_valid.h5'), 'r')
test_file = h5py.File(os.path.join(cache_dir, cache_subdir, 'ssc_test.h5'), 'r')

x_train = train_file['spikes']
y_train = train_file['labels']
x_valid = validation_file['spikes']
y_valid = validation_file['labels']
x_test = test_file['spikes']
y_test = test_file['labels']

# Here we load the Dataset
# cache_dir = os.path.expanduser("~/data")
# cache_subdir = "hdspikes"
# get_shd_dataset(cache_dir, cache_subdir)

# train_file = h5py.File(os.path.join(cache_dir, cache_subdir, 'shd_train.h5'), 'r')
# test_file = h5py.File(os.path.join(cache_dir, cache_subdir, 'shd_test.h5'), 'r')

# x_train = train_file['spikes']
# y_train = train_file['labels']
# x_test = test_file['spikes']
# y_test = test_file['labels']



def sparse_data_generator_from_hdf5_spikes(X, y, batch_size, nb_steps, nb_units, max_time, shuffle=True):
    """ This generator takes a spike dataset and generates spiking network input as sparse tensors.

    Args:
        X: The data ( sample x event x 2 ) the last dim holds (time,neuron) tuples
        y: The labels
    """

    labels_ = np.array(y,dtype=int)
    number_of_batches = len(labels_)//batch_size
    sample_index = np.arange(len(labels_))

    # compute discrete firing times
    firing_times = X['times']
    units_fired = X['units']

    time_bins = np.linspace(0, max_time, num=nb_steps)

    if shuffle:
        np.random.shuffle(sample_index)

    total_batch_count = 0
    counter = 0
    while counter<number_of_batches:
        batch_index = sample_index[batch_size*counter:batch_size*(counter+1)]

        coo = [ [] for i in range(3) ]
        for bc,idx in enumerate(batch_index):
            times = np.digitize(firing_times[idx], time_bins)
            units = units_fired[idx]
            batch = [bc for _ in range(len(times))]

            coo[0].extend(batch)
            coo[1].extend(times)
            coo[2].extend(units)

        i = torch.LongTensor(coo).to(device)
        v = torch.FloatTensor(np.ones(len(coo[0]))).to(device)

        X_batch = torch.sparse.FloatTensor(i, v, torch.Size([batch_size,nb_steps,nb_units])).to(device)
        y_batch = torch.tensor(labels_[batch_index],device=device)

        yield X_batch.to(device=device), y_batch.to(device=device)

        counter += 1
        
class SurrGradSpike(torch.autograd.Function):
    """
    Here we implement our spiking nonlinearity which also implements
    the surrogate gradient. By subclassing torch.autograd.Function,
    we will be able to use all of PyTorch's autograd functionality.
    Here we use the normalized negative part of a fast sigmoid
    as this was done in Zenke & Ganguli (2018).
    """

    scale = 100.0 # controls steepness of surrogate gradient

    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we compute a step function of the input Tensor
        and return it. ctx is a context object that we use to stash information which
        we need to later backpropagate our error signals. To achieve this we use the
        ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        out = torch.zeros_like(input)
        out[input > 0] = 1.0
        return out

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor we need to compute the
        surrogate gradient of the loss with respect to the input.
        Here we use the normalized negative part of a fast sigmoid
        as this was done in Zenke & Ganguli (2018).
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad = grad_input/(SurrGradSpike.scale*torch.abs(input)+1.0)**2
        return grad

# here we overwrite our naive spike function by the "SurrGradSpike" nonlinearity which implements a surrogate gradient
spike_fn  = SurrGradSpike.apply

def dist_fn(dist):
    return {
        'gamma': lambda mean, k, size: np.random.gamma(k, scale=mean/k, size=size),
        'normal': lambda mean, k, size: np.random.normal(loc=mean, scale=mean/np.sqrt(k), size=size), #change standard deviation to match gamma
        'uniform': lambda _, maximum, size: np.random.uniform(low=0, high=maximum, size=size),
    }[dist.lower()]
print("init done")

cuda
init done


In [12]:
def run_snn_hetero(inputs):
    # Initialize memory and synaptic variables
    syn = torch.zeros((batch_size_hetero, nb_hidden), device=device, dtype=dtype)
    mem = torch.zeros((batch_size_hetero, nb_hidden), device=device, dtype=dtype)

    mem_rec = []
    spk_rec = []

    # Compute hidden layer activity
    out = torch.zeros((batch_size_hetero, nb_hidden), device=device, dtype=dtype)
    h1_from_input = torch.einsum("abc,cd->abd", (inputs, w1))
    for t in range(nb_steps):
        h1 = h1_from_input[:, t] + torch.einsum("ab,bc->ac", (out, v1))
        mthr = mem - thresholds_1
        out = spike_fn(mthr)
        rst = torch.zeros_like(mem)
        c = (mthr > 0)
        rst[c] = torch.ones_like(mem)[c]
        # rst = out.detach() * reset  # Reset mechanism considering individual reset values

        new_syn = alpha_hetero_1 * syn + h1
        new_mem = beta_hetero_1 * (mem - rest_1) + rest_1 + (1 - beta_hetero_1) * syn - rst * (thresholds_1 - reset_1)

        mem_rec.append(mem)
        spk_rec.append(out)

        mem = new_mem
        syn = new_syn

    mem_rec = torch.stack(mem_rec, dim=1)
    spk_rec = torch.stack(spk_rec, dim=1)

    # Readout layer
    h2 = torch.einsum("abc,cd->abd", (spk_rec, w2))
    flt = torch.zeros((batch_size_hetero, nb_outputs), device=device, dtype=dtype)
    out = torch.zeros((batch_size_hetero, nb_outputs), device=device, dtype=dtype)
    out_rec = [out]
    for t in range(nb_steps):
        # print(alpha_hetero.shape)
        new_flt = alpha_hetero_2 * flt + h2[:, t]  # Assume alpha for the output layer
        # new_out = beta_hetero_2 * out + flt
        new_out = beta_hetero_2 * out + (1 - beta_hetero_2)*flt  # Assume beta for the output layer

        flt = new_flt
        out = new_out

        out_rec.append(out)

    out_rec = torch.stack(out_rec, dim=1)
    other_recs = [mem_rec, spk_rec]
    return out_rec, other_recs

In [13]:
def compute_classification_accuracy_hetero(x_data, y_data):
    """ Computes classification accuracy on supplied data in batches. """
    accs = []
    for x_local, y_local in sparse_data_generator_from_hdf5_spikes(x_data, y_data, batch_size_hetero, nb_steps, nb_inputs, max_time, shuffle=False):
        output,_ = run_snn_hetero(x_local.to_dense())
        m,_= torch.max(output,1) # max over time
        _,am=torch.max(m,1)      # argmax over output units
        tmp = np.mean((y_local==am).detach().cpu().numpy()) # compare to labels
        accs.append(tmp)
    return np.mean(accs)

In [20]:
def train_snn_hetero(x_data, y_data, lr=1e-3, nb_epochs=10):
    params = [w1, w2, v1, alpha_hetero_1, beta_hetero_1,
              alpha_hetero_2, beta_hetero_2,
              thresholds_1, reset_1, rest_1]
    optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.999))
    loss_fn = nn.NLLLoss()
    log_softmax_fn = nn.LogSoftmax(dim=1)

    loss_hist = []
    train_acc_hist = []
    test_acc_hist = []
    best_accuracy = 0
    best_params = params

    for e in range(nb_epochs):
        local_loss = []
        local_ground_loss = []
        local_reg_loss = []
        accs = []

        for x_local, y_local in sparse_data_generator_from_hdf5_spikes(x_data, y_data, batch_size_hetero, nb_steps, nb_inputs, max_time):
            output, recs = run_snn_hetero(x_local.to_dense())
            _, spks = recs
            m, _ = torch.max(output, 1)

            _, am = torch.max(m, 1)  # argmax over output units
            tmp = np.mean((y_local == am).detach().cpu().numpy())  # compare to labels
            accs.append(tmp)

            log_p_y = nn.LogSoftmax(dim=1)(m)

            ground_loss = loss_fn(log_p_y, y_local)

            reg_loss = 1e-6 * torch.sum(spks)  # L1 loss on total number of spikes
            reg_loss += 1e-6 * torch.mean(torch.sum(torch.sum(spks, dim=0), dim=0) ** 2)  # L2 loss on spikes per neuron

            loss_val = ground_loss + reg_loss

            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()
            # Clamping the values
            with torch.no_grad():
                alpha_hetero_1.clamp_(0.367, 0.995)
                beta_hetero_1.clamp_(0.367, 0.995)
                alpha_hetero_2.clamp_(0.367, 0.995)
                beta_hetero_2.clamp_(0.367, 0.995)
                thresholds_1.clamp_(0.5, 1.5)
                
            local_loss.append(loss_val.item())

        mean_loss = np.mean(local_loss)
        loss_hist.append(mean_loss)
        print(f"Epoch {e + 1}: loss={mean_loss:.5f}")

        train_accuracy = np.mean(accs)
        test_accuracy = compute_classification_accuracy_hetero(x_test, y_test)
        train_acc_hist.append(train_accuracy)
        test_acc_hist.append(test_accuracy)
        print(f"Epoch {e + 1}: Train= {train_accuracy:.5f} Test Accuracy={test_accuracy:.5f}")

        saved_params_hetero = {
            'w1': w1.clone(),
            'w2': w2.clone(),
            'v1': v1.clone(),
            'alpha': alpha_hetero_1.clone(),
            'beta': beta_hetero_1.clone(),
            'threshold': thresholds_1.clone(),
            'reset': reset_1.clone(),
            'rest': rest_1.clone(),
            'alpha_2': alpha_hetero_2.clone(),
            'beta_2': beta_hetero_2.clone()
        }

        # Save parameters along with the current epoch and accuracy
        directory = 'SSC_test_2/epochs_hetero'

        # Create the directory if it does not exist
        if not os.path.exists(directory):
            os.makedirs(directory)

        # Save the file in the specified directory
        file_path = os.path.join(directory, f'snn_{e + 1}.pth')
        torch.save({
            'epoch': e + 1,
            'accuracy': test_accuracy,
            'params': saved_params_hetero,
            'loss': loss_hist,
            'train_acc_hist': train_acc_hist,
            'test_acc_hist': test_acc_hist
        }, file_path)

        # Print the best accuracy so far
        if test_accuracy > best_accuracy:
            best_accuracy = test_accuracy
            print(f"Epoch {e + 1}: Best Test Accuracy={best_accuracy:.5f}")

            directory = 'SSC_test_2/best_hetero'

            # Create the directory if it does not exist
            if not os.path.exists(directory):
                os.makedirs(directory)

            # Save parameters only when a new best accuracy is achieved
            # Create a dictionary of current parameters to save
            saved_params_hetero = {
                'w1': w1.clone(),
                'w2': w2.clone(),
                'v1': v1.clone(),
                'alpha': alpha_hetero_1.clone(),
                'beta': beta_hetero_1.clone(),
                'threshold': thresholds_1.clone(),
                'reset': reset_1.clone(),
                'rest': rest_1.clone(),
                'alpha_2': alpha_hetero_2.clone(),
                'beta_2': beta_hetero_2.clone()
            }

            # Save the file in the specified directory
            file_path = os.path.join(directory, 'best_snn.pth')
            torch.save({
                'epoch': e + 1,
                'accuracy': best_accuracy,
                'params': saved_params_hetero,
                'loss': loss_hist,
                'train_acc_hist': train_acc_hist,
                'test_acc_hist': test_acc_hist
            }, file_path)
        else:
            print('Best', best_accuracy)

    return loss_hist

In [21]:
# Creating tensors with requires_grad=True
thresholds_1 = torch.empty((1, nb_hidden), device=device, dtype=dtype, requires_grad=True)
torch.nn.init.uniform_(thresholds_1, a=0.5, b=1.5)  # Thresholds uniformly distributed between 0.5 and 1.5

reset_1 = torch.empty((1, nb_hidden), device=device, dtype=dtype, requires_grad=True)
torch.nn.init.uniform_(reset_1, a=-0.5, b=0.5)  # Reset potentials uniformly distributed between -0.5 and 0.5

rest_1 = torch.empty((1, nb_hidden), device=device, dtype=dtype, requires_grad=True)
torch.nn.init.uniform_(rest_1, a=-0.5, b=0.5)  # Rest potentials uniformly distributed between -0.5 and 0.5

# Gamma distribution for alpha and beta
tau_syn = 10e-3
tau_mem = 20e-3
distribution = dist_fn('gamma')

alpha_hetero_1_dist = torch.tensor(distribution(tau_syn, 3, (1, nb_hidden)), device=device, dtype=dtype)
alpha_hetero_1 = torch.exp(-time_step / alpha_hetero_1_dist)
alpha_hetero_1.requires_grad_(True)

beta_hetero_1_dist = torch.tensor(distribution(tau_mem, 3, (1, nb_hidden)), device=device, dtype=dtype)
beta_hetero_1 = torch.exp(-time_step / beta_hetero_1_dist)
beta_hetero_1.requires_grad_(True)

alpha_hetero_2_dist = torch.tensor(distribution(tau_syn, 3, (1, nb_outputs)), device=device, dtype=dtype)
alpha_hetero_2 = torch.exp(-time_step / alpha_hetero_2_dist)
alpha_hetero_2.requires_grad_(True)

beta_hetero_2_dist = torch.tensor(distribution(tau_mem, 3, (1, nb_outputs)), device=device, dtype=dtype)
beta_hetero_2 = torch.exp(-time_step / beta_hetero_2_dist)
beta_hetero_2.requires_grad_(True)

weight_scale = 0.2

w1 = torch.empty((nb_inputs, nb_hidden),  device=device, dtype=dtype, requires_grad=True)
torch.nn.init.normal_(w1, mean=0.0, std=weight_scale/np.sqrt(nb_inputs))

w2 = torch.empty((nb_hidden, nb_outputs), device=device, dtype=dtype, requires_grad=True)
torch.nn.init.normal_(w2, mean=0.0, std=weight_scale/np.sqrt(nb_hidden))

v1 = torch.empty((nb_hidden, nb_hidden), device=device, dtype=dtype, requires_grad=True)
torch.nn.init.normal_(v1, mean=0.0, std=weight_scale/np.sqrt(nb_hidden))

tensor([[-0.0034,  0.0015, -0.0127,  ..., -0.0058, -0.0292, -0.0093],
        [ 0.0188,  0.0311, -0.0013,  ...,  0.0205, -0.0194, -0.0179],
        [ 0.0107,  0.0169,  0.0124,  ..., -0.0079,  0.0329, -0.0007],
        ...,
        [-0.0110, -0.0262,  0.0215,  ...,  0.0261,  0.0097, -0.0130],
        [ 0.0009, -0.0216, -0.0070,  ..., -0.0128,  0.0114,  0.0172],
        [-0.0007,  0.0024, -0.0260,  ...,  0.0034,  0.0070,  0.0336]],
       device='cuda:0', requires_grad=True)

In [22]:
nb_epochs_snn_hetero = 150
batch_size_hetero = 64
loss_hist_snn_hetero = train_snn_hetero(x_train, y_train, lr=2e-4, nb_epochs=nb_epochs_snn_hetero)

  labels_ = np.array(y,dtype=int)


Epoch 1: loss=3.10464
Epoch 1: Train= 0.17645 Test Accuracy=0.24887
Epoch 1: Best Test Accuracy=0.24887
Epoch 2: loss=2.57702
Epoch 2: Train= 0.29619 Test Accuracy=0.32188
Epoch 2: Best Test Accuracy=0.32188
Epoch 3: loss=2.34167
Epoch 3: Train= 0.36191 Test Accuracy=0.37534
Epoch 3: Best Test Accuracy=0.37534
Epoch 4: loss=2.17475
Epoch 4: Train= 0.40488 Test Accuracy=0.40247
Epoch 4: Best Test Accuracy=0.40247
Epoch 5: loss=2.03988
Epoch 5: Train= 0.43808 Test Accuracy=0.43519
Epoch 5: Best Test Accuracy=0.43519
Epoch 6: loss=1.94486
Epoch 6: Train= 0.46371 Test Accuracy=0.44084
Epoch 6: Best Test Accuracy=0.44084
Epoch 7: loss=1.86686
Epoch 7: Train= 0.48324 Test Accuracy=0.45765
Epoch 7: Best Test Accuracy=0.45765
Epoch 8: loss=1.80510
Epoch 8: Train= 0.49720 Test Accuracy=0.47003
Epoch 8: Best Test Accuracy=0.47003
Epoch 9: loss=1.74801
Epoch 9: Train= 0.51343 Test Accuracy=0.47666
Epoch 9: Best Test Accuracy=0.47666
Epoch 10: loss=1.71089
Epoch 10: Train= 0.52074 Test Accuracy=0.

KeyboardInterrupt: 

In [7]:
for i in range(1,11):
    loaded_weights_snn = torch.load(f'SSC_first_test/epochs_hetero/snn_{i}.pth')
    print(compute_classification_accuracy_hetero(s_train, y_train))
    print(loaded_weights_snn['accuracy'])

0.2368317610062893
0.3331367924528302
0.377063679245283
0.4137185534591195
0.42914701257861637
0.45936517295597484
0.4750884433962264
0.48452240566037735
0.49164701257861637


In [9]:
loaded_weights_snn = torch.load('Python_Tests/SSC_test_2/best_hetero/best_snn.pth')

w1 = torch.nn.Parameter(loaded_weights_snn['params']['w1'].to(device))
w2 = torch.nn.Parameter(loaded_weights_snn['params']['w2'].to(device))
v1 = torch.nn.Parameter(loaded_weights_snn['params']['v1'].to(device))
alpha_hetero_1 = torch.nn.Parameter(loaded_weights_snn['params']['alpha'].to(device))
beta_hetero_1 = torch.nn.Parameter(loaded_weights_snn['params']['beta'].to(device))
thresholds_1 = torch.nn.Parameter(loaded_weights_snn['params']['threshold'].to(device))
reset_1 = torch.nn.Parameter(loaded_weights_snn['params']['reset'].to(device))
rest_1 = torch.nn.Parameter(loaded_weights_snn['params']['rest'].to(device))
alpha_hetero_2 = torch.nn.Parameter(loaded_weights_snn['params']['alpha_2'].to(device))
beta_hetero_2 = torch.nn.Parameter(loaded_weights_snn['params']['beta_2'].to(device))
print(loaded_weights_snn['test_acc_hist'])
# print(compute_classification_accuracy_hetero(x_train, y_train))

0.5548349056603774


In [34]:
# loaded_weights_snn = torch.load('Python_Tests/SSC_hetero_lr_1e-3/epochs_hetero/snn_70.pth', map_location=torch.device('cpu'))
# loaded_weights_snn = torch.load('Python_Tests/SSC_homo_lr_1e-3/epochs/snn_69.pth', map_location=torch.device('cpu'))
# loaded_weights_snn = torch.load('Python_Tests/SSC_hetero_no_reg/epochs_hetero/snn_71.pth', map_location=torch.device('cpu'))
# loaded_weights_snn = torch.load('Python_Tests/SSC_homo_no_reg/epochs/snn_68.pth', map_location=torch.device('cpu'))
# loaded_weights_snn = torch.load('Python_Tests/SSC_hetero_no_reg/best_hetero/best_snn.pth', map_location=torch.device('cpu'))
# loaded_weights_snn = torch.load('Python_Tests/SSC_homo_no_reg/best/best_snn.pth', map_location=torch.device('cpu'))
w1 = torch.nn.Parameter(loaded_weights_snn['params']['w1'].to(device))
w2 = torch.nn.Parameter(loaded_weights_snn['params']['w2'].to(device))
v1 = torch.nn.Parameter(loaded_weights_snn['params']['v1'].to(device))
alpha_hetero_1 = torch.nn.Parameter(loaded_weights_snn['params']['alpha'].to(device))
beta_hetero_1 = torch.nn.Parameter(loaded_weights_snn['params']['beta'].to(device))
thresholds_1 = torch.nn.Parameter(loaded_weights_snn['params']['threshold'].to(device))
reset_1 = torch.nn.Parameter(loaded_weights_snn['params']['reset'].to(device))
rest_1 = torch.nn.Parameter(loaded_weights_snn['params']['rest'].to(device))
alpha_hetero_2 = torch.nn.Parameter(loaded_weights_snn['params']['alpha_2'].to(device))
beta_hetero_2 = torch.nn.Parameter(loaded_weights_snn['params']['beta_2'].to(device))
print(loaded_weights_snn['test_acc_hist'])
# print(rest_1)

[0.21363993710691823, 0.2592865566037736, 0.2797268081761006, 0.30625982704402516, 0.33176100628930816, 0.35834316037735847, 0.36453419811320753, 0.3553950471698113, 0.40433372641509435, 0.40521816037735847, 0.415438286163522, 0.41037735849056606, 0.4165683962264151, 0.42914701257861637, 0.44079205974842767, 0.4392688679245283, 0.41592963836477986, 0.44089033018867924, 0.44978380503144655, 0.39219732704402516, 0.4367138364779874, 0.45902122641509435, 0.4495381289308176, 0.4537637578616352, 0.4410868710691824, 0.44025157232704404, 0.4054147012578616, 0.4374508647798742, 0.43209512578616355, 0.4349941037735849, 0.4679638364779874, 0.4427574685534591, 0.4785770440251572, 0.4221206761006289, 0.44585298742138363, 0.455188679245283, 0.4410377358490566, 0.4641312893081761, 0.47110849056603776, 0.4659001572327044, 0.47420400943396224, 0.4490467767295597, 0.4882566823899371, 0.47214033018867924, 0.4900746855345912, 0.47420400943396224, 0.4878636006289308, 0.47921580188679247, 0.4706662735849056

In [30]:
# loaded_weights_snn = torch.load('Python_Tests/Hybrid_Hetero_int_5/snn_best.pth', map_location=torch.device('cpu'))
loaded_weights_snn = torch.load('Python_Tests/Hybrid_Homo/epochs/snn_66.pth', map_location=torch.device('cpu'))
print(loaded_weights_snn['test_acc_hist'])

[0.5519850628930818, 0.5764544025157232, 0.5718356918238994, 0.5355738993710691, 0.5731623427672956, 0.5780758647798742, 0.5762087264150944, 0.5833333333333334, 0.5679048742138365, 0.58500393081761, 0.5812205188679245, 0.5873624213836478, 0.5796973270440252, 0.5815153301886793, 0.5829402515723271, 0.584561713836478, 0.5839720911949685, 0.5828419811320755, 0.5840703616352201, 0.5792059748427673, 0.5759139150943396, 0.5826945754716981, 0.5714426100628931, 0.5782232704402516, 0.5714917452830188, 0.5747838050314465, 0.5715408805031447, 0.5723270440251572, 0.5613207547169812, 0.5681014150943396, 0.5709512578616353, 0.56343356918239, 0.567560927672956, 0.5658411949685535, 0.5675117924528302, 0.5656937893081762, 0.568936713836478, 0.5602397798742138, 0.566185141509434, 0.5497739779874213, 0.5555227987421384, 0.5637775157232704, 0.5609768081761006, 0.5491843553459119, 0.5337067610062893, 0.5500196540880503, 0.5556702044025157, 0.5448113207547169, 0.5343946540880503, 0.5593062106918238, 0.54633

# HOMO HYBRID

In [None]:
import os
import h5py

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns

import torch
# print(torch.__version__)
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
# import torchvision
from torch.utils import data


import pickle
import math
# from utils import get_shd_dataset

# The coarse network structure and the time steps are dicated by the SHD dataset.
nb_inputs  = 700
nb_hidden  = 200
nb_outputs = 35

time_step = 1e-3
nb_steps = 100
max_time = 1.4

batch_size = 64

dtype = torch.float

# Check whether a GPU is available
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

# Here we load the Dataset
# cache_dir = os.path.expanduser("~/data")
# cache_subdir = "hdspikes"
# get_shd_dataset(cache_dir, cache_subdir)

# Here we load the Dataset
cache_dir = os.path.expanduser("~/data")
cache_subdir = "ssc_data"
train_file = h5py.File(os.path.join(cache_dir, cache_subdir, 'ssc_train.h5'), 'r')
validation_file = h5py.File(os.path.join(cache_dir, cache_subdir, 'ssc_valid.h5'), 'r')
test_file = h5py.File(os.path.join(cache_dir, cache_subdir, 'ssc_test.h5'), 'r')

x_train = train_file['spikes']
y_train = train_file['labels']
x_valid = validation_file['spikes']
y_valid = validation_file['labels']
x_test = test_file['spikes']
y_test = test_file['labels']

# Here we load the Dataset
# cache_dir = os.path.expanduser("~/data")
# cache_subdir = "hdspikes"
# get_shd_dataset(cache_dir, cache_subdir)

# train_file = h5py.File(os.path.join(cache_dir, cache_subdir, 'shd_train.h5'), 'r')
# test_file = h5py.File(os.path.join(cache_dir, cache_subdir, 'shd_test.h5'), 'r')

# x_train = train_file['spikes']
# y_train = train_file['labels']
# x_test = test_file['spikes']
# y_test = test_file['labels']



def sparse_data_generator_from_hdf5_spikes(X, y, batch_size, nb_steps, nb_units, max_time, shuffle=True):
    """ This generator takes a spike dataset and generates spiking network input as sparse tensors.

    Args:
        X: The data ( sample x event x 2 ) the last dim holds (time,neuron) tuples
        y: The labels
    """

    labels_ = np.array(y,dtype=int)
    number_of_batches = len(labels_)//batch_size
    sample_index = np.arange(len(labels_))

    # compute discrete firing times
    firing_times = X['times']
    units_fired = X['units']

    time_bins = np.linspace(0, max_time, num=nb_steps)

    if shuffle:
        np.random.shuffle(sample_index)

    total_batch_count = 0
    counter = 0
    while counter<number_of_batches:
        batch_index = sample_index[batch_size*counter:batch_size*(counter+1)]

        coo = [ [] for i in range(3) ]
        for bc,idx in enumerate(batch_index):
            times = np.digitize(firing_times[idx], time_bins)
            units = units_fired[idx]
            batch = [bc for _ in range(len(times))]

            coo[0].extend(batch)
            coo[1].extend(times)
            coo[2].extend(units)

        i = torch.LongTensor(coo).to(device)
        v = torch.FloatTensor(np.ones(len(coo[0]))).to(device)

        X_batch = torch.sparse.FloatTensor(i, v, torch.Size([batch_size,nb_steps,nb_units])).to(device)
        y_batch = torch.tensor(labels_[batch_index],device=device)

        yield X_batch.to(device=device), y_batch.to(device=device)

        counter += 1
        
class SurrGradSpike(torch.autograd.Function):
    """
    Here we implement our spiking nonlinearity which also implements
    the surrogate gradient. By subclassing torch.autograd.Function,
    we will be able to use all of PyTorch's autograd functionality.
    Here we use the normalized negative part of a fast sigmoid
    as this was done in Zenke & Ganguli (2018).
    """

    scale = 100.0 # controls steepness of surrogate gradient

    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we compute a step function of the input Tensor
        and return it. ctx is a context object that we use to stash information which
        we need to later backpropagate our error signals. To achieve this we use the
        ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        out = torch.zeros_like(input)
        out[input > 0] = 1.0
        return out

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor we need to compute the
        surrogate gradient of the loss with respect to the input.
        Here we use the normalized negative part of a fast sigmoid
        as this was done in Zenke & Ganguli (2018).
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad = grad_input/(SurrGradSpike.scale*torch.abs(input)+1.0)**2
        return grad

# here we overwrite our naive spike function by the "SurrGradSpike" nonlinearity which implements a surrogate gradient
spike_fn  = SurrGradSpike.apply

def dist_fn(dist):
    return {
        'gamma': lambda mean, k, size: np.random.gamma(k, scale=mean/k, size=size),
        'normal': lambda mean, k, size: np.random.normal(loc=mean, scale=mean/np.sqrt(k), size=size), #change standard deviation to match gamma
        'uniform': lambda _, maximum, size: np.random.uniform(low=0, high=maximum, size=size),
    }[dist.lower()]
print("init done")


class MLP_alpha_beta_single1(nn.Module):
    def __init__(self):
        super(MLP_alpha_beta_single1, self).__init__()
        self.input_size = 928 + 15
        self.hidden_size = 1024
        self.output_size = 7
        self.layers = nn.Sequential(
            nn.Linear(self.input_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.output_size),
            nn.Sigmoid()
        )
        self.init_weights()

    def forward(self, x):
        return self.layers(x)

    def init_weights(self):
        with torch.no_grad():
            # Initialize first layer weights and biases
            self.layers[0].weight.fill_(0)
            self.layers[0].bias.fill_(0)
            for i in range(self.output_size):
                self.layers[0].weight[i, i] = 1

            # Initialize second layer weights and biases
            self.layers[2].weight.fill_(0)
            self.layers[2].bias.fill_(0)
            for i in range(self.output_size):
                self.layers[2].weight[i, i] = 1
                
                
def run_snn_hybrid_alpha_beta_spikes_homo(inputs, mlp, mlp_interval, batch_size_MLP):

    # Initialize local copies of alpha, beta, threshold, reset and rest for all 200 hidden neurons
    alpha_1_local = alpha_homo_1
    beta_1_local = beta_homo_1
    thresholds_local = thresholds_1
    reset_local = reset_1
    rest_local = rest_1
    alpha_2_local = alpha_homo_2
    beta_2_local = beta_homo_2
    alpha_1_local = alpha_1_local.expand(batch_size_MLP, 1)
    beta_1_local = beta_1_local.expand(batch_size_MLP, 1)
    thresholds_local = thresholds_local.expand(batch_size_MLP, 1)
    reset_local = reset_local.expand(batch_size_MLP, 1)
    rest_local = rest_local.expand(batch_size_MLP, 1)
    alpha_2_local = alpha_2_local.expand(batch_size_MLP, 1)
    beta_2_local = beta_2_local.expand(batch_size_MLP, 1)

    # Initialize synaptic and membrane potentials
    syn = torch.zeros((batch_size_MLP, nb_hidden), device=device, dtype=dtype)
    mem = torch.zeros((batch_size_MLP, nb_hidden), device=device, dtype=dtype)

    # Initialize recordings for membrane potentials and spikes
    mem_rec = []
    spk_rec = []

    test = 0

    # Initialize outputs for the hidden layer
    out = torch.zeros((batch_size_MLP, nb_hidden), device=device, dtype=dtype)
    h1_from_input = torch.einsum("abc,cd->abd", (inputs, w1))

    # Prepare readout layer variables
    flt2 = torch.zeros((batch_size_MLP, nb_outputs), device=device, dtype=dtype)
    out2 = torch.zeros((batch_size_MLP, nb_outputs), device=device, dtype=dtype)
    out_rec = [out2]

    for t in range(nb_steps):
        # Compute hidden layer activity
        h1 = h1_from_input[:, t] + torch.einsum("ab,bc->ac", (out, v1))
        mthr = mem - thresholds_local
        out = spike_fn(mthr)
        rst = torch.zeros_like(mem)
        c = (mthr > 0)
        rst[c] = torch.ones_like(mem)[c]
        # rst = out.detach() * reset  # Reset mechanism considering individual reset values

        # Update synaptic and membrane potentials
        syn = alpha_1_local * syn + h1

        mem = beta_1_local * (mem - rest_local) + rest_local + (1 - beta_1_local) * syn - rst * (thresholds_local - reset_local)

        # Record membrane potentials and spikes
        mem_rec.append(mem)
        spk_rec.append(out)


        # Now compute h2 on the fly
        h2_t = torch.einsum("ab,bc->ac", (out, w2))
#         flt2 = alpha_1_local * flt2 + h2_t
#         out2 = beta_1_local * out2 + flt2*(1-beta_1_local)
        flt2 = alpha_2_local * flt2 + h2_t
        out2 = beta_2_local * out2 + flt2*(1-beta_2_local)
        
        out_rec.append(out2)

         # Flatten and concatenate spikes for each item in the batch
        input_spikes_flat = inputs[:, t, :].reshape(batch_size_MLP, -1)  # Shape: [batch_size, 700]
        hidden_spikes_flat = out.reshape(batch_size_MLP, -1)  # Shape: [batch_size, 200]
        output_spikes_flat = out2.reshape(batch_size_MLP, -1)  # Shape: [batch_size, 20]

        # Time tensor
        time_tensor = torch.full((batch_size_MLP, 1), t, device=device, dtype=dtype)

        # Concatenate tensors
        mlp_input = torch.cat([
            alpha_1_local, beta_1_local, 
            thresholds_local, reset_local, rest_local,
            alpha_2_local, beta_2_local,
            time_tensor,
            input_spikes_flat, hidden_spikes_flat, output_spikes_flat
        ], dim=1)

        # Process with MLP (in a single call for the whole batch)
       

        if t % mlp_interval == 0 and t != 0:
            mlp_outputs = mlp(mlp_input)
            # Update alpha_local and beta_local based on MLP outputs
            alpha_1_local = mlp_outputs[:, 0].unsqueeze(1)
            beta_1_local = mlp_outputs[:, 1].unsqueeze(1)
            threshold_local = (mlp_outputs[:, 2] + 0.5).unsqueeze(1)
            reset_local = (mlp_outputs[:, 3] - 0.5).unsqueeze(1)
            rest_local = (mlp_outputs[:, 4] - 0.5).unsqueeze(1)
            alpha_2_local = mlp_outputs[:, 5].unsqueeze(1)
            beta_2_local = mlp_outputs[:, 6].unsqueeze(1)


    # Stack recordings for output
    mem_rec = torch.stack(mem_rec, dim=1).to(device)
    spk_rec = torch.stack(spk_rec, dim=1).to(device)
    out_rec = torch.stack(out_rec[1:], dim=1).to(device)  # Skip the initial zero tensor

    other_recs = [mem_rec, spk_rec]

    return out_rec, other_recs


def compute_classification_accuracy_MLP_homo(x_data, y_data, mlp, mlp_interval):
    """ Computes classification accuracy on supplied data in batches. """
    accs = []
    for x_local, y_local in sparse_data_generator_from_hdf5_spikes(x_data, y_data, batch_size_homo, nb_steps, nb_inputs, max_time, shuffle=False):
        output, _ = run_snn_hybrid_alpha_beta_spikes_homo(x_local.to_dense(),  mlp=mlp, mlp_interval=mlp_interval, batch_size_MLP=batch_size_homo)
        m,_= torch.max(output,1) # max over time
        _,am=torch.max(m,1)      # argmax over output units
        tmp = np.mean((y_local==am).detach().cpu().numpy()) # compare to labels
        accs.append(tmp)
    return np.mean(accs)

def train_hybrid(mlp, x_data, y_data, lr=1e-3, nb_epochs=10, mlp_interval=10):

    snn_params = [w1, w2, v1, 
                  alpha_homo_1, beta_homo_1,
                  thresholds_1, reset_1, rest_1,
                  alpha_homo_2, beta_homo_2]
#  

    # Optimizers
    combined_params = [
        {'params': snn_params, 'lr': lr},  # Parameters for SNN with specific learning rate
        {'params': mlp.parameters(), 'lr': lr}  # Parameters for MLP with its own learning rate
    ]

    # Using a single optimizer for both SNN and MLP
    combined_optimizer = torch.optim.Adam(combined_params)


    #Loss functions
    loss_fn = nn.NLLLoss()
    log_softmax_fn = nn.LogSoftmax(dim=1)

    best_accuracy = 0

    loss_hist = []
    train_acc_hist = []
    test_acc_hist = []
    for epoch in range(nb_epochs):
        local_loss = []
        local_ground_loss = []
        local_reg_loss = []
        accs = []
        for x_local, y_local in sparse_data_generator_from_hdf5_spikes(x_data, y_data, batch_size_homo, nb_steps, nb_inputs, max_time):
                output, recs = run_snn_hybrid_alpha_beta_spikes_homo(inputs=x_local.to_dense(), mlp=mlp, mlp_interval=mlp_interval, batch_size_MLP=batch_size_homo)
                _ , spks = recs
                m, _ = torch.max(output,1)

                _,am=torch.max(m,1)      # argmax over output units
                tmp = np.mean((y_local==am).detach().cpu().numpy()) # compare to labels
                accs.append(tmp)

                log_p_y = log_softmax_fn(m)
                ground_loss = loss_fn(log_p_y, y_local)
                reg_loss = 1e-6*torch.sum(spks) # L1 loss on total number of spikes
                reg_loss += 1e-6*torch.mean(torch.sum(torch.sum(spks,dim=0),dim=0)**2) # L2 loss on spikes per neuron

                loss_MLP = ground_loss + reg_loss

                combined_optimizer.zero_grad()
                loss_MLP.backward()
                combined_optimizer.step()
                
                # Clamping the values
                with torch.no_grad():
                    alpha_homo_1.clamp_(0.367, 0.995)
                    beta_homo_1.clamp_(0.367, 0.995)
                    alpha_homo_2.clamp_(0.367, 0.995)
                    beta_homo_2.clamp_(0.367, 0.995)
                    thresholds_1.clamp_(0.5, 1.5)

                local_loss.append(loss_MLP.item())
                local_ground_loss.append(ground_loss.item())
                local_reg_loss.append(reg_loss.item())


        mean_loss = np.mean(local_loss)
        loss_hist.append(mean_loss)
        print(f"Epoch {epoch+1}: loss={mean_loss:.5f}")
        print("ground_loss", np.mean(local_ground_loss))
        print("reg_loss", np.mean(local_reg_loss))
        train_accuracy = np.mean(accs)
        test_accuracy = compute_classification_accuracy_MLP_homo(x_test, y_test, mlp, mlp_interval)
        train_acc_hist.append(train_accuracy)
        test_acc_hist.append(test_accuracy)
        print(f"Epoch {epoch + 1}: Train= {train_accuracy:.5f} Test Accuracy={test_accuracy:.5f}")

        # Print the best accuracy so far
        if test_accuracy > best_accuracy:
            best_accuracy = test_accuracy

            directory = 'Hybrid_Homo'

            # Create the directory if it does not exist
            if not os.path.exists(directory):
                os.makedirs(directory)

            best_model_state = mlp_mlp.state_dict()

            # Save the file in the specified directory
            file_path = os.path.join(directory, 'mlp.pt')
            torch.save(best_model_state, file_path)

            # Save parameters only when a new best accuracy is achieved
            # Create a dictionary of current parameters to save
            saved_params_homo = {
                'w1': w1.clone(),
                'w2': w2.clone(),
                'v1': v1.clone(),
                'alpha': alpha_homo_1.clone(),
                'beta': beta_homo_1.clone(),
                'threshold': thresholds_1.clone(),
                'reset': reset_1.clone(),
                'rest': rest_1.clone(),
                'alpha_2': alpha_homo_2.clone(),
                'beta_2': beta_homo_2.clone()
            }

            # Save the file in the specified directory
            file_path = os.path.join(directory, 'snn.pth')
            torch.save({
                'epoch': epoch + 1,
                'accuracy': best_accuracy,
                'params': saved_params_homo,
                'loss': loss_hist,
                'train_acc_hist': train_acc_hist,
                'test_acc_hist': test_acc_hist
            }, file_path)
        else:
            print("Best", best_accuracy)

    return loss_hist

loaded_weights_snn = torch.load('Python_Tests/SSC_homo/epochs/snn_35.pth')

w1 = torch.nn.Parameter(loaded_weights_snn['params']['w1'].to(device))
w2 = torch.nn.Parameter(loaded_weights_snn['params']['w2'].to(device))
v1 = torch.nn.Parameter(loaded_weights_snn['params']['v1'].to(device))
alpha_homo_1 = torch.nn.Parameter(loaded_weights_snn['params']['alpha'].to(device))
beta_homo_1 = torch.nn.Parameter(loaded_weights_snn['params']['beta'].to(device))
thresholds_1 = torch.nn.Parameter(loaded_weights_snn['params']['threshold'].to(device))
reset_1 = torch.nn.Parameter(loaded_weights_snn['params']['reset'].to(device))
rest_1 = torch.nn.Parameter(loaded_weights_snn['params']['rest'].to(device))
alpha_homo_2 = torch.nn.Parameter(loaded_weights_snn['params']['alpha_2'].to(device))
beta_homo_2 = torch.nn.Parameter(loaded_weights_snn['params']['beta_2'].to(device))
print(loaded_weights_snn['accuracy'])
# print(compute_classification_accuracy_homo(x_train, y_train))


nb_epochs_mlp = 80
batch_size_homo = 64
mlp_interval = 10
mlp_mlp = MLP_alpha_beta_single1().to(device)
loss_hist_MLP = train_hybrid(mlp_mlp, x_train, y_train, lr=2e-4, nb_epochs=nb_epochs_mlp, mlp_interval = mlp_interval)

cuda
init done
0.5752751572327044


  labels_ = np.array(y,dtype=int)


Epoch 1: loss=1.65236
ground_loss 1.6100557048609947
reg_loss 0.04229970570366144
Epoch 1: Train= 0.55821 Test Accuracy=0.56014
Epoch 2: loss=1.41226
ground_loss 1.3753598979006791
reg_loss 0.036897322429980735
Epoch 2: Train= 0.61599 Test Accuracy=0.56604
Epoch 3: loss=1.37327
ground_loss 1.3384665056646021
reg_loss 0.034807233291538314
Epoch 3: Train= 0.62356 Test Accuracy=0.55120
Best 0.5660377358490566
Epoch 4: loss=1.34750
ground_loss 1.3122466347094206
reg_loss 0.035249515772308206
Epoch 4: Train= 0.63078 Test Accuracy=0.57370
Epoch 5: loss=1.32519
ground_loss 1.2906800935691043
reg_loss 0.03451130066487221
Epoch 5: Train= 0.63298 Test Accuracy=0.57690
Epoch 6: loss=1.30405
ground_loss 1.269622088338077
reg_loss 0.034432012666778304
Epoch 6: Train= 0.63788 Test Accuracy=0.56820
Best 0.5768966194968553
Epoch 7: loss=1.28070
ground_loss 1.2464253988379235
reg_loss 0.034277205176152244
Epoch 7: Train= 0.64471 Test Accuracy=0.56815
Best 0.5768966194968553
Epoch 8: loss=1.26482
ground

# HETERO HYBRID

In [2]:
import os
import h5py

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns

import torch
# print(torch.__version__)
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
# import torchvision
from torch.utils import data


import pickle
import math
# from utils import get_shd_dataset

# The coarse network structure and the time steps are dicated by the SHD dataset.
nb_inputs  = 700
nb_hidden  = 200
nb_outputs = 35

time_step = 1e-3
nb_steps = 100
max_time = 1.4

batch_size = 64

dtype = torch.float

# Check whether a GPU is available
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

# Here we load the Dataset
# cache_dir = os.path.expanduser("~/data")
# cache_subdir = "hdspikes"
# get_shd_dataset(cache_dir, cache_subdir)

# Here we load the Dataset
cache_dir = os.path.expanduser("~/data")
cache_subdir = "ssc_data"
train_file = h5py.File(os.path.join(cache_dir, cache_subdir, 'ssc_train.h5'), 'r')
validation_file = h5py.File(os.path.join(cache_dir, cache_subdir, 'ssc_valid.h5'), 'r')
test_file = h5py.File(os.path.join(cache_dir, cache_subdir, 'ssc_test.h5'), 'r')

x_train = train_file['spikes']
y_train = train_file['labels']
x_valid = validation_file['spikes']
y_valid = validation_file['labels']
x_test = test_file['spikes']
y_test = test_file['labels']

# Here we load the Dataset
# cache_dir = os.path.expanduser("~/data")
# cache_subdir = "hdspikes"
# get_shd_dataset(cache_dir, cache_subdir)

# train_file = h5py.File(os.path.join(cache_dir, cache_subdir, 'shd_train.h5'), 'r')
# test_file = h5py.File(os.path.join(cache_dir, cache_subdir, 'shd_test.h5'), 'r')

# x_train = train_file['spikes']
# y_train = train_file['labels']
# x_test = test_file['spikes']
# y_test = test_file['labels']



def sparse_data_generator_from_hdf5_spikes(X, y, batch_size, nb_steps, nb_units, max_time, shuffle=True):
    """ This generator takes a spike dataset and generates spiking network input as sparse tensors.

    Args:
        X: The data ( sample x event x 2 ) the last dim holds (time,neuron) tuples
        y: The labels
    """

    labels_ = np.array(y,dtype=int)
    number_of_batches = len(labels_)//batch_size
    sample_index = np.arange(len(labels_))

    # compute discrete firing times
    firing_times = X['times']
    units_fired = X['units']

    time_bins = np.linspace(0, max_time, num=nb_steps)

    if shuffle:
        np.random.shuffle(sample_index)

    total_batch_count = 0
    counter = 0
    while counter<number_of_batches:
        batch_index = sample_index[batch_size*counter:batch_size*(counter+1)]

        coo = [ [] for i in range(3) ]
        for bc,idx in enumerate(batch_index):
            times = np.digitize(firing_times[idx], time_bins)
            units = units_fired[idx]
            batch = [bc for _ in range(len(times))]

            coo[0].extend(batch)
            coo[1].extend(times)
            coo[2].extend(units)

        i = torch.LongTensor(coo).to(device)
        v = torch.FloatTensor(np.ones(len(coo[0]))).to(device)

        X_batch = torch.sparse.FloatTensor(i, v, torch.Size([batch_size,nb_steps,nb_units])).to(device)
        y_batch = torch.tensor(labels_[batch_index],device=device)

        yield X_batch.to(device=device), y_batch.to(device=device)

        counter += 1
        
class SurrGradSpike(torch.autograd.Function):
    """
    Here we implement our spiking nonlinearity which also implements
    the surrogate gradient. By subclassing torch.autograd.Function,
    we will be able to use all of PyTorch's autograd functionality.
    Here we use the normalized negative part of a fast sigmoid
    as this was done in Zenke & Ganguli (2018).
    """

    scale = 100.0 # controls steepness of surrogate gradient

    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we compute a step function of the input Tensor
        and return it. ctx is a context object that we use to stash information which
        we need to later backpropagate our error signals. To achieve this we use the
        ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        out = torch.zeros_like(input)
        out[input > 0] = 1.0
        return out

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor we need to compute the
        surrogate gradient of the loss with respect to the input.
        Here we use the normalized negative part of a fast sigmoid
        as this was done in Zenke & Ganguli (2018).
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad = grad_input/(SurrGradSpike.scale*torch.abs(input)+1.0)**2
        return grad

# here we overwrite our naive spike function by the "SurrGradSpike" nonlinearity which implements a surrogate gradient
spike_fn  = SurrGradSpike.apply

def dist_fn(dist):
    return {
        'gamma': lambda mean, k, size: np.random.gamma(k, scale=mean/k, size=size),
        'normal': lambda mean, k, size: np.random.normal(loc=mean, scale=mean/np.sqrt(k), size=size), #change standard deviation to match gamma
        'uniform': lambda _, maximum, size: np.random.uniform(low=0, high=maximum, size=size),
    }[dist.lower()]
print("init done")


class hetero_mlp_a_b_spikes(nn.Module):
    def __init__(self):
        super(hetero_mlp_a_b_spikes, self).__init__()
        self.input_size = 1961+45 #(adding 40 for alpha and beta 2)
        self.hidden_size = 2048
        self.output_size = 1070 # 200 each for alpha, beta, threshold, reset, rest (adding 40 for alpha and beta 2)

        self.layers = nn.Sequential(
            nn.Linear(self.input_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.output_size),
            nn.Sigmoid()
        )

        self.init_weights()

    def forward(self, x):
        return self.layers(x)

    def init_weights(self):
        with torch.no_grad():
            # Initialize first layer weights and biases
            self.layers[0].weight.fill_(0)
            self.layers[0].bias.fill_(0)

            # Pass the first 1000 inputs directly to the hidden layer; the rest to 0
            for i in range(1040):
                self.layers[0].weight[i, i] = 1

            # Initialize second layer weights and biases
            self.layers[2].weight.fill_(0)
            self.layers[2].bias.fill_(0)

            # Pass the first 1000 inputs directly to the hidden layer; the rest to 0
            for i in range(1040):
                self.layers[2].weight[i, i] = 1
                
                
def run_snn_hybrid_alpha_beta_spikes_HETERO(inputs, mlp, mlp_interval, batch_size_MLP):

    # Initialize local copies of alpha, beta, threshold, reset and rest for all 200 hidden neurons
    alpha_1_local = alpha_hetero_1
    beta_1_local = beta_hetero_1
    thresholds_local = thresholds_1
    reset_local = reset_1
    rest_local = rest_1
    alpha_2_local = alpha_hetero_2
    beta_2_local = beta_hetero_2
    alpha_1_local = alpha_1_local.expand(batch_size_MLP, 200)
    beta_1_local = beta_1_local.expand(batch_size_MLP, 200)
    thresholds_local = thresholds_local.expand(batch_size_MLP, 200)
    reset_local = reset_local.expand(batch_size_MLP, 200)
    rest_local = rest_local.expand(batch_size_MLP, 200)
    alpha_2_local = alpha_2_local.expand(batch_size_MLP, 35)
    beta_2_local = beta_2_local.expand(batch_size_MLP, 35)

    # Initialize synaptic and membrane potentials
    syn = torch.zeros((batch_size_MLP, nb_hidden), device=device, dtype=dtype)
    mem = torch.zeros((batch_size_MLP, nb_hidden), device=device, dtype=dtype)

    # Initialize recordings for membrane potentials and spikes
    mem_rec = []
    spk_rec = []

    # Initialize outputs for the hidden layer
    out = torch.zeros((batch_size_MLP, nb_hidden), device=device, dtype=dtype)
    h1_from_input = torch.einsum("abc,cd->abd", (inputs, w1))

    # Prepare readout layer variables
    flt2 = torch.zeros((batch_size_MLP, nb_outputs), device=device, dtype=dtype)
    out2 = torch.zeros((batch_size_MLP, nb_outputs), device=device, dtype=dtype)
    out_rec = [out2]

    for t in range(nb_steps):

        h1 = h1_from_input[:, t] + torch.einsum("ab,bc->ac", (out, v1))
        mthr = mem - thresholds_local
        out = spike_fn(mthr)
        rst = torch.zeros_like(mem)
        c = (mthr > 0)
        rst[c] = torch.ones_like(mem)[c]
        # rst = out.detach() * reset  # Reset mechanism considering individual reset values

        # Update synaptic and membrane potentials
        syn = alpha_1_local * syn + h1
        mem = beta_1_local * (mem - rest_local) + rest_local + (1 - beta_1_local) * syn - rst * (thresholds_local - reset_local)

        # Record membrane potentials and spikes
        mem_rec.append(mem)
        spk_rec.append(out)

        # Now compute h2 on the fly
        h2_t = torch.einsum("ab,bc->ac", (out, w2))
        flt2 = alpha_2_local * flt2 + h2_t
#         out2 = beta_2_local * out2 + flt2
        out2 = beta_2_local * out2 + flt2*(1-beta_2_local)
        out_rec.append(out2)

         # Flatten and concatenate spikes for each item in the batch
        input_spikes_flat = inputs[:, t, :].reshape(batch_size_MLP, -1)  # Shape: [batch_size, 700]
        hidden_spikes_flat = out.reshape(batch_size_MLP, -1)  # Shape: [batch_size, 200]
        output_spikes_flat = out2.reshape(batch_size_MLP, -1)  # Shape: [batch_size, 20]

        # Time tensor
        time_tensor = torch.full((batch_size_MLP, 1), t, device=device, dtype=dtype)

        # Concatenate tensors
        mlp_input = torch.cat([
            alpha_1_local, beta_1_local, thresholds_local, reset_local, rest_local,
            alpha_2_local, beta_2_local,
            time_tensor,
            input_spikes_flat, hidden_spikes_flat, output_spikes_flat
        ], dim=1)


        if t % mlp_interval == 0:
            mlp_outputs = mlp(mlp_input)
            # Update alpha_local and beta_local based on MLP outputs
            alpha_1_local, beta_1_local = mlp_outputs[:, :200], mlp_outputs[:, 200:400]
            threshold_local = mlp_outputs[:, 400:600] + 0.5
            reset_local, rest_local = mlp_outputs[:, 600:800] - 0.5, mlp_outputs[:, 800:1000] - 0.5
#             print(beta_2_local.shape)
            alpha_2_local, beta_2_local = mlp_outputs[:, 1000:1035], mlp_outputs[:, 1035:1070]
#             print(beta_2_local.shape)

    # Stack recordings for output
    mem_rec = torch.stack(mem_rec, dim=1).to(device)
    spk_rec = torch.stack(spk_rec, dim=1).to(device)
    out_rec = torch.stack(out_rec[1:], dim=1).to(device)  # Skip the initial zero tensor

    other_recs = [mem_rec, spk_rec]

    return out_rec, other_recs


def compute_classification_accuracy_MLP(x_data, y_data, mlp, mlp_interval):
    """ Computes classification accuracy on supplied data in batches. """
    accs = []
    with torch.no_grad():
        for x_local, y_local in sparse_data_generator_from_hdf5_spikes(x_data, y_data, batch_size_hetero, nb_steps, nb_inputs, max_time, shuffle=False):
            output, _ = run_snn_hybrid_alpha_beta_spikes_HETERO(x_local.to_dense(),  mlp=mlp, mlp_interval=mlp_interval, batch_size_MLP=batch_size_hetero)
            m,_= torch.max(output,1) # max over time
            _,am=torch.max(m,1)      # argmax over output units
            tmp = np.mean((y_local==am).detach().cpu().numpy()) # compare to labels
            accs.append(tmp)
    return np.mean(accs)

def train_hybrid(mlp, x_data, y_data, lr=1e-3, nb_epochs=10, mlp_interval=10):

    snn_params = [w1, w2, v1, 
                  alpha_hetero_1, beta_hetero_1,
                  thresholds_1, reset_1, rest_1,
                  alpha_hetero_2, beta_hetero_2]
#  

    # Optimizers
    combined_params = [
        {'params': snn_params, 'lr': lr},  # Parameters for SNN with specific learning rate
        {'params': mlp.parameters(), 'lr': lr}  # Parameters for MLP with its own learning rate
    ]

    # Using a single optimizer for both SNN and MLP
    combined_optimizer = torch.optim.Adam(combined_params)


    #Loss functions
    loss_fn = nn.NLLLoss()
    log_softmax_fn = nn.LogSoftmax(dim=1)

    best_accuracy = 0

    loss_hist = []
    train_acc_hist = []
    test_acc_hist = []
    for epoch in range(nb_epochs):
        local_loss = []
        local_ground_loss = []
        local_reg_loss = []
        accs = []
#         print(w1)
#         print(alpha_hetero_1)
#         print(reset_1)
        for x_local, y_local in sparse_data_generator_from_hdf5_spikes(x_data, y_data, batch_size_hetero, nb_steps, nb_inputs, max_time):
                output, recs = run_snn_hybrid_alpha_beta_spikes_HETERO(inputs=x_local.to_dense(), mlp=mlp, mlp_interval=mlp_interval, batch_size_MLP=batch_size_hetero)
                _ , spks = recs
                m, _ = torch.max(output,1)

                _,am=torch.max(m,1)      # argmax over output units
                tmp = np.mean((y_local==am).detach().cpu().numpy()) # compare to labels
                accs.append(tmp)

                log_p_y = log_softmax_fn(m)
                ground_loss = loss_fn(log_p_y, y_local)
                reg_loss = 1e-6*torch.sum(spks) # L1 loss on total number of spikes
                reg_loss += 1e-6*torch.mean(torch.sum(torch.sum(spks,dim=0),dim=0)**2) # L2 loss on spikes per neuron

                loss_MLP = ground_loss + reg_loss

                combined_optimizer.zero_grad()
                loss_MLP.backward()
                combined_optimizer.step()
                
                # Clamping the values
                with torch.no_grad():
                    alpha_hetero_1.clamp_(0.367, 0.995)
                    beta_hetero_1.clamp_(0.367, 0.995)
                    alpha_hetero_2.clamp_(0.367, 0.995)
                    beta_hetero_2.clamp_(0.367, 0.995)
                    thresholds_1.clamp_(0.5, 1.5)

                local_loss.append(loss_MLP.item())
                local_ground_loss.append(ground_loss.item())
                local_reg_loss.append(reg_loss.item())


        mean_loss = np.mean(local_loss)
        loss_hist.append(mean_loss)
        print(f"Epoch {epoch+1}: loss={mean_loss:.5f}")
        print("ground_loss", np.mean(local_ground_loss))
        print("reg_loss", np.mean(local_reg_loss))
        train_accuracy = np.mean(accs)
        test_accuracy = compute_classification_accuracy_MLP(x_test, y_test, mlp, mlp_interval)
        train_acc_hist.append(train_accuracy)
        test_acc_hist.append(test_accuracy)
        print(f"Epoch {epoch + 1}: Train= {train_accuracy:.5f} Test Accuracy={test_accuracy:.5f}")

        # Print the best accuracy so far
        if test_accuracy > best_accuracy:
            best_accuracy = test_accuracy

            directory = 'Hybrid_Hetero'

            # Create the directory if it does not exist
            if not os.path.exists(directory):
                os.makedirs(directory)

            best_model_state = mlp_mlp.state_dict()

            # Save the file in the specified directory
            file_path = os.path.join(directory, 'mlp.pt')
            torch.save(best_model_state, file_path)

            # Save parameters only when a new best accuracy is achieved
            # Create a dictionary of current parameters to save
            saved_params_hetero = {
                'w1': w1.clone(),
                'w2': w2.clone(),
                'v1': v1.clone(),
                'alpha': alpha_hetero_1.clone(),
                'beta': beta_hetero_1.clone(),
                'threshold': thresholds_1.clone(),
                'reset': reset_1.clone(),
                'rest': rest_1.clone(),
                'alpha_2': alpha_hetero_2.clone(),
                'beta_2': beta_hetero_2.clone()
            }

            # Save the file in the specified directory
            file_path = os.path.join(directory, 'snn.pth')
            torch.save({
                'epoch': epoch + 1,
                'accuracy': best_accuracy,
                'params': saved_params_hetero,
                'loss': loss_hist,
                'train_acc_hist': train_acc_hist,
                'test_acc_hist': test_acc_hist
            }, file_path)
        else:
            print("Best", best_accuracy)

    return loss_hist

loaded_weights_snn = torch.load('SSC_test_2/epochs_hetero/snn_15.pth')

w1 = torch.nn.Parameter(loaded_weights_snn['params']['w1'].to(device))
w2 = torch.nn.Parameter(loaded_weights_snn['params']['w2'].to(device))
v1 = torch.nn.Parameter(loaded_weights_snn['params']['v1'].to(device))
alpha_hetero_1 = torch.nn.Parameter(loaded_weights_snn['params']['alpha'].to(device))
beta_hetero_1 = torch.nn.Parameter(loaded_weights_snn['params']['beta'].to(device))
thresholds_1 = torch.nn.Parameter(loaded_weights_snn['params']['threshold'].to(device))
reset_1 = torch.nn.Parameter(loaded_weights_snn['params']['reset'].to(device))
rest_1 = torch.nn.Parameter(loaded_weights_snn['params']['rest'].to(device))
alpha_hetero_2 = torch.nn.Parameter(loaded_weights_snn['params']['alpha_2'].to(device))
beta_hetero_2 = torch.nn.Parameter(loaded_weights_snn['params']['beta_2'].to(device))
print(loaded_weights_snn['accuracy'])
# print(compute_classification_accuracy_hetero(x_train, y_train))


nb_epochs_mlp = 80
batch_size_hetero = 64
mlp_interval = 10
mlp_mlp = hetero_mlp_a_b_spikes().to(device)
loss_hist_MLP = train_hybrid(mlp_mlp, x_train, y_train, lr=1e-3, nb_epochs=nb_epochs_mlp, mlp_interval = mlp_interval)

cuda
init done
0.5246167452830188


  labels_ = np.array(y,dtype=int)


Epoch 1: loss=2.34134
ground_loss 2.307302144546444
reg_loss 0.03403388401785703
Epoch 1: Train= 0.35103 Test Accuracy=0.42448


KeyboardInterrupt: 

# Double HOMO

In [None]:
import os
import h5py

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns

import torch
# print(torch.__version__)
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
# import torchvision
from torch.utils import data

import pickle
from utils import get_shd_dataset

dtype = torch.float

# Check whether a GPU is available
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

# Here we load the Dataset
cache_dir = os.path.expanduser("~/data")
cache_subdir = "ssc_data"
train_file = h5py.File(os.path.join(cache_dir, cache_subdir, 'ssc_train.h5'), 'r')
validation_file = h5py.File(os.path.join(cache_dir, cache_subdir, 'ssc_valid.h5'), 'r')
test_file = h5py.File(os.path.join(cache_dir, cache_subdir, 'ssc_test.h5'), 'r')

x_train = train_file['spikes']
y_train = train_file['labels']
x_valid = validation_file['spikes']
y_valid = validation_file['labels']
x_test = test_file['spikes']
y_test = test_file['labels']

def sparse_data_generator_from_hdf5_spikes(X, y, batch_size, nb_steps, nb_units, max_time, shuffle=True):
    """ This generator takes a spike dataset and generates spiking network input as sparse tensors.

    Args:
        X: The data ( sample x event x 2 ) the last dim holds (time,neuron) tuples
        y: The labels
    """

    labels_ = np.array(y,dtype=int)
    number_of_batches = len(labels_)//batch_size
    sample_index = np.arange(len(labels_))

    # compute discrete firing times
    firing_times = X['times']
    units_fired = X['units']

    time_bins = np.linspace(0, max_time, num=nb_steps)

    if shuffle:
        np.random.shuffle(sample_index)

    total_batch_count = 0
    counter = 0
    while counter<number_of_batches:
        batch_index = sample_index[batch_size*counter:batch_size*(counter+1)]

        coo = [ [] for i in range(3) ]
        for bc,idx in enumerate(batch_index):
            times = np.digitize(firing_times[idx], time_bins)
            units = units_fired[idx]
            batch = [bc for _ in range(len(times))]

            coo[0].extend(batch)
            coo[1].extend(times)
            coo[2].extend(units)

        i = torch.LongTensor(coo).to(device)
        v = torch.FloatTensor(np.ones(len(coo[0]))).to(device)

        X_batch = torch.sparse.FloatTensor(i, v, torch.Size([batch_size,nb_steps,nb_units])).to(device)
        y_batch = torch.tensor(labels_[batch_index],device=device)

        yield X_batch.to(device=device), y_batch.to(device=device)

        counter += 1

class SurrGradSpike(torch.autograd.Function):
    """
    Here we implement our spiking nonlinearity which also implements
    the surrogate gradient. By subclassing torch.autograd.Function,
    we will be able to use all of PyTorch's autograd functionality.
    Here we use the normalized negative part of a fast sigmoid
    as this was done in Zenke & Ganguli (2018).
    """

    scale = 100.0 # controls steepness of surrogate gradient

    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we compute a step function of the input Tensor
        and return it. ctx is a context object that we use to stash information which
        we need to later backpropagate our error signals. To achieve this we use the
        ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        out = torch.zeros_like(input)
        out[input > 0] = 1.0
        return out

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor we need to compute the
        surrogate gradient of the loss with respect to the input.
        Here we use the normalized negative part of a fast sigmoid
        as this was done in Zenke & Ganguli (2018).
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad = grad_input/(SurrGradSpike.scale*torch.abs(input)+1.0)**2
        return grad

# here we overwrite our naive spike function by the "SurrGradSpike" nonlinearity which implements a surrogate gradient
spike_fn  = SurrGradSpike.apply

tau_syn = 10e-3
tau_mem = 20e-3

time_step = 1e-3
nb_steps = 100
max_time = 1.4

s_batch_size = 64

# SNN 1 Network
s1_nb_inputs  = 700
s1_nb_hidden  = 200
s1_nb_outputs = 35

# SNN 2 Network
s2_nb_inputs = s1_nb_hidden
s2_nb_hidden = 14


def run_double_snn_homo(s1_inputs, s2_interval):
    s1_alpha_1_local = s1_alpha_homo_1
    s1_beta_1_local = s1_beta_homo_1
    s1_thresholds_local = s1_thresholds_1
    s1_reset_local = s1_reset_1
    s1_rest_local = s1_rest_1
    s1_alpha_2_local = s1_alpha_homo_2
    s1_beta_2_local = s1_beta_homo_2
    s1_alpha_1_local = s1_alpha_1_local.expand(s_batch_size, 1)
    s1_beta_1_local = s1_beta_1_local.expand(s_batch_size, 1)
    s1_thresholds_local = s1_thresholds_local.expand(s_batch_size, 1)
    s1_reset_local = s1_reset_local.expand(s_batch_size, 1)
    s1_rest_local = s1_rest_local.expand(s_batch_size, 1)
    s1_alpha_2_local = s1_alpha_2_local.expand(s_batch_size, 1)
    s1_beta_2_local = s1_beta_2_local.expand(s_batch_size, 1)

    # Initialize synaptic and membrane potentials for SNN 1
    s1_syn = torch.zeros((s_batch_size, s1_nb_hidden), device=device, dtype=dtype)
    s1_mem = torch.zeros((s_batch_size, s1_nb_hidden), device=device, dtype=dtype)

    # Initialize recordings for membrane potentials and spikes for SNN 1
    s1_mem_rec = []
    s1_spk_rec = []

    # Initialize outputs for the hidden layer for SNN 1
    s1_out = torch.zeros((s_batch_size, s1_nb_hidden), device=device, dtype=dtype)
    s1_h1_from_input = torch.einsum("abc,cd->abd", (s1_inputs, s1_w1))

    # Prepare readout layer variables for SNN 1
    s1_flt2 = torch.zeros((s_batch_size, s1_nb_outputs), device=device, dtype=dtype)
    s1_out2 = torch.zeros((s_batch_size, s1_nb_outputs), device=device, dtype=dtype)
    s1_out_rec = [s1_out2]

    # Initialize synaptic and membrane potentials for SNN 2
    s2_syn = torch.zeros((s_batch_size, s2_nb_hidden), device=device, dtype=dtype)
    s2_mem = torch.zeros((s_batch_size, s2_nb_hidden), device=device, dtype=dtype)

    # Initialize recordings for membrane potentials and spikes for SNN 2
    s2_mem_rec = []
    s2_spk_rec = []

    # Initialize outputs for the hidden layer for SNN 2
    s2_out = torch.zeros((s_batch_size, s2_nb_hidden), device=device, dtype=dtype)

    # Buffer to store spikes from the last s2_interval timesteps for SNN 2
    s2_spike_buffer = torch.zeros((s2_interval, s_batch_size, s2_nb_hidden), device=device, dtype=dtype)

    for s1_t in range(nb_steps):
        # Compute hidden layer activity for SNN 1
        s1_h1 = s1_h1_from_input[:, s1_t] + torch.einsum("ab,bc->ac", (s1_out, s1_v1))
        s1_mthr = s1_mem - s1_thresholds_local
        s1_out = spike_fn(s1_mthr)
        s1_rst = torch.zeros_like(s1_mem)
        c = (s1_mthr > 0)
        s1_rst[c] = torch.ones_like(s1_mem)[c]

        s1_syn = s1_alpha_1_local * s1_syn + s1_h1  # Reshape and expand to match s1_mem
        s1_mem = s1_beta_1_local * (s1_mem - s1_rest_local) + s1_rest_local + (1 - s1_beta_1_local) * s1_syn - s1_rst * (s1_thresholds_local - s1_reset_local)

        s1_mem_rec.append(s1_mem)
        s1_spk_rec.append(s1_out)

        # Compute output for the readout layer for SNN 1
        s1_h2_t = torch.einsum("ab,bc->ac", (s1_out, s1_w2))
        s1_flt2 = s1_alpha_2_local * s1_flt2 + s1_h2_t  # Reshape and expand to match s1_flt2
        s1_out2 = s1_beta_2_local * s1_out2 + (1 - s1_beta_2_local) * s1_flt2  # Reshape and expand to match s1_flt2

        s1_out_rec.append(s1_out2)

        # Compute hidden layer activity for SNN 2 using spikes from SNN 1 as input
        s2_h1 = torch.einsum("ab,bc->ac", (s1_out, s2_w1))
        s2_mthr = s2_mem - s2_thresholds_1  # Reshape and expand to match s2_mem
        s2_out = spike_fn(s2_mthr)
        s2_rst = torch.zeros_like(s2_mem)
        c = (s2_mthr > 0)
        s2_rst[c] = torch.ones_like(s2_mem)[c]

        s2_new_syn = s2_alpha_homo_1 * s2_syn + s2_h1  # Reshape and expand to match s2_syn
        s2_new_mem = s2_beta_homo_1 * (s2_mem - s2_rest_1) + s2_rest_1 + (1 - s2_beta_homo_1) * s2_syn - s2_rst * (s2_thresholds_1 - s2_reset_1)

        s2_mem_rec.append(s2_mem)
        s2_spk_rec.append(s2_out)

        s2_mem = s2_new_mem
        s2_syn = s2_new_syn

        # Update the spike buffer
        s2_spike_buffer[s1_t % s2_interval] = s2_out

        # Use spikes from the last s2_interval timesteps to modulate parameters of SNN 1
        if s1_t > s2_interval and s1_t % s2_interval == 0:
            # Aggregate spikes over the last s2_interval timesteps
            recent_spikes = s2_spike_buffer.sum(dim=0)

            # Neurons for modulating s1_alpha_homo_1
            mod_alpha_1_pos = torch.sum(recent_spikes[:, 0:1], dim=1, keepdim=True)
            mod_alpha_1_neg = torch.sum(recent_spikes[:, 1:2], dim=1, keepdim=True)

            # Neurons for modulating s1_beta_homo_1
            mod_beta_1_pos = torch.sum(recent_spikes[:, 2:3], dim=1, keepdim=True)
            mod_beta_1_neg = torch.sum(recent_spikes[:, 3:4], dim=1, keepdim=True)

            # Neurons for modulating s1_alpha_homo_2
            mod_alpha_2_pos = torch.sum(recent_spikes[:, 4:5], dim=1, keepdim=True)
            mod_alpha_2_neg = torch.sum(recent_spikes[:, 5:6], dim=1, keepdim=True)

            # Neurons for modulating s1_beta_homo_2
            mod_beta_2_pos = torch.sum(recent_spikes[:, 6:7], dim=1, keepdim=True)
            mod_beta_2_neg = torch.sum(recent_spikes[:, 7:8], dim=1, keepdim=True)

            # Neurons for modulating s1_thresholds_1
            mod_thresholds_pos = torch.sum(recent_spikes[:, 8:9], dim=1, keepdim=True)
            mod_thresholds_neg = torch.sum(recent_spikes[:, 9:10], dim=1, keepdim=True)

            # Neurons for modulating s1_reset_1
            mod_reset_pos = torch.sum(recent_spikes[:, 10:11], dim=1, keepdim=True)
            mod_reset_neg = torch.sum(recent_spikes[:, 11:12], dim=1, keepdim=True)

            # Neurons for modulating s1_rest_1
            mod_rest_pos = torch.sum(recent_spikes[:, 12:13], dim=1, keepdim=True)
            mod_rest_neg = torch.sum(recent_spikes[:, 13:14], dim=1, keepdim=True)

            # Calculate modulation effects
            mod_alpha_1_effect = mod_alpha_1_pos * s2_alpha_mod_factor1_pos - mod_alpha_1_neg * s2_alpha_mod_factor1_neg
            mod_beta_1_effect = mod_beta_1_pos * s2_beta_mod_factor1_pos - mod_beta_1_neg * s2_beta_mod_factor1_neg
            mod_alpha_2_effect = mod_alpha_2_pos * s2_alpha_mod_factor2_pos - mod_alpha_2_neg * s2_alpha_mod_factor2_neg
            mod_beta_2_effect = mod_beta_2_pos * s2_beta_mod_factor2_pos - mod_beta_2_neg * s2_beta_mod_factor2_neg
            mod_thresholds_effect = mod_thresholds_pos * s2_thresholds_mod_factor_pos - mod_thresholds_neg * s2_thresholds_mod_factor_neg
            mod_reset_effect = mod_reset_pos * s2_reset_mod_factor_pos - mod_reset_neg * s2_reset_mod_factor_neg
            mod_rest_effect = mod_rest_pos * s2_rest_mod_factor_pos - mod_rest_neg * s2_rest_mod_factor_neg

            # Update the actual parameters using in-place operations
            s1_alpha_1_local = s1_alpha_1_local + mod_alpha_1_effect
            s1_beta_1_local = s1_beta_1_local + mod_beta_1_effect
            s1_alpha_2_local = s1_alpha_2_local + mod_alpha_2_effect
            s1_beta_2_local = s1_beta_2_local + mod_beta_2_effect
            s1_thresholds_local = s1_thresholds_local + mod_thresholds_effect
            s1_reset_local = s1_reset_local+ mod_reset_effect
            s1_rest_local = s1_rest_local + mod_rest_effect

            # Ensure modulation effects are stable
            s1_alpha_1_local = torch.clamp(s1_alpha_1_local, min=0.367, max=0.995)
            s1_beta_1_local = torch.clamp(s1_beta_1_local, min=0.367, max=0.995)
            s1_alpha_2_local = torch.clamp(s1_alpha_2_local, min=0.367, max=0.995)
            s1_beta_2_local = torch.clamp(s1_beta_2_local, min=0.367, max=0.995)
            s1_thresholds_local = torch.clamp(s1_thresholds_local, min=0.5, max=1.5)
            s1_reset_local = torch.clamp(s1_reset_local, min=-0.5, max=0.5)
            s1_rest_local = torch.clamp(s1_rest_local, min=-0.5, max=0.5)

    # Stack recordings for output
    s1_mem_rec = torch.stack(s1_mem_rec, dim=1)
    s1_spk_rec = torch.stack(s1_spk_rec, dim=1)
    s1_out_rec = torch.stack(s1_out_rec[1:], dim=1)  # Skip the initial zero tensor

    s1_other_recs = [s1_mem_rec, s1_spk_rec]

    return s1_out_rec, s1_other_recs

def train_double_snn(x_data, y_data, lr=1e-3, nb_epochs=10, s2_interval=10):
    # Include all learnable parameters in the list
    params = [
        s1_w1, s1_w2, s1_v1,
        s2_w1, s2_v1,
        s2_alpha_mod_factor1_pos, s2_alpha_mod_factor1_neg,
        s2_beta_mod_factor1_pos, s2_beta_mod_factor1_neg,
        s2_alpha_mod_factor2_pos, s2_alpha_mod_factor2_neg,
        s2_beta_mod_factor2_pos, s2_beta_mod_factor2_neg,
        s2_thresholds_mod_factor_pos, s2_thresholds_mod_factor_neg,
        s2_reset_mod_factor_pos, s2_reset_mod_factor_neg,
        s2_rest_mod_factor_pos, s2_rest_mod_factor_neg,
        s1_thresholds_1, s1_reset_1, s1_rest_1,
        s1_alpha_homo_1, s1_beta_homo_1,
        s1_alpha_homo_2, s1_beta_homo_2,
        s2_thresholds_1, s2_reset_1, s2_rest_1,
        s2_alpha_homo_1, s2_beta_homo_1
    ]


    optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.999))

    loss_fn = nn.NLLLoss()
    log_softmax_fn = nn.LogSoftmax(dim=1)

    best_accuracy = 0
    loss_hist = []
    train_acc_hist = []
    test_acc_hist = []

    for epoch in range(nb_epochs):
        local_loss = []
        local_ground_loss = []
        local_reg_loss = []
        accs = []
#         print(s1_w1)
#         print(s2_alpha_mod_factor2_pos)
#         print(s1_thresholds_1)
#         print(s2_alpha_homo_1)

        for x_local, y_local in sparse_data_generator_from_hdf5_spikes(x_data, y_data, s_batch_size, nb_steps, s1_nb_inputs, max_time):
            # Ensure x_local is detached from any previous computation graph
            x_local = x_local.to_dense().detach()

            output, recs = run_double_snn_homo(x_local, s2_interval=s2_interval)
            _, spks = recs
            m, _ = torch.max(output, 1)

            _, am = torch.max(m, 1)  # argmax over output units
            tmp = np.mean((y_local == am).detach().cpu().numpy())  # compare to labels
            accs.append(tmp)

            log_p_y = log_softmax_fn(m)
            ground_loss = loss_fn(log_p_y, y_local)

            # L2 loss on spikes per neuron
#             reg_loss = 1e-6 * torch.sum(spks)  # L1 loss on total number of spikes
#             reg_loss += 1e-6 * torch.mean(torch.sum(torch.sum(spks, dim=0), dim=0)**2)  # L2 loss on spikes per neuron
            tot_num_neurons = nb_hidden + nb_outputs
            N_samp = 6412
            T = nb_steps
            sl = 1
            thetal = 0.01
            su = 0.06
            thetau = 100
            tmp = (torch.clamp((1 / T) * torch.sum(spks, 1) - thetal, min=0.)) ** 2
            L1_batch = torch.sum(tmp, (0, 1))
            reg_loss = (sl / (N_samp + tot_num_neurons)) * L1_batch

            tmp2 = (torch.clamp((1 / nb_hidden) * torch.sum(spks, (1, 2)) - thetau, min=0.)) ** 2
            L2_batch = torch.sum(tmp2)
            reg_loss += (su / N_samp) * L2_batch
            loss_val = ground_loss + reg_loss

            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()
            local_loss.append(loss_val.item())
            local_ground_loss.append(ground_loss.item())
            local_reg_loss.append(reg_loss.item())

        mean_loss = np.mean(local_loss)
        loss_hist.append(mean_loss)
        print(f"Epoch {epoch + 1}: loss={mean_loss:.5f}")
        print("ground_loss", np.mean(local_ground_loss))
        print("reg_loss", np.mean(local_reg_loss))
        
        train_accuracy = np.mean(accs)
        test_accuracy = compute_classification_accuracy_double(x_test, y_test, s2_interval)
        train_acc_hist.append(train_accuracy)
        test_acc_hist.append(test_accuracy)
        print(f"Epoch {epoch + 1}: Train= {train_accuracy:.5f} Test Accuracy={test_accuracy:.5f}")

        if test_accuracy > best_accuracy:
            best_accuracy = test_accuracy
            print("best_accuracy", best_accuracy)

            # Create a dictionary of curent parameters to save
            saved_weights_snn = {
                's1_w1': s1_w1,
                's1_w2': s1_w2,
                's1_v1': s1_v1,
                's2_w1': s2_w1,
                's2_v1': s2_v1,
                's2_alpha_mod_factor1_pos': s2_alpha_mod_factor1_pos,
                's2_alpha_mod_factor1_neg': s2_alpha_mod_factor1_neg,
                's2_beta_mod_factor1_pos': s2_beta_mod_factor1_pos,
                's2_beta_mod_factor1_neg': s2_beta_mod_factor1_neg,
                's2_alpha_mod_factor2_pos': s2_alpha_mod_factor2_pos,
                's2_alpha_mod_factor2_neg': s2_alpha_mod_factor2_neg,
                's2_beta_mod_factor2_pos': s2_beta_mod_factor2_pos,
                's2_beta_mod_factor2_neg': s2_beta_mod_factor2_neg,
                's2_thresholds_mod_factor_pos': s2_thresholds_mod_factor_pos,
                's2_thresholds_mod_factor_neg': s2_thresholds_mod_factor_neg,
                's2_reset_mod_factor_pos': s2_reset_mod_factor_pos,
                's2_reset_mod_factor_neg': s2_reset_mod_factor_neg,
                's2_rest_mod_factor_pos': s2_rest_mod_factor_pos,
                's2_rest_mod_factor_neg': s2_rest_mod_factor_neg,
                's1_thresholds_1': s1_thresholds_1,
                's1_reset_1': s1_reset_1,
                's1_rest_1': s1_rest_1,
                's2_thresholds_1': s2_thresholds_1,
                's2_reset_1': s2_reset_1,
                's2_rest_1': s2_rest_1,
                's2_alpha_homo_1': s2_alpha_homo_1,
                's2_beta_homo_1': s2_beta_homo_1,
                's1_alpha_homo_1': s1_alpha_homo_1,
                's1_beta_homo_1': s1_beta_homo_1,
                's1_alpha_homo_2': s1_alpha_homo_2,
                's1_beta_homo_2': s1_beta_homo_2
            }
            
            directory = 'SSC_Double_Homo'

            # Create the directory if it does not exist
            if not os.path.exists(directory):
                os.makedirs(directory)

            # Save parameters along with the current epoch and accuracy
            file_path = os.path.join(directory, 'best_double.pth')

            torch.save({
                'epoch': epoch + 1,
                'accuracy': test_accuracy,
                'params': saved_weights_snn,
                'loss': loss_hist,
                'train_acc_hist': train_acc_hist,
                'test_acc_hist': test_acc_hist
            }, file_path)
        else:
            print("best_accuracy", best_accuracy)

    return loss_hist

def compute_classification_accuracy_double(x_data, y_data, s2_interval):
    """ Computes classification accuracy on supplied data in batches. """
    accs = []
    for x_local, y_local in sparse_data_generator_from_hdf5_spikes(x_data, y_data, s_batch_size, nb_steps, s1_nb_inputs, max_time, shuffle=False):
        output,_ = run_double_snn_homo(x_local.to_dense(), s2_interval = s2_interval)
        m,_= torch.max(output,1) # max over time
        _,am=torch.max(m,1)      # argmax over output units
        tmp = np.mean((y_local==am).detach().cpu().numpy()) # compare to labels
        accs.append(tmp)
    return np.mean(accs)

# Initialize SNN 1 and SNN 2 parameters

# SNN 1 parameters

s1_thresholds_1 = nn.Parameter(torch.ones((1, 1), device=device, dtype=dtype, requires_grad=True))
s1_reset_1 = nn.Parameter(torch.zeros((1, 1), device=device, dtype=dtype, requires_grad=True))
s1_rest_1 = nn.Parameter(torch.zeros((1, 1), device=device, dtype=dtype, requires_grad=True))

const_alpha = float(np.exp(-time_step/tau_syn))
const_beta = float(np.exp(-time_step/tau_mem))

s1_alpha_homo_1 = nn.Parameter(torch.full((1, 1), const_alpha, device=device, dtype=dtype, requires_grad=True))
s1_beta_homo_1 = nn.Parameter(torch.full((1, 1), const_beta, device=device, dtype=dtype, requires_grad=True))

s1_alpha_homo_2 = nn.Parameter(torch.full((1, 1), const_alpha, device=device, dtype=dtype, requires_grad=True))
s1_beta_homo_2 = nn.Parameter(torch.full((1, 1), const_beta, device=device, dtype=dtype, requires_grad=True))

# SNN 2 parameters

s2_thresholds_1 = nn.Parameter(torch.ones((1, 1), device=device, dtype=dtype, requires_grad=True))
s2_reset_1 = nn.Parameter(torch.zeros((1, 1), device=device, dtype=dtype, requires_grad=True))
s2_rest_1 = nn.Parameter(torch.zeros((1, 1), device=device, dtype=dtype, requires_grad=True))

s2_alpha_homo_1 = nn.Parameter(torch.full((1, 1), const_alpha, device=device, dtype=dtype, requires_grad=True))
s2_beta_homo_1 = nn.Parameter(torch.full((1, 1), const_beta, device=device, dtype=dtype, requires_grad=True))


# Initialize modulation factors
s2_alpha_mod_factor1_pos = nn.Parameter(torch.full((1, 1), const_alpha, device=device, dtype=dtype, requires_grad=True))
s2_alpha_mod_factor1_neg = nn.Parameter(torch.full((1, 1), const_alpha, device=device, dtype=dtype, requires_grad=True))
s2_beta_mod_factor1_pos = nn.Parameter(torch.full((1, 1), const_alpha, device=device, dtype=dtype, requires_grad=True))
s2_beta_mod_factor1_neg = nn.Parameter(torch.full((1, 1), const_alpha, device=device, dtype=dtype, requires_grad=True))
s2_alpha_mod_factor2_pos = nn.Parameter(torch.full((1, 1), const_alpha, device=device, dtype=dtype, requires_grad=True))
s2_alpha_mod_factor2_neg = nn.Parameter(torch.full((1, 1), const_alpha, device=device, dtype=dtype, requires_grad=True))
s2_beta_mod_factor2_pos = nn.Parameter(torch.full((1, 1), const_alpha, device=device, dtype=dtype, requires_grad=True))
s2_beta_mod_factor2_neg = nn.Parameter(torch.full((1, 1), const_alpha, device=device, dtype=dtype, requires_grad=True))
s2_thresholds_mod_factor_pos = nn.Parameter(torch.full((1, 1), const_alpha, device=device, dtype=dtype, requires_grad=True))
s2_thresholds_mod_factor_neg = nn.Parameter(torch.full((1, 1), const_alpha, device=device, dtype=dtype, requires_grad=True))
s2_reset_mod_factor_pos = nn.Parameter(torch.full((1, 1), const_alpha, device=device, dtype=dtype, requires_grad=True))
s2_reset_mod_factor_neg = nn.Parameter(torch.full((1, 1), const_alpha, device=device, dtype=dtype, requires_grad=True))
s2_rest_mod_factor_pos = nn.Parameter(torch.full((1, 1), const_alpha, device=device, dtype=dtype, requires_grad=True))
s2_rest_mod_factor_neg = nn.Parameter(torch.full((1, 1), const_alpha, device=device, dtype=dtype, requires_grad=True))

# SNN 1 weights
s1_weight_scale = 0.2

s1_w1 = torch.empty((s1_nb_inputs, s1_nb_hidden),  device=device, dtype=dtype, requires_grad=True)
torch.nn.init.normal_(s1_w1, mean=0.0, std=s1_weight_scale/np.sqrt(s1_nb_inputs))

s1_w2 = torch.empty((s1_nb_hidden, s1_nb_outputs), device=device, dtype=dtype, requires_grad=True)
torch.nn.init.normal_(s1_w2, mean=0.0, std=s1_weight_scale/np.sqrt(s1_nb_hidden))

s1_v1 = torch.empty((s1_nb_hidden, s1_nb_hidden), device=device, dtype=dtype, requires_grad=True)
torch.nn.init.normal_(s1_v1, mean=0.0, std=s1_weight_scale/np.sqrt(s1_nb_hidden))

# SNN 2 weights
s2_weight_scale = 0.2

s2_w1 = torch.empty((s2_nb_inputs, s2_nb_hidden),  device=device, dtype=dtype, requires_grad=True)
torch.nn.init.normal_(s2_w1, mean=0.0, std=s2_weight_scale/np.sqrt(s2_nb_inputs))

s2_v1 = torch.empty((s2_nb_hidden, s2_nb_hidden), device=device, dtype=dtype, requires_grad=True)
torch.nn.init.normal_(s2_v1, mean=0.0, std=s2_weight_scale/np.sqrt(s2_nb_hidden))

loaded_weights_snn = torch.load('Python_Tests/SSC_homo_no_reg/epochs/snn_13.pth', map_location=torch.device(device))
s1_w1 = torch.nn.Parameter(loaded_weights_snn['params']['w1'].to(device))
s1_w2 = torch.nn.Parameter(loaded_weights_snn['params']['w2'].to(device))
s1_v1 = torch.nn.Parameter(loaded_weights_snn['params']['v1'].to(device))
s1_alpha_homo_1 = torch.nn.Parameter(loaded_weights_snn['params']['alpha'].to(device))
s1_beta_homo_1 = torch.nn.Parameter(loaded_weights_snn['params']['beta'].to(device))
s1_thresholds_1 = torch.nn.Parameter(loaded_weights_snn['params']['threshold'].to(device))
s1_reset_1 = torch.nn.Parameter(loaded_weights_snn['params']['reset'].to(device))
s1_rest_1 = torch.nn.Parameter(loaded_weights_snn['params']['rest'].to(device))
s1_alpha_homo_2 = torch.nn.Parameter(loaded_weights_snn['params']['alpha_2'].to(device))
s1_beta_homo_2 = torch.nn.Parameter(loaded_weights_snn['params']['beta_2'].to(device))

# INTERVAL (DOUBLE) HERE
s2_interval = 10
nb_epochs_double = 100
s_batch_size = 64
lr_double = 2e-4
loss_hist_snn = train_double_snn(x_train, y_train, lr=lr_double, nb_epochs=nb_epochs_double, s2_interval = s2_interval)

cuda


  labels_ = np.array(y,dtype=int)


Epoch 1: loss=1.41865
ground_loss 1.4065279819583165
reg_loss 0.012126956692722209
Epoch 1: Train= 0.60949 Test Accuracy=0.57184
best_accuracy 0.5718356918238994
Epoch 2: loss=1.43688
ground_loss 1.4250233568206898
reg_loss 0.011852954442441412
Epoch 2: Train= 0.59999 Test Accuracy=0.58608
best_accuracy 0.5860849056603774
Epoch 5: loss=1.44366
ground_loss 1.4241916570145767
reg_loss 0.019465153803456453
Epoch 5: Train= 0.59999 Test Accuracy=0.59547
best_accuracy 0.5954697327044025
Epoch 6: loss=1.33269
ground_loss 1.3222405119807767
reg_loss 0.010447952909785322
Epoch 6: Train= 0.63031 Test Accuracy=0.60122
best_accuracy 0.6012185534591195
Epoch 7: loss=1.46227
ground_loss 1.453848522785258
reg_loss 0.008423414372422743
Epoch 7: Train= 0.59686 Test Accuracy=0.56314
best_accuracy 0.6012185534591195
Epoch 15: loss=1.18872
ground_loss 1.1815861485791064
reg_loss 0.00713437555580573


# Double HETERO

In [6]:
import os
import h5py

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import seaborn as sns

import torch
# print(torch.__version__)
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
# import torchvision
from torch.utils import data


import pickle
import math
from utils import get_shd_dataset

# The coarse network structure and the time steps are dicated by the SHD dataset.
nb_inputs  = 700
nb_hidden  = 200
nb_outputs = 35

time_step = 1e-3
nb_steps = 100
max_time = 1.4

batch_size = 64

dtype = torch.float

# Check whether a GPU is available
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(device)

# Here we load the Dataset
cache_dir = os.path.expanduser("~/data")
cache_subdir = "ssc_data"
train_file = h5py.File(os.path.join(cache_dir, cache_subdir, 'ssc_train.h5'), 'r')
validation_file = h5py.File(os.path.join(cache_dir, cache_subdir, 'ssc_valid.h5'), 'r')
test_file = h5py.File(os.path.join(cache_dir, cache_subdir, 'ssc_test.h5'), 'r')

x_train = train_file['spikes']
y_train = train_file['labels']
x_valid = validation_file['spikes']
y_valid = validation_file['labels']
x_test = test_file['spikes']
y_test = test_file['labels']

def sparse_data_generator_from_hdf5_spikes(X, y, batch_size, nb_steps, nb_units, max_time, shuffle=True):
    """ This generator takes a spike dataset and generates spiking network input as sparse tensors.

    Args:
        X: The data ( sample x event x 2 ) the last dim holds (time,neuron) tuples
        y: The labels
    """

    labels_ = np.array(y,dtype=int)
    number_of_batches = len(labels_)//batch_size
    sample_index = np.arange(len(labels_))

    # compute discrete firing times
    firing_times = X['times']
    units_fired = X['units']

    time_bins = np.linspace(0, max_time, num=nb_steps)

    if shuffle:
        np.random.shuffle(sample_index)

    total_batch_count = 0
    counter = 0
    while counter<number_of_batches:
        batch_index = sample_index[batch_size*counter:batch_size*(counter+1)]

        coo = [ [] for i in range(3) ]
        for bc,idx in enumerate(batch_index):
            times = np.digitize(firing_times[idx], time_bins)
            units = units_fired[idx]
            batch = [bc for _ in range(len(times))]

            coo[0].extend(batch)
            coo[1].extend(times)
            coo[2].extend(units)

        i = torch.LongTensor(coo).to(device)
        v = torch.FloatTensor(np.ones(len(coo[0]))).to(device)

        X_batch = torch.sparse.FloatTensor(i, v, torch.Size([batch_size,nb_steps,nb_units])).to(device)
        y_batch = torch.tensor(labels_[batch_index],device=device)

        yield X_batch.to(device=device), y_batch.to(device=device)

        counter += 1
        
class SurrGradSpike(torch.autograd.Function):
    """
    Here we implement our spiking nonlinearity which also implements
    the surrogate gradient. By subclassing torch.autograd.Function,
    we will be able to use all of PyTorch's autograd functionality.
    Here we use the normalized negative part of a fast sigmoid
    as this was done in Zenke & Ganguli (2018).
    """

    scale = 100.0 # controls steepness of surrogate gradient

    @staticmethod
    def forward(ctx, input):
        """
        In the forward pass we compute a step function of the input Tensor
        and return it. ctx is a context object that we use to stash information which
        we need to later backpropagate our error signals. To achieve this we use the
        ctx.save_for_backward method.
        """
        ctx.save_for_backward(input)
        out = torch.zeros_like(input)
        out[input > 0] = 1.0
        return out

    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor we need to compute the
        surrogate gradient of the loss with respect to the input.
        Here we use the normalized negative part of a fast sigmoid
        as this was done in Zenke & Ganguli (2018).
        """
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad = grad_input/(SurrGradSpike.scale*torch.abs(input)+1.0)**2
        return grad

# here we overwrite our naive spike function by the "SurrGradSpike" nonlinearity which implements a surrogate gradient
spike_fn  = SurrGradSpike.apply

def dist_fn(dist):
    return {
        'gamma': lambda mean, k, size: np.random.gamma(k, scale=mean/k, size=size),
        'normal': lambda mean, k, size: np.random.normal(loc=mean, scale=mean/np.sqrt(k), size=size), #change standard deviation to match gamma
        'uniform': lambda _, maximum, size: np.random.uniform(low=0, high=maximum, size=size),
    }[dist.lower()]
    
    
    
s_batch_size = 64

# SNN 1 Network
s1_nb_inputs  = 700
s1_nb_hidden  = 200
s1_nb_outputs = 35

s_group_size = 1

# SNN 2 Network
s2_nb_inputs = 1961+45
s2_nb_hidden = 1070*2/s_group_size



def run_double_snn_hetero(s1_inputs, s2_interval):
    s1_alpha_1_local = s1_alpha_hetero_1
    s1_beta_1_local = s1_beta_hetero_1
    s1_thresholds_local = s1_thresholds_1
    s1_reset_local = s1_reset_1
    s1_rest_local = s1_rest_1
    s1_alpha_2_local = s1_alpha_hetero_2
    s1_beta_2_local = s1_beta_hetero_2
#     mod_factors_local = mod_factors
    s1_alpha_1_local = s1_alpha_1_local.expand(s_batch_size, s1_nb_hidden)
    s1_beta_1_local = s1_beta_1_local.expand(s_batch_size, s1_nb_hidden)
    s1_thresholds_local = s1_thresholds_local.expand(s_batch_size, s1_nb_hidden)
    s1_reset_local = s1_reset_local.expand(s_batch_size, s1_nb_hidden)
    s1_rest_local = s1_rest_local.expand(s_batch_size, s1_nb_hidden)
    s1_alpha_2_local = s1_alpha_2_local.expand(s_batch_size, s1_nb_outputs)
    s1_beta_2_local = s1_beta_2_local.expand(s_batch_size, s1_nb_outputs)
#     mod_factors_local = mod_factors_local.expand(s_batch_size, 

    # Initialize synaptic and membrane potentials for SNN 1
    s1_syn = torch.zeros((s_batch_size, s1_nb_hidden), device=device, dtype=dtype)
    s1_mem = torch.zeros((s_batch_size, s1_nb_hidden), device=device, dtype=dtype)

    # Initialize recordings for membrane potentials and spikes for SNN 1
    s1_mem_rec = []
    s1_spk_rec = []

    # Initialize outputs for the hidden layer for SNN 1
    s1_out = torch.zeros((s_batch_size, s1_nb_hidden), device=device, dtype=dtype)
    s1_h1_from_input = torch.einsum("abc,cd->abd", (s1_inputs, s1_w1))

    # Prepare readout layer variables for SNN 1
    s1_flt2 = torch.zeros((s_batch_size, s1_nb_outputs), device=device, dtype=dtype)
    s1_out2 = torch.zeros((s_batch_size, s1_nb_outputs), device=device, dtype=dtype)
    s1_out_rec = [s1_out2]

    # Initialize synaptic and membrane potentials for SNN 2
    s2_syn = torch.zeros((s_batch_size, s2_nb_hidden), device=device, dtype=dtype)
    s2_mem = torch.zeros((s_batch_size, s2_nb_hidden), device=device, dtype=dtype)

    # Initialize recordings for membrane potentials and spikes for SNN 2
    s2_mem_rec = []
    s2_spk_rec = []
    
    test = 0

    # Initialize outputs for the hidden layer for SNN 2
    s2_out = torch.zeros((s_batch_size, s2_nb_hidden), device=device, dtype=dtype)

    # Buffer to store spikes from the last s2_interval timesteps for SNN 2
    s2_spike_buffer = torch.zeros((s2_interval, s_batch_size, s2_nb_hidden), device=device, dtype=dtype)

    for s1_t in range(nb_steps):
        # Compute hidden layer activity for SNN 1
        s1_h1 = s1_h1_from_input[:, s1_t] + torch.einsum("ab,bc->ac", (s1_out, s1_v1))
        s1_mthr = s1_mem - s1_thresholds_local
        s1_out = spike_fn(s1_mthr)
        s1_rst = torch.zeros_like(s1_mem)
        c = (s1_mthr > 0)
        s1_rst[c] = torch.ones_like(s1_mem)[c]

        s1_syn = s1_alpha_1_local * s1_syn + s1_h1
        s1_mem = s1_beta_1_local * (s1_mem - s1_rest_local) + s1_rest_local + (1 - s1_beta_1_local) * s1_syn - s1_rst * (s1_thresholds_local - s1_reset_local)

        s1_mem_rec.append(s1_mem)
        s1_spk_rec.append(s1_out)

        # Compute output for the readout layer for SNN 1
        s1_h2_t = torch.einsum("ab,bc->ac", (s1_out, s1_w2))
        s1_flt2 = s1_alpha_2_local * s1_flt2 + s1_h2_t
        s1_out2 = s1_beta_2_local * s1_out2 + (1 - s1_beta_2_local) * s1_flt2

        s1_out_rec.append(s1_out2)

        s1_vars = torch.cat((s1_alpha_1_local, s1_beta_1_local,
                             s1_thresholds_local, s1_reset_local, s1_rest_local,
                             s1_alpha_2_local, s1_beta_2_local), dim=1)

        s1_time_tensor = torch.full((s_batch_size, 1), s1_t, device=device, dtype=dtype)

        # Compute hidden layer activity for SNN 2 using spikes from SNN 1 as input
        s2_input_combined = torch.cat((s1_vars, s1_time_tensor, s1_inputs[:, s1_t, :], s1_out, s1_out2), dim=1)
        s2_h1 = torch.einsum("ab,bc->ac", (s2_input_combined, s2_w1))
        s2_mthr = s2_mem - s2_thresholds_1
        s2_out = spike_fn(s2_mthr)
        s2_rst = torch.zeros_like(s2_mem)
        c = (s2_mthr > 0)
        s2_rst[c] = torch.ones_like(s2_mem)[c]

        s2_new_syn = s2_alpha_hetero_1 * s2_syn + s2_h1
        s2_new_mem = s2_beta_hetero_1 * (s2_mem - s2_rest_1) + s2_rest_1 + (1 - s2_beta_hetero_1) * s2_syn - s2_rst * (s2_thresholds_1 - s2_reset_1)

        s2_mem_rec.append(s2_mem)
        s2_spk_rec.append(s2_out)

        s2_mem = s2_new_mem
        s2_syn = s2_new_syn

        # Update the spike buffer
        s2_spike_buffer[s1_t % s2_interval] = s2_out

        # Use spikes from the last s2_interval timesteps to modulate parameters of SNN 1
        if s1_t >= s2_interval and s1_t % s2_interval == 0:
            # Aggregate spikes over the last s2_interval timesteps
            recent_spikes = s2_spike_buffer.sum(dim=0)

            # Modulation neurons indices
            mod_indices = {
                's1_alpha_1': (0, 400),
                's1_beta_1': (400, 800),
                's1_thresholds': (800, 1200),
                's1_reset': (1200, 1600),
                's1_rest': (1600, 2000),
                's1_alpha_2': (2000, 2070),
                's1_beta_2': (2070, 2140)
            }
#             if test ==0:
# #                 test += 1
#                 print("PRE")
#                 print(s1_alpha_2_local)

            for param, (start_idx, end_idx) in mod_indices.items():
                # Calculate the midpoint
                mid_idx = start_idx + (end_idx - start_idx) // 2
                pos_mod = torch.sum(recent_spikes[:, start_idx:mid_idx], dim=1, keepdim=True)
                neg_mod = torch.sum(recent_spikes[:, mid_idx:end_idx], dim=1, keepdim=True)
                effect = pos_mod * mod_factors[:, start_idx:mid_idx] - neg_mod * mod_factors[:, mid_idx:end_idx]

                if param == 's1_alpha_1':
#                     print(1, " ", effect)
                    s1_alpha_1_local = s1_alpha_1_local + effect
                elif param == 's1_beta_1':
#                     print(2, " ", effect)
                    s1_beta_1_local = s1_beta_1_local + effect
                elif param == 's1_alpha_2':
                    s1_alpha_2_local = s1_alpha_2_local + effect
#                     if test == 0:
#                             test+=1
#                             print(3, " ", effect)
#                             print(s1_alpha_2_local)
                elif param == 's1_beta_2':
#                     print(4, " ", effect)
                    s1_beta_2_local = s1_beta_2_local + effect
                elif param == 's1_thresholds':
#                     print(5, " ", effect)
                    s1_thresholds_local = s1_thresholds_local + effect
                elif param == 's1_reset':
#                     print(6, " ", effect)
                    s1_reset_local = s1_reset_local + effect
                elif param == 's1_rest':
#                     print(7, " ", effect)
                    s1_rest_local = s1_rest_local + effect
#             if test ==0:
#                 test += 1
#                 print("POST")
#                 print(s1_alpha_2_local)

            # Ensure modulation effects are stable
            s1_alpha_1_local = torch.clamp(s1_alpha_1_local, min=0.367, max=0.995)
            s1_beta_1_local = torch.clamp(s1_beta_1_local, min=0.367, max=0.995)
            s1_alpha_2_local = torch.clamp(s1_alpha_2_local, min=0.367, max=0.995)
            s1_beta_2_local = torch.clamp(s1_beta_2_local, min=0.367, max=0.995)
            s1_thresholds_local = torch.clamp(s1_thresholds_local, min=0.5, max=1.5)
            s1_reset_local = torch.clamp(s1_reset_local, min=-0.5, max=0.5)
            s1_rest_local = torch.clamp(s1_rest_local, min=-0.5, max=0.5)

    # Stack recordings for output
    s1_mem_rec = torch.stack(s1_mem_rec, dim=1)
    s1_spk_rec = torch.stack(s1_spk_rec, dim=1)
    s1_out_rec = torch.stack(s1_out_rec[1:], dim=1)  # Skip the initial zero tensor

    s1_other_recs = [s1_mem_rec, s1_spk_rec]

    return s1_out_rec, s1_other_recs


def compute_classification_accuracy_double_hetero(x_data, y_data, s2_interval):
    """ Computes classification accuracy on supplied data in batches. """
    accs = []
    for x_local, y_local in sparse_data_generator_from_hdf5_spikes(x_data, y_data, s_batch_size, nb_steps, s1_nb_inputs, max_time, shuffle=False):
        output,_ = run_double_snn_hetero(x_local.to_dense(), s2_interval = s2_interval)
        m,_= torch.max(output,1) # max over time
        _,am=torch.max(m,1)      # argmax over output units
        tmp = np.mean((y_local==am).detach().cpu().numpy()) # compare to labels
        accs.append(tmp)
    return np.mean(accs)


def train_double_snn(x_data, y_data, lr=1e-3, nb_epochs=10, s2_interval=10):
    # Include all learnable parameters in the list
    params = [
        s1_w1, s1_w2, s1_v1,
        s2_w1, s2_v1,
        mod_factors,
        s1_thresholds_1, s1_reset_1, s1_rest_1,
        s1_alpha_hetero_1, s1_beta_hetero_1,
        s1_alpha_hetero_2, s1_beta_hetero_2,
        s2_thresholds_1, s2_reset_1, s2_rest_1,
        s2_alpha_hetero_1, s2_beta_hetero_1
    ]

    optimizer = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.999))

    loss_fn = nn.NLLLoss()
    log_softmax_fn = nn.LogSoftmax(dim=1)

    best_accuracy = 0
    loss_hist = []
    train_acc_hist = []
    test_acc_hist = []

    for epoch in range(nb_epochs):
        local_loss = []
        local_ground_loss = []
        local_reg_loss = []
        accs = []
#         print("MOD ", mod_factors)

        for x_local, y_local in sparse_data_generator_from_hdf5_spikes(x_data, y_data, s_batch_size, nb_steps, s1_nb_inputs, max_time):
            # Ensure x_local is detached from any previous computation graph
            x_local = x_local.to_dense().detach()

            output, recs = run_double_snn_hetero(x_local, s2_interval=s2_interval)
            _, spks = recs
            m, _ = torch.max(output, 1)

            _, am = torch.max(m, 1)  # argmax over output units
            tmp = np.mean((y_local == am).detach().cpu().numpy())  # compare to labels
            accs.append(tmp)

            log_p_y = log_softmax_fn(m)
            ground_loss = loss_fn(log_p_y, y_local)

            # L2 loss on spikes per neuron
#             reg_loss = 1e-6 * torch.sum(spks)  # L1 loss on total number of spikes
#             reg_loss += 1e-6 * torch.mean(torch.sum(torch.sum(spks, dim=0), dim=0)**2)  # L2 loss on spikes per neuron
            tot_num_neurons = s1_nb_hidden + s1_nb_outputs
            N_samp = 6412
            T = nb_steps
            sl = 1
            thetal = 0.01
            su = 0.06
            thetau = 100
            tmp = (torch.clamp((1 / T) * torch.sum(spks, 1) - thetal, min=0.)) ** 2
            L1_batch = torch.sum(tmp, (0, 1))
            reg_loss = (sl / (N_samp + tot_num_neurons)) * L1_batch

            tmp2 = (torch.clamp((1 / nb_hidden) * torch.sum(spks, (1, 2)) - thetau, min=0.)) ** 2
            L2_batch = torch.sum(tmp2)
            reg_loss += (su / N_samp) * L2_batch
            
            loss_val = ground_loss + reg_loss

            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()
            
            mod_factors.data.clamp_(min=0)
            # Clamping the values
            with torch.no_grad():
                s1_alpha_hetero_1.clamp_(0.367, 0.995)
                s1_beta_hetero_1.clamp_(0.367, 0.995)
                s1_alpha_hetero_2.clamp_(0.367, 0.995)
                s1_beta_hetero_2.clamp_(0.367, 0.995)
                s1_thresholds_1.clamp_(0.5, 1.5)
            
            local_loss.append(loss_val.item())
            local_ground_loss.append(ground_loss.item())
            local_reg_loss.append(reg_loss.item())

        mean_loss = np.mean(local_loss)
        loss_hist.append(mean_loss)
        print(f"Epoch {epoch + 1}: loss={mean_loss:.5f}")
        print("ground_loss", np.mean(local_ground_loss))
        print("reg_loss", np.mean(local_reg_loss))
        train_accuracy = np.mean(accs)
        train_acc_hist.append(train_accuracy)

        test_accuracy = compute_classification_accuracy_double_hetero(x_test, y_test, s2_interval)
        test_acc_hist.append(test_accuracy)
        print(f"Epoch {epoch + 1}: Train= {train_accuracy:.5f} Test Accuracy={test_accuracy:.5f}")

        if test_accuracy > best_accuracy:
            best_accuracy = test_accuracy
            print("best_accuracy", best_accuracy)

            # Create a dictionary of current parameters to save
            saved_weights_snn = {
                's1_w1': s1_w1,
                's1_w2': s1_w2,
                's1_v1': s1_v1,
                's2_w1': s2_w1,
                's2_v1': s2_v1,
                'mod_factors': mod_factors,
                's1_thresholds_1': s1_thresholds_1,
                's1_reset_1': s1_reset_1,
                's1_rest_1': s1_rest_1,
                's2_thresholds_1': s2_thresholds_1,
                's2_reset_1': s2_reset_1,
                's2_rest_1': s2_rest_1,
                's2_alpha_hetero_1': s2_alpha_hetero_1,
                's2_beta_hetero_1': s2_beta_hetero_1,
                's1_alpha_hetero_1': s1_alpha_hetero_1,
                's1_beta_hetero_1': s1_beta_hetero_1,
                's1_alpha_hetero_2': s1_alpha_hetero_2,
                's1_beta_hetero_2': s1_beta_hetero_2
            }

            directory = 'SSC_Double_Hetero'

            # Create the directory if it does not exist
            if not os.path.exists(directory):
                os.makedirs(directory)

            # Save parameters along with the current epoch and accuracy
            file_path = os.path.join(directory, 'best_double_hetero.pth')

            torch.save({
                'epoch': epoch + 1,
                'accuracy': test_accuracy,
                'params': saved_weights_snn,
                'loss': loss_hist,
                'train_acc_hist': train_acc_hist,
                'test_acc_hist': test_acc_hist
            }, file_path)
        else:
            print("best_accuracy", best_accuracy)

    return loss_hist

s_batch_size = 64

# SNN 1 Network
s1_nb_inputs  = 700
s1_nb_hidden  = 200
s1_nb_outputs = 35

s_group_size = 1

# SNN 2 Network
s2_nb_inputs = 2006
s2_nb_hidden = int(1070 * 2 / s_group_size)

# Initialize SNN 1 and SNN 2 parameters

# SNN 1 parameters
tau_syn = 10e-3
tau_mem = 20e-3

s1_thresholds_1 = nn.Parameter(torch.ones((1, s1_nb_hidden), device=device, dtype=dtype, requires_grad=True))
s1_reset_1 = nn.Parameter(torch.zeros((1, s1_nb_hidden), device=device, dtype=dtype, requires_grad=True))
s1_rest_1 = nn.Parameter(torch.zeros((1, s1_nb_hidden), device=device, dtype=dtype, requires_grad=True))

const_alpha = float(np.exp(-time_step/tau_syn))
const_beta = float(np.exp(-time_step/tau_mem))

s1_alpha_hetero_1 = nn.Parameter(torch.full((1, s1_nb_hidden), const_alpha, device=device, dtype=dtype, requires_grad=True))
s1_beta_hetero_1 = nn.Parameter(torch.full((1, s1_nb_hidden), const_beta, device=device, dtype=dtype, requires_grad=True))

s1_alpha_hetero_2 = nn.Parameter(torch.full((1, s1_nb_outputs), const_alpha, device=device, dtype=dtype, requires_grad=True))
s1_beta_hetero_2 = nn.Parameter(torch.full((1, s1_nb_outputs), const_beta, device=device, dtype=dtype, requires_grad=True))

# SNN 2 parameters

s2_thresholds_1 = nn.Parameter(torch.ones((1, s2_nb_hidden), device=device, dtype=dtype, requires_grad=True))
s2_reset_1 = nn.Parameter(torch.zeros((1, s2_nb_hidden), device=device, dtype=dtype, requires_grad=True))
s2_rest_1 = nn.Parameter(torch.zeros((1, s2_nb_hidden), device=device, dtype=dtype, requires_grad=True))

s2_alpha_hetero_1 = nn.Parameter(torch.full((1, s2_nb_hidden), const_alpha, device=device, dtype=dtype, requires_grad=True))
s2_beta_hetero_1 = nn.Parameter(torch.full((1, s2_nb_hidden), const_beta, device=device, dtype=dtype, requires_grad=True))

# Initialize modulation factors as a single tensor
mod_factors = nn.Parameter(torch.full((1, s2_nb_hidden), 0.01, device=device, dtype=dtype, requires_grad=True))

# SNN 1 weights
s1_weight_scale = 0.2

s1_w1 = torch.empty((s1_nb_inputs, s1_nb_hidden), device=device, dtype=dtype, requires_grad=True)
torch.nn.init.normal_(s1_w1, mean=0.0, std=s1_weight_scale/np.sqrt(s1_nb_inputs))

s1_w2 = torch.empty((s1_nb_hidden, s1_nb_outputs), device=device, dtype=dtype, requires_grad=True)
torch.nn.init.normal_(s1_w2, mean=0.0, std=s1_weight_scale/np.sqrt(s1_nb_hidden))

s1_v1 = torch.empty((s1_nb_hidden, s1_nb_hidden), device=device, dtype=dtype, requires_grad=True)
torch.nn.init.normal_(s1_v1, mean=0.0, std=s1_weight_scale/np.sqrt(s1_nb_hidden))

# SNN 2 weights
s2_weight_scale = 0.2

s2_w1 = nn.Parameter(torch.zeros((s2_nb_inputs, s2_nb_hidden), device=device, dtype=dtype, requires_grad=True))
for i in range(min(s2_nb_hidden, s2_nb_inputs, s2_nb_hidden)):
    s2_w1.data[i, i] = 0.2

s2_v1 = nn.Parameter(torch.zeros((s2_nb_hidden, s2_nb_hidden), device=device, dtype=dtype, requires_grad=True))
for i in range(min(s2_nb_hidden, s2_nb_hidden)):
    s2_v1.data[i, i] = 0.2
    
# print(s2_v1)

loaded_weights_snn = torch.load('Python_Tests/SSC_hetero_no_reg/epochs_hetero/snn_13.pth', map_location=torch.device(device))

# Convert tensors to parameters and ensure they are leaf tensors by re-wrapping them
# Move tensors to device first and then wrap them as parameters
s1_w1 = torch.nn.Parameter(loaded_weights_snn['params']['w1'].to(device))
s1_w2 = torch.nn.Parameter(loaded_weights_snn['params']['w2'].to(device))
s1_v1 = torch.nn.Parameter(loaded_weights_snn['params']['v1'].to(device))
s1_alpha_hetero_1 = torch.nn.Parameter(loaded_weights_snn['params']['alpha'].to(device))
s1_beta_hetero_1 = torch.nn.Parameter(loaded_weights_snn['params']['beta'].to(device))
s1_thresholds_1 = torch.nn.Parameter(loaded_weights_snn['params']['threshold'].to(device))
s1_reset_1 = torch.nn.Parameter(loaded_weights_snn['params']['reset'].to(device))
s1_rest_1 = torch.nn.Parameter(loaded_weights_snn['params']['rest'].to(device))
s1_alpha_hetero_2 = torch.nn.Parameter(loaded_weights_snn['params']['alpha_2'].to(device))
s1_beta_hetero_2 = torch.nn.Parameter(loaded_weights_snn['params']['beta_2'].to(device))
print(loaded_weights_snn['accuracy'])

print(s1_beta_hetero_2)

# INTERVAL (DOUBLE) HERE
# s2_interval = 5
# nb_epochs_double = 50
# s_batch_size = 64
# lr_double = 2e-4
# loss_hist_snn = train_double_snn(x_train, y_train, lr=lr_double, nb_epochs=nb_epochs_double, s2_interval = s2_interval)

cuda
0.5602889150943396
Parameter containing:
tensor([[0.8435, 0.7279, 0.8964, 0.8699, 0.8599, 0.8616, 0.8820, 0.7949, 0.8599,
         0.8721, 0.8755, 0.9605, 0.7079, 0.8238, 0.7247, 0.4365, 0.6525, 0.7271,
         0.3876, 0.7894, 0.7537, 0.9026, 0.8108, 0.9147, 0.8894, 0.7979, 0.8469,
         0.7822, 0.6684, 0.8756, 0.8728, 0.9136, 0.7775, 0.8649, 0.7798]],
       device='cuda:0', requires_grad=True)
