In [233]:
import snntorch.functional as SF
import snntorch.spikegen as spikegen
import snntorch as snn

import torch, torch.nn as nn
import numpy as np
import time

from torch.distributions import MultivariateNormal, Categorical


hidden_size = 300  # number of hidden neurons


In [234]:
"""
    This file contains a neural network module for us to
    define our actor and critic networks in PPO.
"""

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np


class FFNetwork(nn.Module):
    """
        A standard in_dim-64-64-out_dim Feed Forward Neural Network.
    """

    def __init__(self, in_dim, out_dim):
        """
            Initialize the network and set up the layers.

            Parameters:
                in_dim - input dimensions as an int
                out_dim - output dimensions as an int

                Return:
                None
        """
        super(FFNetwork, self).__init__()

        self.layer1 = nn.Linear(in_dim, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, out_dim)

    def forward(self, obs):
        """
            Runs a forward pass on the neural network.

            Parameters:
                obs - observation to pass as input

            Return:
                output - the output of our forward pass
        """

        # Convert observation to tensor if it's a numpy array
        if isinstance(obs, np.ndarray):
            obs = torch.tensor(obs, dtype=torch.float)


        activation1 = torch.relu(self.layer1(obs))
        activation2 = torch.relu(self.layer2(activation1))
        output = self.layer3(activation2)

        return output

In [235]:
hidden_size = 300  # Number of hidden neurons

class SNN(nn.Module):
    def __init__(self, input_size, output_size, num_steps):
        super(SNN, self).__init__()

        self.num_steps = num_steps
        beta1 = 0.9
        beta2 = 0.9
        beta3 = torch.rand((output_size), dtype=torch.float)  # Independent decay rate for each output neuron

        # Define layers
        self.fc1 = nn.Linear(input_size, hidden_size, dtype=torch.float)
        self.fc1.weight.data += 0.0075
        self.lif1 = snn.Leaky(beta=beta1)

        self.fc2 = nn.Linear(hidden_size, hidden_size, dtype=torch.float)
        self.fc2.weight.data += 0.0075
        self.lif2 = snn.Leaky(beta=beta2)

        self.fc3 = nn.Linear(hidden_size, output_size, dtype=torch.float)
        self.fc3.weight.data += 0.0075
        self.lif3 = snn.Leaky(beta=beta3, learn_beta=True)

    def forward(self, x):
        # Convert observation to tensor if it's a numpy array
        if isinstance(x, np.ndarray):
            x = torch.tensor(x, dtype=torch.float)

        # Determine if input is batched or not
        is_batched = x.dim() == 3  # [batch_size, num_steps, input_size] is 3D

        if not is_batched:
            # If not batched, add a batch dimension
            x = x.unsqueeze(0)  # Shape becomes [1, num_steps, input_size]


        batch_size = x.size(0)  # This is 1 if not batched, otherwise the actual batch size

        # Initialize membrane potentials
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        # Record the spikes from the last layer
        spk3_rec = []
        mem3_rec = []

        for step in range(self.num_steps):
            cur1 = self.fc1(x[:, step, :])
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            cur3 = self.fc3(spk2)
            spk3, mem3 = self.lif3(cur3, mem3)

            spk3_rec.append(spk3)
            mem3_rec.append(mem3)


        output_spk = torch.stack(spk3_rec, dim=1)  # Shape: [batch_size, num_steps, output_size]
        output_mem = torch.stack(mem3_rec, dim=1)  # Shape: [batch_size, num_steps, output_size]

        if not is_batched:
            # Remove the batch dimension if it was added
            output_spk = output_spk.squeeze(0)  # Shape becomes [num_steps, output_size]
            output_mem = output_mem.squeeze(0)  # Shape becomes [num_steps, output_size]

        return output_spk, output_mem

