In [None]:
import torch
import snntorch
import snntorch as snn
from snntorch import spikegen
import torch.nn as nn

## SNN def

In [None]:
hidden_size = 300  # number of hidden neurons


class SNN(nn.Module):
    def __init__(self, input_size, output_size, num_steps):
        super(SNN, self).__init__()

        self.num_steps = num_steps
        beta1 = 0.9 # global decay rate for all leaky neurons in layer 1
        beta2 = torch.rand((output_size), dtype = torch.float) # independent decay rate for each leaky neuron in layer 2: [0, 1)

        # Initialize layers using snnTorch's Leaky integrate-and-fire neurons
        self.fc1 = snn.Linear(input_size, hidden_size)
        self.lif1 = snn.Leaky(beta=beta1)  # beta is the decay rate of the membrane potential
        self.fc2 = snn.Linear(hidden_size, output_size)
        self.lif2 = snn.Leaky(beta=beta2)

    def forward(self, x):
        # Initialize the membrane potentials to zero for each forward pass
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        spk2_rec = [] # record output spikes
        mem2_rec = [] # record output hidden states

        # Loop over time steps
        for step in range(self.num_steps):
            cur1 = self.fc1(x)
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            spk2_rec.append(spk2) # record spikes
            mem2_rec.append(mem2) # record membrane

        return torch.stack(spk2_rec), torch.stack(mem2_rec)

In [None]:
def decode_time_to_first_spike(spike_trains):
    
    decoded_values = torch.full((spike_trains.size(0),), -1.0)

    for i, train in enumerate(spike_trains):

        first_spike_time = torch.argmax(train) 

        if train[first_spike_time] == 1:
            decoded_values[i] = first_spike_time.float() 

    return decoded_values

In [None]:
def spike_counting_decoding(spikes):
    spike_counts = torch.sum(spikes, dim=1)
    action = torch.zeros(spikes.size(0))
    max_spike_count = torch.max(spike_counts)
    candidates = torch.where(spike_counts == max_spike_count)[0]
    if len(candidates) > 1:
        action[torch.multinomial(candidates.float(), 1)] = 1
    else:
        action[candidates] = 1
    return action

In [None]:
from snntorch import spikegen

In [None]:
def encode_to_spikes(data, num_steps):
    """
    Encodes analog signals into spike trains using rate encoding.

    Parameters:
        data - The continuous-valued data to be encoded.
        num_steps - The number of time steps for the spike train.

    Returns:
        spike_train - The encoded spike train.
    """
    # Normalize the data to be between 0 and 1
    normalized_data = (data - data.min()) / (data.max() - data.min())

    # Convert normalized data to spike trains
    # TODO rate vs latency vs delta
    spike_train = spikegen.rate(normalized_data, num_steps=num_steps)

    return spike_train