In [3]:
import torch
import snntorch as snn
import snntorch.functional as SF
import snntorch.spikegen as spikegen

In [4]:
def generate_spike_trains(observation, num_steps, threshold, shift):
    """
    Generate spike trains from a single observation using a fixed global threshold.
    
    Parameters:
    - observation: A tensor representing the observation ([observation_dim]).
    - num_steps: The number of timesteps for the spike train.
    - threshold: A single global threshold value to be used for normalization.
    
    Returns:
    - spike_trains: Tensor of spike trains.
    """
    # Normalize and clip observation
    shifted_obs = observation + shift 

    normalized_obs = shifted_obs / (threshold + 1e-6)  # Avoid division by zero

    normalized_obs /= 2

    print("normalized obs ", normalized_obs)
    
    normalized_obs = normalized_obs.clamp(0, 1)  # Clip values to be within [0, 1]

    
    # Generate spike trains
    spike_trains = spikegen.rate(normalized_obs, num_steps=num_steps)
    
    return spike_trains

In [5]:
def generate_spike_trains_batched(observations, num_steps, threshold, shift):
    """
    Generate spike trains from batched observations using a fixed global threshold.
    
    Parameters:
    - observations: A tensor representing the batched observations ([batch_size, observation_dim]).
    - num_steps: The number of timesteps for the spike train.
    - threshold: A single global threshold value to be used for normalization.
    - shift: A value to shift the observation range to handle negative values.
    
    Returns:
    - spike_trains: Tensor of spike trains with shape (batch_size, num_steps, observation_dim).
    """
    # Normalize and shift observations
    normalized_obs = (observations + shift) / (2 * (threshold + 1e-6))  # Avoid division by zero
    normalized_obs = normalized_obs.clamp(0, 1)  # Clip values to [0, 1]

    # Generate spike trains for each observation in the batch
    spike_trains = spikegen.rate(normalized_obs, num_steps=num_steps)
    
    # Rearrange the output to have shape (batch_size, num_steps, observation_dim)
    spike_trains = spike_trains.permute(1, 0, 2)
    
    return spike_trains

In [6]:
def get_spike_counts(spike_trains):
    """
    Get the total number of spikes for each neuron over all timesteps.
    
    Parameters:
    - spike_trains: Tensor of spike trains with shape [num_steps, observation_dim].
    
    Returns:
    - Array of spike counts for each neuron.
    """
    spike_counts = torch.sum(spike_trains, dim=0)
    return spike_counts.tolist()

In [7]:
def get_spike_counts_batched(spike_trains):
    """
    Get the total number of spikes for each neuron over all timesteps for batched spike trains.
    
    Parameters:
    - spike_trains: Tensor of spike trains with shape [batch_size, num_steps, observation_dim].
    
    Returns:
    - Array of spike counts for each neuron in each observation (shape: [batch_size, observation_dim]).
    """
    # Sum over the time dimension (dim=1) to get spike counts for each neuron in each observation
    spike_counts = torch.sum(spike_trains, dim=1)
    
    return spike_counts.tolist()  # Convert to a list for easier interpretation

In [8]:
def decode_first_spike_batched(spike_trains):
    """
    Decodes the first spike time from spike trains for batched data using 'time to first spike' method.

    Parameters:
        spike_trains - The batched spike trains with shape (batch_size, num_steps, num_neurons).

    Returns:
        decoded_vector - The decoded first spike times with shape (batch_size, num_neurons).
    """
    batch_size = spike_trains.size(0)
    num_neurons = spike_trains.size(2)
    decoded_vectors = []

    for batch_idx in range(batch_size):
        decoded_vector = [spike_trains.size(1)+1] * num_neurons
        
        for neuron_idx in range(num_neurons):
            first_spike = (spike_trains[batch_idx, :, neuron_idx] == 1).nonzero(as_tuple=True)[0]
            if first_spike.nelement() != 0:
                decoded_vector[neuron_idx] = first_spike[0].item() + 1
        
        decoded_vectors.append(decoded_vector)

    return torch.FloatTensor(decoded_vectors)

In [9]:
def decode_first_spike(spike_trains):
    """
    Decodes the first spike time from spike trains using the 'time to first spike' method for non-batched data.

    Parameters:
        spike_trains - The spike trains with shape (num_steps, num_neurons).

    Returns:
        decoded_vector - The decoded first spike times with shape (num_neurons,).
    """
    num_steps = spike_trains.size(0)
    num_neurons = spike_trains.size(1)
    
    # Initialize decoded vector with default values greater than the maximum possible spike time
    decoded_vector = [num_steps + 1] * num_neurons

    # Iterate over each neuron to find the first spike time
    for neuron_idx in range(num_neurons):
        first_spike = (spike_trains[:, neuron_idx] == 1).nonzero(as_tuple=True)[0]
        if first_spike.nelement() != 0:
            decoded_vector[neuron_idx] = first_spike[0].item() + 1  # +1 to convert 0-based index to 1-based time step

    return torch.FloatTensor(decoded_vector)