In [236]:
class SNN_NonBatched(nn.Module):
    def __init__(self, input_size, output_size, num_steps):
        super(SNN_NonBatched, self).__init__()

        self.num_steps = num_steps
        beta1 = 0.9
        beta2 = 0.9
        beta3 = torch.rand((output_size), dtype=torch.float)  # Independent decay rate for each output neuron

        # Define layers
        self.fc1 = nn.Linear(input_size, hidden_size, dtype=torch.float)
        self.fc1.weight.data += 0.01
        self.lif1 = snn.Leaky(beta=beta1)

        self.fc2 = nn.Linear(hidden_size, hidden_size, dtype=torch.float)
        self.fc2.weight.data += 0.01
        self.lif2 = snn.Leaky(beta=beta2)

        self.fc3 = nn.Linear(hidden_size, output_size, dtype=torch.float)
        self.fc3.weight.data += 0.01
        self.lif3 = snn.Leaky(beta=beta3, learn_beta=True)

    def forward(self, x):
        # Initialize membrane potentials
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        # Record the spikes from the last layer
        spk3_rec = []
        mem3_rec = []


        for step in range(self.num_steps):
            cur1 = self.fc1(x[step, :])
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            cur3 = self.fc3(spk2)
            spk3, mem3 = self.lif3(cur3, mem3)

            spk3_rec.append(spk3)
            mem3_rec.append(mem3)


        return torch.stack(spk3_rec, dim=0), torch.stack(mem3_rec, dim=1)  # Shape: [num_steps, output_size]

In [237]:
hidden_size = 300  # Number of hidden neurons

class SNN_Batched(nn.Module):
    def __init__(self, input_size, output_size, num_steps):
        super(SNN_Batched, self).__init__()

        self.num_steps = num_steps
        beta1 = 0.9
        beta2 = 0.9
        beta3 = torch.rand((output_size), dtype=torch.float)  # Independent decay rate for each output neuron

        # Define layers
        self.fc1 = nn.Linear(input_size, hidden_size, dtype=torch.float)
        self.fc1.weight.data += 0.01
        self.lif1 = snn.Leaky(beta=beta1)

        self.fc2 = nn.Linear(hidden_size, hidden_size, dtype=torch.float)
        self.fc2.weight.data += 0.01
        self.lif2 = snn.Leaky(beta=beta2)

        self.fc3 = nn.Linear(hidden_size, output_size, dtype=torch.float)
        self.fc3.weight.data += 0.01
        self.lif3 = snn.Leaky(beta=beta3, learn_beta=True)

    def forward(self, x):
        # Initialize membrane potentials
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        # Record the spikes from the last layer
        spk3_rec = []
        mem3_rec = [] # record output hidden states


        for step in range(self.num_steps):
            cur1 = self.fc1(x[:, step, :])
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            cur3 = self.fc3(spk2)
            spk3, mem3 = self.lif3(cur3, mem3)

            spk3_rec.append(spk3)
            mem3_rec.append(mem3)


        return torch.stack(spk3_rec, dim=1), torch.stack(mem3_rec, dim=1) # Shape: [batch_size, num_steps, output_size]



In [238]:
def generate_spike_trains(observation, num_steps, threshold, shift):
    """
    Generate spike trains from a single observation using a fixed global threshold.
    
    Parameters:
    - observation: A tensor representing the observation ([observation_dim]).
    - num_steps: The number of timesteps for the spike train.
    - threshold: A single global threshold value to be used for normalization.
    
    Returns:
    - spike_trains: Tensor of spike trains.
    """
    
    shift = shift.numpy()

    # Normalize and clip observation
    shifted_obs = np.add(observation, shift) 

    # torch version
    #shifted_obs = observation + shift


    normalized_obs = shifted_obs / (threshold + 1e-6)  # Avoid division by zero

    normalized_obs /= 2
    
    normalized_obs = normalized_obs.clamp(0, 1)  # Clip values to be within [0, 1]

    
    # Generate spike trains
    spike_trains = spikegen.rate(normalized_obs, num_steps=num_steps)
    
    # torch version
    #return spike_trains

    return spike_trains.numpy()

In [239]:
def generate_spike_trains_batched(observations, num_steps, threshold, shift):
    """
    Generate spike trains from batched observations using a fixed global threshold.
    
    Parameters:
    - observations: A tensor representing the batched observations ([batch_size, observation_dim]).
    - num_steps: The number of timesteps for the spike train.
    - threshold: A single global threshold value to be used for normalization.
    - shift: A value to shift the observation range to handle negative values.
    
    Returns:
    - spike_trains: Tensor of spike trains with shape (batch_size, num_steps, observation_dim).
    """

    shift = shift.numpy()


    # Normalize and shift observations
    normalized_obs = np.add(observations, shift) / (2 * (threshold + 1e-6))  # Avoid division by zero
    normalized_obs = normalized_obs.clamp(0, 1)  # Clip values to [0, 1]

    # Generate spike trains for each observation in the batch
    spike_trains = spikegen.rate(normalized_obs, num_steps=num_steps)
    
    # Rearrange the output to have shape (batch_size, num_steps, observation_dim)
    spike_trains = spike_trains.permute(1, 0, 2)
    
    # torch version
    #return spike_trains

    return spike_trains.numpy()

In [240]:
def get_spike_counts(spike_trains):
    """
    Get the total number of spikes for each neuron over all timesteps.
    
    Parameters:
    - spike_trains: Tensor of spike trains with shape [num_steps, observation_dim].
    
    Returns:
    - Array of spike counts for each neuron.
    """
    spike_counts = torch.sum(spike_trains, dim=0)
    return spike_counts

In [241]:
def get_spike_counts_batched(spike_trains):
    """
    Get the total number of spikes for each neuron over all timesteps for batched spike trains.
    
    Parameters:
    - spike_trains: Tensor of spike trains with shape [batch_size, num_steps, observation_dim].
    
    Returns:
    - Array of spike counts for each neuron in each observation (shape: [batch_size, observation_dim]).
    """
    # Sum over the time dimension (dim=1) to get spike counts for each neuron in each observation
    spike_counts = torch.sum(spike_trains, dim=1)
    
    return spike_counts

In [242]:
def decode_first_spike_batched(spike_trains):
    """
    Decodes the first spike time from batched spike trains using the 'time to first spike' method.
    
    Parameters:
        spike_trains - The batched spike trains with shape (batch_size, num_steps, num_neurons).
    
    Returns:
        decoded_vector - A tensor representing the first spike times for each neuron in each batch with gradients retained.
    """
    batch_size, num_steps, num_neurons = spike_trains.shape

    # Create a tensor with time steps and retain gradients
    time_tensor = torch.arange(1, num_steps + 1, dtype=torch.float32, requires_grad=True).unsqueeze(0).unsqueeze(2).expand(batch_size, num_steps, num_neurons)

    # Multiply spike_trains by the time tensor, masking out non-spike entries
    spike_times = spike_trains * time_tensor

    # Set all zero entries (no spike) to a very high value (greater than num_steps)
    spike_times = spike_times + (1 - spike_trains) * (num_steps+1)

    # Find the minimum value in each column (i.e., first spike) for each batch
    first_spike_times, _ = spike_times.min(dim=1)

    # Ensure that this tensor retains gradients
    return first_spike_times

In [243]:
def decode_first_spike(spike_trains):
    """
    Decodes the first spike time from spike trains using the 'time to first spike' method.
    
    Parameters:
        spike_trains - The spike trains with shape (num_steps, num_neurons).
    
    Returns:
        decoded_vector - A tensor representing the first spike times for each neuron with gradients retained.
    """
    num_steps, num_neurons = spike_trains.shape

    # Create a tensor with time steps and retain gradients
    time_tensor = torch.arange(1, num_steps + 1, dtype=torch.float32, requires_grad=True).unsqueeze(1).expand(num_steps, num_neurons)

    # Multiply spike_trains by the time tensor, masking out non-spike entries
    spike_times = spike_trains * time_tensor

    # Set all zero entries (no spike) to a very high value (greater than num_steps)
    spike_times = spike_times + (1 - spike_trains) * (num_steps+1)

    # Find the minimum value in each column (i.e., first spike)
    first_spike_times, _ = spike_times.min(dim=0)

    # Ensure that this tensor retains gradients
    return first_spike_times