In [10]:
def decode_count(spikes):
    spike_counts = torch.sum(spikes, dim=1)
    action = torch.zeros(spikes.size(0))
    max_spike_count = torch.max(spike_counts)
    candidates = torch.where(spike_counts == max_spike_count)[0]
    if len(candidates) > 1:
        action[torch.multinomial(candidates.float(), 1)] = 1
    else:
        action[candidates] = 1
    return action

In [15]:
# non-batch version
observation = torch.tensor([1.2, 0.5, 2.0, -5, 1.0, -4.5, 0.8, 0.3])  # Example observation
threshold = torch.tensor([1.5, 1.5, 5, 5, 3.14, 5, 1, 1])
shift = torch.tensor([1.5, 1.5, 5, 5, 3.14, 5, 1, 1])
print(observation.shape)
spike_trains = generate_spike_trains(observation, num_steps=100, threshold=threshold, shift=shift)

print("shape of spike trains",spike_trains)  # [num_steps, observation_dim]
# Get the first spike times as an array
first_spike_times = decode_first_spike(spike_trains)
print("First spike times:", first_spike_times)

# Get the spike counts as an array
spike_counts = get_spike_counts(spike_trains)
print("Spike counts:", spike_counts)

torch.Size([8])
normalized obs  tensor([0.9000, 0.6667, 0.7000, 0.0000, 0.6592, 0.0500, 0.9000, 0.6500])
shape of spike trains tensor([[1., 0., 0., 0., 0., 0., 1., 1.],
        [1., 0., 1., 0., 1., 1., 1., 1.],
        [1., 1., 1., 0., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0., 0., 1., 1.],
        [1., 0., 0., 0., 1., 0., 1., 1.],
        [1., 1., 1., 0., 1., 0., 1., 0.],
        [1., 0., 0., 0., 1., 0., 1., 0.],
        [1., 1., 1., 0., 1., 0., 1., 1.],
        [1., 1., 1., 0., 1., 0., 1., 1.],
        [0., 1., 0., 0., 0., 0., 1., 1.],
        [0., 1., 1., 0., 0., 0., 1., 0.],
        [1., 1., 1., 0., 1., 0., 1., 1.],
        [1., 0., 1., 0., 0., 0., 1., 1.],
        [1., 1., 1., 0., 0., 0., 1., 0.],
        [1., 1., 1., 0., 1., 0., 1., 1.],
        [1., 1., 1., 0., 1., 0., 1., 0.],
        [1., 1., 1., 0., 1., 0., 1., 0.],
        [1., 1., 1., 0., 1., 0., 1., 0.],
        [0., 1., 1., 0., 0., 0., 1., 1.],
        [1., 1., 1., 0., 1., 0., 1., 1.],
        [1., 0., 1., 0., 1., 0., 

In [12]:
# batch version
observation = torch.tensor([[1.5, -0.5, -5.0, -0.0, 1.0, -4.5, 0.8, 0.3],
                             [0.2, 0.5, -2.0, 1.0, 0.5, -1.5, -0.8, -0.3],
                             [-1.2, -0.5, -0.2, -0.3, -1.0, 0.4, 0.2, 0.1],
                             [0.5, 0.5, 0.2, 0.3, 0.1, 0.0, -0.2, 0.3]])

threshold = torch.tensor([1.5, 1.5, 5, 5, 3.14, 5, 1, 1])
shift = torch.tensor([1.5, 1.5, 5, 5, 3.14, 5, 1, 1])
print(observation.shape)
spike_trains = generate_spike_trains_batched(observation, num_steps=100, threshold=threshold, shift=shift)

print(spike_trains.shape)  # [num_steps, observation_dim]
# Get the first spike times as an array
first_spike_times = decode_first_spike_batched(spike_trains)
print("First spike times:", first_spike_times)

# Get the spike counts as an array
spike_counts = get_spike_counts_batched(spike_trains)

spike_counts = torch.tensor(spike_counts)
print("Spike counts:", spike_counts)

torch.Size([4, 8])
torch.Size([4, 100, 8])
First spike times: tensor([[  1.,   1., 101.,   2.,   1.,  40.,   1.,   2.],
        [  2.,   2.,   1.,   1.,   2.,   1.,   5.,   5.],
        [  1.,   1.,   2.,   1.,   3.,   1.,   1.,   1.],
        [  1.,   2.,   4.,   1.,   3.,   1.,   1.,   1.]])
Spike counts: tensor([[100.,  33.,   0.,  52.,  67.,   5.,  91.,  68.],
        [ 54.,  65.,  31.,  71.,  60.,  23.,  12.,  39.],
        [  6.,  38.,  45.,  46.,  36.,  55.,  65.,  51.],
        [ 68.,  62.,  47.,  48.,  54.,  48.,  50.,  56.]])