In [244]:
def decode_first_spike_batched_archived(spike_trains):
    """
    Decodes the first spike time from spike trains for batched data using 'time to first spike' method.

    Parameters:
        spike_trains - The batched spike trains with shape (batch_size, num_steps, num_neurons).

    Returns:
        decoded_vector - The decoded first spike times with shape (batch_size, num_neurons).
    """
    batch_size = spike_trains.size(0)
    num_neurons = spike_trains.size(2)
    decoded_vectors = []

    for batch_idx in range(batch_size):
        decoded_vector = [spike_trains.size(1)+1] * num_neurons
        
        for neuron_idx in range(num_neurons):
            first_spike = (spike_trains[batch_idx, :, neuron_idx] == 1).nonzero(as_tuple=True)[0]
            if first_spike.nelement() != 0:
                decoded_vector[neuron_idx] = first_spike[0].item() + 1
        
        decoded_vectors.append(decoded_vector)

    return torch.FloatTensor(decoded_vectors)

In [245]:
def decode_first_spike_archived(spike_trains):
    """
    Decodes the first spike time from spike trains using the 'time to first spike' method for non-batched data.

    Parameters:
        spike_trains - The spike trains with shape (num_steps, num_neurons).

    Returns:
        decoded_vector - The decoded first spike times with shape (num_neurons,).
    """
    num_steps = spike_trains.size(0)
    num_neurons = spike_trains.size(1)
    
    # Initialize decoded vector with default values greater than the maximum possible spike time
    decoded_vector = [num_steps + 1] * num_neurons

    # Iterate over each neuron to find the first spike time
    for neuron_idx in range(num_neurons):
        first_spike = (spike_trains[:, neuron_idx] == 1).nonzero(as_tuple=True)[0]
        if first_spike.nelement() != 0:
            decoded_vector[neuron_idx] = first_spike[0].item() + 1  # +1 to convert 0-based index to 1-based time step

    return torch.FloatTensor(decoded_vector)

In [246]:
def decode_count(spikes):
    spike_counts = torch.sum(spikes, dim=1)
    action = torch.zeros(spikes.size(0))
    max_spike_count = torch.max(spike_counts)
    candidates = torch.where(spike_counts == max_spike_count)[0]
    if len(candidates) > 1:
        action[torch.multinomial(candidates.float(), 1)] = 1
    else:
        action[candidates] = 1
    return action

In [247]:
# non-batch version
observation = np.array([1.2, 0.5, 2.0, -5, 1.0, -4.5, 0.8, 0.3])  # Example observation
threshold = torch.tensor([1.5, 1.5, 5, 5, 3.14, 5, 1, 1])
shift = torch.tensor([1.5, 1.5, 5, 5, 3.14, 5, 1, 1])
print(observation.shape)
spike_trains = generate_spike_trains(observation, num_steps=100, threshold=threshold, shift=shift)

spike_trains = torch.tensor(spike_trains, dtype=torch.float)

print("shape of spike trains",spike_trains.shape)  # [num_steps, observation_dim]
# Get the first spike times as an array
first_spike_times = decode_first_spike(spike_trains)
print("First spike times:", first_spike_times)

# Get the spike counts as an array
spike_counts = get_spike_counts(spike_trains)
print("Spike counts:", spike_counts)

(8,)
shape of spike trains torch.Size([100, 8])
First spike times: tensor([  1.,   2.,   1., 101.,   2.,   9.,   1.,   1.],
       grad_fn=<MinBackward0>)
Spike counts: tensor([87., 64., 62.,  0., 68.,  4., 90., 54.])


In [248]:
# batch version
batch_observation = np.array([[1.5, -0.5, -5.0, -0.0, 1.0, -4.5, 0.8, 0.3],
                             [0.2, 0.5, -2.0, 1.0, 0.5, -1.5, -0.8, -0.3],
                             [-1.2, -0.5, -0.2, -0.3, -1.0, 0.4, 0.2, 0.1],
                             [0.5, 0.5, 0.2, 0.3, 0.1, 0.0, -0.2, 0.3],
                             [-1.5, -1.5, -5, -5, -3.14, -5, -1, -1],
                             [1.5, 1.5, 5, 5, 3.14, 5, 1, 1]])

threshold = torch.tensor([1.5, 1.5, 5, 5, 3.14, 5, 1, 1])
shift = torch.tensor([1.5, 1.5, 5, 5, 3.14, 5, 1, 1])
print(observation.shape)
spike_trains = generate_spike_trains_batched(batch_observation, num_steps=100, threshold=threshold, shift=shift)

spike_trains = torch.tensor(spike_trains, dtype=torch.float)

print(spike_trains.shape)  # [num_steps, observation_dim]
# Get the first spike times as an array
first_spike_times = decode_first_spike_batched(spike_trains)
print("First spike times:", first_spike_times)

# Get the spike counts as an array
spike_counts = get_spike_counts_batched(spike_trains)

print("Spike counts:", spike_counts)

(8,)
torch.Size([6, 100, 8])
First spike times: tensor([[  1.,   2., 101.,   1.,   1.,  11.,   1.,   1.],
        [  2.,   1.,   6.,   1.,   1.,   1.,  23.,   2.],
        [  4.,   1.,   4.,   2.,   1.,   1.,   1.,   1.],
        [  1.,   2.,   4.,   1.,   1.,   1.,   2.,   1.],
        [101., 101., 101., 101., 101., 101., 101., 101.],
        [  1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.]],
       grad_fn=<MinBackward0>)
Spike counts: tensor([[100.,  23.,   0.,  50.,  71.,   6.,  89.,  71.],
        [ 54.,  70.,  27.,  59.,  56.,  33.,   9.,  32.],
        [ 14.,  33.,  50.,  44.,  33.,  62.,  65.,  54.],
        [ 69.,  59.,  52.,  52.,  51.,  50.,  47.,  69.],
        [  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [100., 100., 100., 100., 100., 100., 100., 100.]])


## SNN code rundown

### init

In [249]:
num_steps = 50

obs_dim = 8
act_dim = 4

actor_SNN = SNN(obs_dim, act_dim, num_steps)
critic_SNN = SNN(obs_dim, 1, num_steps)

actor_ANN = FFNetwork(obs_dim, act_dim)
critic_ANN = FFNetwork(obs_dim, 1)

cov_var = torch.full(size=(act_dim,), fill_value=0.5)
cov_mat = torch.diag(cov_var)

threshold = torch.tensor([1.5, 1.5, 5, 5, 3.14, 5, 1, 1])
shift = torch.tensor([1.5, 1.5, 5, 5, 3.14, 5, 0, 0])


### making sure the data is well distributed after

In [250]:
batch_observation = np.array([[-1.5, -1.5, -5, -5, -3.14, -5, 0, 0],
                              [-1.4, -1.4, -4, -4, -3.0, -4, 1, 1],
                              [-1.2, -1.2, -2, -2, -2.5, -2, 1, 0],
                              [-1.0, -1.0, -0, -0, -1.5, 2, 1, 0],
                              [1.2, 1.3, 2, 2, 2.1, 4, 1, 0],

                              [-1.4, -1.5, -0, -2, -3.0, -4, 1, 0],
                              [1.2, -1.4, 4, -4, 1.12, -4, 0, 0],
                              [1.1, 1.1, 0, 1, 0.2, 2, 0, 0],
                              [1.5, 1.3, 3, -4, -3.0, 1, 0.0, 0.0],
                              [1.4, 1.4, -4, -4, -3.0, -4, 1, 1],
                             [1.5, 1.5, 5, 5, 3.14, 5, 1, 1]])

print(batch_observation.shape)
batch_observation_st = generate_spike_trains_batched(batch_observation, num_steps=50, threshold=threshold, shift=shift)
batch_observation_st_tensor = torch.tensor(batch_observation_st, dtype=torch.float)

print("obs spike trains shape", batch_observation_st_tensor.shape)  # [num_steps, observation_dim]
print("obs First spike times:", decode_first_spike_batched(batch_observation_st_tensor))
print("obs Spike counts:", get_spike_counts_batched(batch_observation_st_tensor))

action = actor_SNN(batch_observation_st)[0].detach()

print("///////////////////////////////////")

print("action spike trains shape", action.shape)  # [num_steps, observation_dim]
print("action First spike times:", decode_first_spike_batched(action))
print("action Spike counts:", get_spike_counts_batched(action))



(11, 8)
obs spike trains shape torch.Size([11, 50, 8])
obs First spike times: tensor([[51., 51., 51., 51., 51., 51., 51., 51.],
        [25., 22.,  6.,  7., 51.,  3.,  1.,  1.],
        [ 3., 17.,  9.,  1.,  8.,  2.,  2., 51.],
        [13.,  9.,  1.,  1.,  3.,  1.,  1., 51.],
        [ 2.,  1.,  3.,  3.,  2.,  1.,  1., 51.],
        [12., 51.,  2.,  2., 51.,  8.,  2., 51.],
        [ 1., 19.,  1.,  8.,  4.,  8., 51., 51.],
        [ 1.,  1.,  2.,  1.,  1.,  1., 51., 51.],
        [ 1.,  1.,  1.,  3., 21.,  1., 51., 51.],
        [ 1.,  1., 27.,  6., 16.,  1.,  1.,  1.],
        [ 1.,  1.,  1.,  1.,  1.,  1.,  4.,  1.]], grad_fn=<MinBackward0>)
obs Spike counts: tensor([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 1.,  4.,  6.,  5.,  0., 11., 24., 18.],
        [ 5.,  4.,  8., 15.,  4., 17., 25.,  0.],
        [ 4.,  9., 20., 26., 10., 37., 22.,  0.],
        [47., 48., 32., 34., 36., 46., 29.,  0.],
        [ 4.,  0., 25., 17.,  0.,  3., 26.,  0.],
        [46.,  2., 44.,  4., 

In [18]:
batch_obs_st = generate_spike_trains_batched(batch_observation, num_steps=num_steps, threshold=threshold, shift=threshold)

batch_obs_st = torch.tensor(batch_obs_st, dtype=torch.float)


print(batch_obs_st.shape)
print("obs first spikes:", decode_first_spike_batched(batch_obs_st))
print("obs spike counts:", get_spike_counts_batched(batch_obs_st))


action = actor_SNN(batch_obs_st)[0].detach()

print(action.shape)

print("action first spikes:", decode_first_spike_batched(action))
print("action spike counts:", get_spike_counts_batched(action))




torch.Size([6, 100, 8])
obs first spikes: tensor([[  1.,   1., 101.,   1.,   1.,   7.,   1.,   1.],
        [  1.,   2.,   2.,   2.,   1.,   2.,  20.,   8.],
        [ 32.,   1.,   3.,   1.,   2.,   2.,   2.,   1.],
        [  2.,   1.,   1.,   1.,   1.,   1.,   2.,   1.],
        [101., 101., 101., 101., 101., 101., 101., 101.],
        [  1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.]],
       grad_fn=<MinBackward0>)
obs spike counts: tensor([[100.,  41.,   0.,  43.,  67.,   5.,  92.,  67.],
        [ 59.,  73.,  26.,  58.,  60.,  41.,   9.,  31.],
        [  3.,  30.,  51.,  52.,  37.,  54.,  62.,  56.],
        [ 73.,  66.,  52.,  52.,  51.,  41.,  45.,  65.],
        [  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [100., 100., 100., 100., 100., 100., 100., 100.]])
torch.Size([6, 100, 4])
action first spikes: tensor([[ 4.,  4.,  5.,  3.],
        [ 4.,  4.,  4.,  4.],
        [ 4.,  5.,  5.,  4.],
        [ 3.,  5.,  4.,  5.],
        [20.,  8., 24., 24.],
        [ 2.,  3

In [58]:
obs_st= generate_spike_trains(observation, num_steps=num_steps, threshold=threshold, shift=threshold)

obs_st = torch.tensor(obs_st, dtype=torch.float)

print(obs_st.shape)
print("obs first spikes:", decode_first_spike(obs_st))
print("obs spike counts:", get_spike_counts(obs_st))


action = actor_SNN(obs_st)[0]

assert action.shape == torch.Size([100, 4])

print(action.shape)

print("action first spikes:", decode_first_spike(action))
print("action spike counts:", get_spike_counts(action))


torch.Size([100, 4])
action first spikes: tensor([3., 4., 3., 3.], grad_fn=<MinBackward0>)
action spike counts: tensor([95., 79., 84., 81.], grad_fn=<SumBackward1>)


### Evaluate

In [20]:
V_st = critic_SNN(batch_obs_st)[0]
V = decode_first_spike_batched(V_st).squeeze()

print(V)


tensor([3., 3., 4., 3., 8., 2.], grad_fn=<SqueezeBackward0>)


In [21]:
V = critic_ANN(batch_observation).squeeze()
print(V)

tensor([0.3530, 0.1756, 0.1277, 0.0211, 0.4791, 0.1556],
       grad_fn=<SqueezeBackward0>)


In [22]:
mean = actor_ANN(batch_observation)
dist = MultivariateNormal(mean, cov_mat)

print(mean)
print(dist)

tensor([[-0.0122, -0.1634, -0.4114,  0.1835],
        [ 0.0653, -0.1383, -0.2141,  0.0582],
        [ 0.0077, -0.0445, -0.1448, -0.0263],
        [-0.0033, -0.0150, -0.0698,  0.0331],
        [ 0.2616, -0.0920, -0.0461,  0.3458],
        [-0.0077,  0.2907, -0.1952, -0.0573]], grad_fn=<AddmmBackward0>)
MultivariateNormal(loc: torch.Size([6, 4]), covariance_matrix: torch.Size([6, 4, 4]))


In [23]:
mean_st = actor_SNN(batch_obs_st)[0]
mean = decode_first_spike_batched(mean_st)
dist = MultivariateNormal(mean, cov_mat)
print(mean)
print(dist)

tensor([[ 4.,  4.,  5.,  3.],
        [ 4.,  4.,  4.,  4.],
        [ 4.,  5.,  5.,  4.],
        [ 3.,  5.,  4.,  5.],
        [20.,  8., 24., 24.],
        [ 2.,  3.,  3.,  3.]], grad_fn=<MinBackward0>)
MultivariateNormal(loc: torch.Size([6, 4]), covariance_matrix: torch.Size([6, 4, 4]))


In [24]:
logits = actor_ANN(batch_observation)
dist = Categorical(logits=logits)

print(logits)
print(dist)

print(dist.sample())

tensor([[-0.0122, -0.1634, -0.4114,  0.1835],
        [ 0.0653, -0.1383, -0.2141,  0.0582],
        [ 0.0077, -0.0445, -0.1448, -0.0263],
        [-0.0033, -0.0150, -0.0698,  0.0331],
        [ 0.2616, -0.0920, -0.0461,  0.3458],
        [-0.0077,  0.2907, -0.1952, -0.0573]], grad_fn=<AddmmBackward0>)
Categorical(logits: torch.Size([6, 4]))
tensor([1, 2, 1, 2, 3, 2])


In [25]:
logits_st = actor_SNN(batch_obs_st)[0]
logits = decode_first_spike_batched(logits_st)
dist = Categorical(logits=logits)

print(logits)
print(dist)

print(dist.sample())

tensor([[ 4.,  4.,  5.,  3.],
        [ 4.,  4.,  4.,  4.],
        [ 4.,  5.,  5.,  4.],
        [ 3.,  5.,  4.,  5.],
        [20.,  8., 24., 24.],
        [ 2.,  3.,  3.,  3.]], grad_fn=<MinBackward0>)
Categorical(logits: torch.Size([6, 4]))
tensor([2, 1, 1, 3, 3, 3])


### Get Action

In [26]:
mean = actor_ANN(observation)
dist = MultivariateNormal(mean, cov_mat)

print(mean)
print(dist)

tensor([ 0.2900,  0.1700,  0.0757, -0.0335], grad_fn=<ViewBackward0>)
MultivariateNormal(loc: torch.Size([4]), covariance_matrix: torch.Size([4, 4]))


In [27]:
mean_st = actor_SNN(obs_st)[0]
mean = decode_first_spike(mean_st)
dist = MultivariateNormal(mean, cov_mat)
print(mean)
print(dist)

tensor([4., 4., 4., 4.], grad_fn=<MinBackward0>)
MultivariateNormal(loc: torch.Size([4]), covariance_matrix: torch.Size([4, 4]))


In [28]:
logits = actor_ANN(observation)
dist = Categorical(logits=logits)

print(logits)
print(dist)

print(dist.sample())

tensor([ 0.2900,  0.1700,  0.0757, -0.0335], grad_fn=<ViewBackward0>)
Categorical(logits: torch.Size([4]))
tensor(1)


In [29]:
logits_st = actor_SNN(obs_st)[0]
logits = decode_first_spike(logits_st)
dist = Categorical(logits=logits)

print(logits)
print(dist)

print(dist.sample())

tensor([4., 4., 4., 4.], grad_fn=<MinBackward0>)
Categorical(logits: torch.Size([4]))
tensor(2)


In [31]:
def testing(observation, num_steps, threshold, shift):
    """
    Generate spike trains from a single observation using a fixed global threshold.
    
    Parameters:
    - observation: A tensor representing the observation ([observation_dim]).
    - num_steps: The number of timesteps for the spike train.
    - threshold: A single global threshold value to be used for normalization.
    
    Returns:
    - spike_trains: Tensor of spike trains.
    """
    
    shift = shift.numpy()

    # Normalize and clip observation
    shifted_obs = np.add(observation, shift) 

    # torch version
    #shifted_obs = observation + shift


    normalized_obs = shifted_obs / (threshold + 1e-6)  # Avoid division by zero

    normalized_obs /= 2
    
    normalized_obs = normalized_obs.clamp(0, 1)  # Clip values to be within [0, 1]

    
    # Generate spike trains
    spike_trains = spikegen.rate(normalized_obs, num_steps=num_steps)
    
    # torch version
    #return spike_trains

    return spike_trains.numpy()

In [32]:
np_obs = np.array([[1.2, 0.5, 2.0, -5, 1.0, -4.5, 0.8, 0.3],
                  [1.2, 0.5, 2.0, -5, 1.0, -4.5, 0.8, 0.3],
                  [1.2, 0.5, 2.0, -5, 1.0, -4.5, 0.8, 0.3],
                  [1.2, 0.5, 2.0, -5, 1.0, -4.5, 0.8, 0.3],
                  [1.2, 0.5, 2.0, -5, 1.0, -4.5, 0.8, 0.3],
                  [1.2, 0.5, 2.0, -5, 1.0, -4.5, 0.8, 0.3]])



np_obs_st = generate_spike_trains_batched(np_obs, 100, threshold, shift)

#print(np_obs_st.requires_grad)


print(np_obs_st.shape)

yeet = actor_SNN(np_obs_st)[0]

print(type(yeet))
print(yeet.shape)

bruh = decode_first_spike_batched(yeet)

print("huh", bruh)
print("huhhhhh", bruh.requires_grad)


print(get_spike_counts_batched(yeet))

(6, 100, 8)
<class 'torch.Tensor'>
torch.Size([6, 100, 4])
huh tensor([[3., 3., 3., 3.],
        [3., 3., 3., 3.],
        [3., 4., 4., 4.],
        [3., 3., 3., 3.],
        [4., 4., 5., 4.],
        [4., 4., 4., 4.]], grad_fn=<MinBackward0>)
huhhhhh True
tensor([[96., 83., 91., 87.],
        [94., 77., 90., 83.],
        [91., 77., 87., 80.],
        [94., 78., 84., 81.],
        [92., 79., 84., 81.],
        [90., 80., 87., 84.]], grad_fn=<SumBackward1>)
