In [817]:
import snntorch.functional as SF
import snntorch.spikegen as spikegen
import snntorch as snn

import torch, torch.nn as nn
import numpy as np
import time

from torch.distributions import MultivariateNormal, Categorical


hidden_size = 300  # number of hidden neurons


In [818]:
"""
    This file contains a neural network module for us to
    define our actor and critic networks in PPO.
"""

import torch
from torch import nn
import torch.nn.functional as F
import numpy as np


class FFNetwork(nn.Module):
    """
        A standard in_dim-64-64-out_dim Feed Forward Neural Network.
    """

    def __init__(self, in_dim, out_dim):
        """
            Initialize the network and set up the layers.

            Parameters:
                in_dim - input dimensions as an int
                out_dim - output dimensions as an int

                Return:
                None
        """
        super(FFNetwork, self).__init__()

        self.layer1 = nn.Linear(in_dim, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, out_dim)

    def forward(self, obs):
        """
            Runs a forward pass on the neural network.

            Parameters:
                obs - observation to pass as input

            Return:
                output - the output of our forward pass
        """

        # Convert observation to tensor if it's a numpy array
        if isinstance(obs, np.ndarray):
            obs = torch.tensor(obs, dtype=torch.float)


        activation1 = torch.relu(self.layer1(obs))
        activation2 = torch.relu(self.layer2(activation1))
        output = self.layer3(activation2)

        return output

In [819]:
hidden_size = 64  # Number of hidden neurons

class SNN_small(nn.Module):
    def __init__(self, input_size, output_size, num_steps):
        super(SNN_small, self).__init__()

        self.num_steps = num_steps
        beta1 = 0.9
        beta2 = torch.rand((output_size), dtype=torch.float)  # Independent decay rate for each output neuron

        # Define layers
        self.fc1 = nn.Linear(input_size, hidden_size, dtype=torch.float)
        self.fc1.weight.data += 0.0
        self.lif1 = snn.Leaky(beta=beta1)

        self.fc2 = nn.Linear(hidden_size, output_size, dtype=torch.float)
        self.fc2.weight.data += 0.0
        self.lif2 = snn.Leaky(beta=beta2, learn_beta=True)

    def forward(self, x):

        # Convert observation to tensor if it's a numpy array
        if isinstance(x, np.ndarray):
            x = torch.tensor(x, dtype=torch.float)

        # Determine if input is batched or not
        is_batched = x.dim() == 3  # [batch_size, num_steps, input_size] is 3D

        if not is_batched:
            # If not batched, add a batch dimension
            x = x.unsqueeze(0)  # Shape becomes [1, num_steps, input_size]


        batch_size = x.size(0)  # This is 1 if not batched, otherwise the actual batch size

        # Initialize membrane potentials
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Record the spikes from the last layer
        spk2_rec = []
        mem2_rec = []

        for step in range(self.num_steps):
            cur1 = self.fc1(x[:, step, :])
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            spk2_rec.append(spk2)
            mem2_rec.append(mem2)


        output_spk = torch.stack(spk2_rec, dim=1)  # Shape: [batch_size, num_steps, output_size]
        output_mem = torch.stack(mem2_rec, dim=1)  # Shape: [batch_size, num_steps, output_size]

        if not is_batched:
            # Remove the batch dimension if it was added
            output_spk = output_spk.squeeze(0)  # Shape becomes [num_steps, output_size]
            output_mem = output_mem.squeeze(0)  # Shape becomes [num_steps, output_size]

        #print("should not be none :", output_spk.grad_fn)  # This should not be None

        return output_spk, output_mem

In [820]:
hidden_size = 300  # Number of hidden neurons

class SNN(nn.Module):
    def __init__(self, input_size, output_size, num_steps):
        super(SNN, self).__init__()

        self.num_steps = num_steps
        beta1 = 0.9
        beta2 = 0.9
        beta3 = torch.rand((output_size), dtype=torch.float)  # Independent decay rate for each output neuron

        # Define layers
        self.fc1 = nn.Linear(input_size, hidden_size, dtype=torch.float)
        self.fc1.weight.data += 0.005
        self.lif1 = snn.Leaky(beta=beta1)

        self.fc2 = nn.Linear(hidden_size, hidden_size, dtype=torch.float)
        self.fc2.weight.data += 0.005
        self.lif2 = snn.Leaky(beta=beta2)

        self.fc3 = nn.Linear(hidden_size, output_size, dtype=torch.float)
        self.fc3.weight.data += 0.005
        self.lif3 = snn.Leaky(beta=beta3, learn_beta=True)

    def forward(self, x):
        # Convert observation to tensor if it's a numpy array
        if isinstance(x, np.ndarray):
            x = torch.tensor(x, dtype=torch.float)

        # Determine if input is batched or not
        is_batched = x.dim() == 3  # [batch_size, num_steps, input_size] is 3D

        if not is_batched:
            # If not batched, add a batch dimension
            x = x.unsqueeze(0)  # Shape becomes [1, num_steps, input_size]


        batch_size = x.size(0)  # This is 1 if not batched, otherwise the actual batch size

        # Initialize membrane potentials
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        # Record the spikes from the last layer
        spk3_rec = []
        mem3_rec = []

        for step in range(self.num_steps):
            cur1 = self.fc1(x[:, step, :])
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            cur3 = self.fc3(spk2)
            spk3, mem3 = self.lif3(cur3, mem3)

            spk3_rec.append(spk3)
            mem3_rec.append(mem3)


        output_spk = torch.stack(spk3_rec, dim=1)  # Shape: [batch_size, num_steps, output_size]
        output_mem = torch.stack(mem3_rec, dim=1)  # Shape: [batch_size, num_steps, output_size]

        if not is_batched:
            # Remove the batch dimension if it was added
            output_spk = output_spk.squeeze(0)  # Shape becomes [num_steps, output_size]
            output_mem = output_mem.squeeze(0)  # Shape becomes [num_steps, output_size]

        return output_spk, output_mem

In [821]:
def generate_spike_trains(observation, num_steps, threshold, shift):
    """
    Generate spike trains from a single observation using a fixed global threshold.
    
    Parameters:
    - observation: A tensor representing the observation ([observation_dim]).
    - num_steps: The number of timesteps for the spike train.
    - threshold: A single global threshold value to be used for normalization.
    
    Returns:
    - spike_trains: Tensor of spike trains.
    """
    
    shift = shift.numpy()

    # Normalize and clip observation
    shifted_obs = np.add(observation, shift) 

    # torch version
    #shifted_obs = observation + shift


    normalized_obs = shifted_obs / (threshold + 1e-6)  # Avoid division by zero

    normalized_obs /= 2
    
    normalized_obs = normalized_obs.clamp(0, 1)  # Clip values to be within [0, 1]

    
    # Generate spike trains
    spike_trains = spikegen.rate(normalized_obs, num_steps=num_steps)
    
    # torch version
    #return spike_trains

    return spike_trains.numpy()

In [822]:
def generate_spike_trains_batched(observations, num_steps, threshold, shift):
    """
    Generate spike trains from batched observations using a fixed global threshold.
    
    Parameters:
    - observations: A tensor representing the batched observations ([batch_size, observation_dim]).
    - num_steps: The number of timesteps for the spike train.
    - threshold: A single global threshold value to be used for normalization.
    - shift: A value to shift the observation range to handle negative values.
    
    Returns:
    - spike_trains: Tensor of spike trains with shape (batch_size, num_steps, observation_dim).
    """

    shift = shift.numpy()


    # Normalize and shift observations
    normalized_obs = np.add(observations, shift) / (2 * (threshold + 1e-6))  # Avoid division by zero
    normalized_obs = normalized_obs.clamp(0, 1)  # Clip values to [0, 1]

    # Generate spike trains for each observation in the batch
    spike_trains = spikegen.rate(normalized_obs, num_steps=num_steps)
    
    # Rearrange the output to have shape (batch_size, num_steps, observation_dim)
    spike_trains = spike_trains.permute(1, 0, 2)
    
    # torch version
    #return spike_trains

    return spike_trains.numpy()

In [823]:
def get_spike_counts(spike_trains):
    """
    Get the total number of spikes for each neuron over all timesteps.
    
    Parameters:
    - spike_trains: Tensor of spike trains with shape [num_steps, observation_dim].
    
    Returns:
    - Array of spike counts for each neuron.
    """

    num_steps, num_neurons = spike_trains.shape


    spike_counts = torch.sum(spike_trains, dim=0)

    # TODO restore
    #spike_counts = spike_counts/num_steps

    return spike_counts

In [824]:
def get_spike_counts_batched(spike_trains):
    """
    Get the total number of spikes for each neuron over all timesteps for batched spike trains.
    
    Parameters:
    - spike_trains: Tensor of spike trains with shape [batch_size, num_steps, observation_dim].
    
    Returns:
    - Array of spike counts for each neuron in each observation (shape: [batch_size, observation_dim]).
    """
    batch_size, num_steps, num_neurons = spike_trains.shape


    # Sum over the time dimension (dim=1) to get spike counts for each neuron in each observation
    spike_counts = torch.sum(spike_trains, dim=1)
    
    # TODO restore
    #spike_counts = spike_counts/num_steps


    return spike_counts

In [825]:
def decode_first_spike_batched(spike_trains):
    """
    Decodes the first spike time from batched spike trains using the 'time to first spike' method.
    
    Parameters:
        spike_trains - The batched spike trains with shape (batch_size, num_steps, num_neurons).
    
    Returns:
        decoded_vector - A tensor representing the first spike times for each neuron in each batch with gradients retained.
    """
    batch_size, num_steps, num_neurons = spike_trains.shape

    # Create a tensor with time steps and retain gradients
    time_tensor = torch.arange(1, num_steps + 1, dtype=torch.float32, requires_grad=True).unsqueeze(0).unsqueeze(2).expand(batch_size, num_steps, num_neurons)

    # Multiply spike_trains by the time tensor, masking out non-spike entries
    spike_times = spike_trains * time_tensor

    # Set all zero entries (no spike) to a very high value (greater than num_steps)
    spike_times = spike_times + (1 - spike_trains) * (num_steps+1)

    # Find the minimum value in each column (i.e., first spike) for each batch
    first_spike_times, _ = spike_times.min(dim=1)

    # TODO restore
    # Transform the spike times into a format better suited for a categorical 
    #first_spike_times = (-2/(num_steps+1))*first_spike_times + 2


    # Ensure that this tensor retains gradients
    return first_spike_times

In [826]:
def decode_first_spike(spike_trains):
    """
    Decodes the first spike time from spike trains using the 'time to first spike' method.
    
    Parameters:
        spike_trains - The spike trains with shape (num_steps, num_neurons).
    
    Returns:
        decoded_vector - A tensor representing the first spike times for each neuron with gradients retained.
    """
    num_steps, num_neurons = spike_trains.shape

    # Create a tensor with time steps and retain gradients
    time_tensor = torch.arange(1, num_steps + 1, dtype=torch.float32, requires_grad=True).unsqueeze(1).expand(num_steps, num_neurons)

    # Multiply spike_trains by the time tensor, masking out non-spike entries
    spike_times = spike_trains * time_tensor

    # Set all zero entries (no spike) to a very high value (greater than num_steps)
    spike_times = spike_times + (1 - spike_trains) * (num_steps+1)

    # Find the minimum value in each column (i.e., first spike)
    first_spike_times, _ = spike_times.min(dim=0)

    # TODO restore
    # Transform the spike times into a format better suited for a categorical 
    #first_spike_times = (-2/(num_steps+1))*first_spike_times + 2


    # Ensure that this tensor retains gradients
    return first_spike_times

In [827]:
def compute_spike_metrics(spk_output):
    """
    Compute the average spike time and ratio of neurons that spike at least once.

    Handles both batched ([batch_size, num_steps, output_size]) and unbatched ([num_steps, output_size]) outputs.

    Parameters:
        spk_output: Spiking activity output from the actor network.
                    Shape can be either [batch_size, num_steps, output_size] or [num_steps, output_size].

    Returns:
        avg_spike_time: The average time at which spikes occur
        spike_ratio: The ratio of neurons that spike at least once
    """
    if spk_output.dim() == 3:
        # Batched case: [batch_size, num_steps, output_size]
        spike_times = decode_first_spike_batched(spk_output)
        avg_spike_time = torch.mean(spike_times)  # Average spike time

        # Calculate the ratio of neurons that spiked at least once per batch
        spike_ratio = (spk_output.sum(dim=1) > 0).float().mean()

    elif spk_output.dim() == 2:
        # Unbatched case: [num_steps, output_size]
        spike_times = decode_first_spike(spk_output)
        avg_spike_time = torch.mean(spike_times)  # Average spike time

        # Calculate the ratio of neurons that spiked at least once
        spike_ratio = (spk_output.sum(dim=0) > 0).float().mean()

    else:
        raise ValueError("spk_output must have 2 or 3 dimensions, got shape: {}".format(spk_output.shape))

    return avg_spike_time.detach(), spike_ratio.detach()

In [828]:
# non-batch version
observation = np.array([1.2, 0.5, 2.0, -5, 1.0, -4.5, 0.8, 0.3])  # Example observation
threshold = torch.tensor([1.5, 1.5, 5, 5, 3.14, 5, 1, 1])
shift = torch.tensor([1.5, 1.5, 5, 5, 3.14, 5, 1, 1])
print(observation.shape)
spike_trains = generate_spike_trains(observation, num_steps=100, threshold=threshold, shift=shift)

spike_trains = torch.tensor(spike_trains, dtype=torch.float)

print("shape of spike trains",spike_trains.shape)  # [num_steps, observation_dim]
# Get the first spike times as an array
first_spike_times = decode_first_spike(spike_trains)
print("First spike times:", first_spike_times)

# Get the spike counts as an array
spike_counts = get_spike_counts(spike_trains)
print("Spike counts:", spike_counts)

(8,)
shape of spike trains torch.Size([100, 8])
First spike times: tensor([  1.,   1.,   1., 101.,   1.,   9.,   1.,   1.],
       grad_fn=<MinBackward0>)
Spike counts: tensor([94., 73., 71.,  0., 69.,  8., 86., 66.])


In [829]:
# batch version
batch_observation = np.array([[1.5, -0.5, -5.0, -0.0, 1.0, -4.5, 0.8, 0.3],
                             [0.2, 0.5, -2.0, 1.0, 0.5, -1.5, -0.8, -0.3],
                             [-1.2, -0.5, -0.2, -0.3, -1.0, 0.4, 0.2, 0.1],
                             [0.5, 0.5, 0.2, 0.3, 0.1, 0.0, -0.2, 0.3],
                             [-1.5, -1.5, -5, -5, -3.14, -5, -1, -1],
                             [1.5, 1.5, 5, 5, 3.14, 5, 1, 1]])

threshold = torch.tensor([1.5, 1.5, 5, 5, 3.14, 5, 1, 1])
shift = torch.tensor([1.5, 1.5, 5, 5, 3.14, 5, 1, 1])
print(observation.shape)
spike_trains = generate_spike_trains_batched(batch_observation, num_steps=100, threshold=threshold, shift=shift)

spike_trains = torch.tensor(spike_trains, dtype=torch.float)

print(spike_trains.shape)  # [num_steps, observation_dim]
# Get the first spike times as an array
first_spike_times = decode_first_spike_batched(spike_trains)
print("First spike times:", first_spike_times)

# Get the spike counts as an array
spike_counts = get_spike_counts_batched(spike_trains)

print("Spike counts:", spike_counts)

(8,)
torch.Size([6, 100, 8])
First spike times: tensor([[  1.,   4., 101.,   1.,   1.,  29.,   2.,   1.],
        [  1.,   1.,   1.,   5.,   1.,   3.,  22.,   1.],
        [ 12.,   7.,   5.,   2.,   5.,   1.,   1.,   1.],
        [  2.,   1.,   3.,   5.,   1.,   1.,   3.,   1.],
        [101., 101., 101., 101., 101., 101., 101., 101.],
        [  1.,   1.,   1.,   1.,   1.,   1.,   1.,   1.]],
       grad_fn=<MinBackward0>)
Spike counts: tensor([[100.,  32.,   0.,  61.,  69.,   6.,  90.,  74.],
        [ 59.,  68.,  26.,  63.,  58.,  27.,   5.,  34.],
        [ 12.,  33.,  53.,  36.,  33.,  57.,  64.,  52.],
        [ 65.,  73.,  54.,  48.,  44.,  52.,  45.,  58.],
        [  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
        [100., 100., 100., 100., 100., 100., 100., 100.]])


## SNN code rundown

### init

In [830]:
num_steps = 50

obs_dim = 8
act_dim = 4

actor_SNN = SNN_small(obs_dim, act_dim, num_steps)
critic_SNN = SNN_small(obs_dim, 1, num_steps)

actor_ANN = FFNetwork(obs_dim, act_dim)
critic_ANN = FFNetwork(obs_dim, 1)

cov_var = torch.full(size=(act_dim,), fill_value=0.5)
cov_mat = torch.diag(cov_var)

threshold = torch.tensor([1.5, 1.5, 5, 5, 3.14, 5, 1, 1])
shift = torch.tensor([1.5, 1.5, 5, 5, 3.14, 5, 0, 0])


### making sure the data is well distributed after

In [831]:
batch_observation = np.array([[-1.5, -1.5, -5, -5, -3.14, -5, 0, 0],
                              [-1.4, -1.4, -4, -4, -3.0, -4, 1, 1],
                              [-1.2, -1.2, -2, -2, -2.5, -2, 1, 0],
                              [-1.0, -1.0, -0, -0, -1.5, 2, 1, 0],
                              [1.2, 1.3, 2, 2, 2.1, 4, 1, 0],

                              [-1.4, -1.5, -0, -2, -3.0, -4, 1, 0],
                              [1.2, -1.4, 4, -4, 1.12, -4, 0, 0],
                              [1.1, 1.1, 0, 1, 0.2, 2, 0, 0],
                              [1.5, 1.3, 3, -4, -3.0, 1, 0.0, 0.0],
                              [1.4, 1.4, -4, -4, -3.0, -4, 1, 1],
                             [1.5, 1.5, 5, 5, 3.14, 5, 1, 1]])

print(batch_observation.shape)
batch_observation_st = generate_spike_trains_batched(batch_observation, num_steps=50, threshold=threshold, shift=shift)
batch_observation_st_tensor = torch.tensor(batch_observation_st, dtype=torch.float)

print("obs spike trains shape", batch_observation_st_tensor.shape)  # [num_steps, observation_dim]
print("obs First spike times:", decode_first_spike_batched(batch_observation_st_tensor))
print("obs Spike counts:", get_spike_counts_batched(batch_observation_st_tensor))

action = actor_SNN(batch_observation_st)[0].detach()

print("///////////////////////////////////")

print("action spike trains shape", action.shape)  # [num_steps, observation_dim]
print("action First spike times:", decode_first_spike_batched(action))
print("action Spike counts:", get_spike_counts_batched(action))



(11, 8)
obs spike trains shape torch.Size([11, 50, 8])
obs First spike times: tensor([[51., 51., 51., 51., 51., 51., 51., 51.],
        [ 8., 14.,  5.,  3., 51., 40.,  2.,  1.],
        [11.,  3.,  4.,  2.,  3.,  3.,  1., 51.],
        [ 7.,  7.,  3.,  2.,  3.,  2.,  3., 51.],
        [ 1.,  1.,  1.,  1.,  1.,  1.,  2., 51.],
        [15., 51.,  1.,  4., 51.,  4.,  1., 51.],
        [ 1.,  4.,  2., 16.,  1.,  9., 51., 51.],
        [ 1.,  1.,  2.,  1.,  2.,  3., 51., 51.],
        [ 1.,  1.,  3., 13., 29.,  2., 51., 51.],
        [ 1.,  1.,  4.,  1., 23.,  3.,  5.,  2.],
        [ 1.,  1.,  1.,  1.,  1.,  1.,  3.,  4.]], grad_fn=<MinBackward0>)
obs Spike counts: tensor([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 3.,  1.,  6.,  5.,  0.,  3., 24., 23.],
        [ 6.,  5., 12., 12.,  5., 10., 28.,  0.],
        [ 8.,  7., 25., 22., 18., 40., 15.,  0.],
        [44., 47., 36., 39., 45., 47., 28.,  0.],
        [ 3.,  0., 20., 18.,  0.,  8., 21.,  0.],
        [46.,  2., 40.,  4., 

In [832]:
batch_obs_st = generate_spike_trains_batched(batch_observation, num_steps=num_steps, threshold=threshold, shift=threshold)

batch_obs_st = torch.tensor(batch_obs_st, dtype=torch.float)


print(batch_obs_st.shape)
print("obs first spikes:", decode_first_spike_batched(batch_obs_st))
print("obs spike counts:", get_spike_counts_batched(batch_obs_st))


action = actor_SNN(batch_obs_st)[0].detach()

print(action.shape)

print("action first spikes:", decode_first_spike_batched(action))
print("action spike counts:", get_spike_counts_batched(action))




torch.Size([11, 50, 8])
obs first spikes: tensor([[51., 51., 51., 51., 51., 51.,  1.,  2.],
        [33., 37., 22.,  5., 34.,  9.,  1.,  1.],
        [ 7.,  1.,  3.,  9., 12.,  1.,  1.,  3.],
        [17.,  3.,  1.,  2.,  4.,  1.,  1.,  1.],
        [ 1.,  1.,  2.,  1.,  1.,  1.,  1.,  1.],
        [19., 51.,  1.,  4., 17.,  7.,  1.,  1.],
        [ 1., 39.,  2., 23.,  1.,  1.,  1.,  1.],
        [ 1.,  1.,  1.,  1.,  1.,  1.,  2.,  3.],
        [ 1.,  1.,  1.,  1., 33.,  1.,  2.,  1.],
        [ 1.,  1.,  6.,  3., 18.,  5.,  1.,  1.],
        [ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.]], grad_fn=<MinBackward0>)
obs spike counts: tensor([[ 0.,  0.,  0.,  0.,  0.,  0., 27., 26.],
        [ 3.,  1.,  1.,  6.,  1.,  5., 50., 50.],
        [ 7.,  5., 20., 16.,  4., 12., 50., 20.],
        [ 9., 10., 23., 17., 14., 33., 50., 27.],
        [49., 43., 34., 34., 44., 47., 50., 27.],
        [ 1.,  0., 25., 23.,  2.,  5., 50., 21.],
        [47.,  2., 40.,  4., 42.,  5., 28., 27.],
        [45., 4

In [833]:
obs_st= generate_spike_trains(observation, num_steps=num_steps, threshold=threshold, shift=threshold)

obs_st = torch.tensor(obs_st, dtype=torch.float)

print(obs_st.shape)
print("obs first spikes:", decode_first_spike(obs_st))
print("obs spike counts:", get_spike_counts(obs_st))


action = actor_SNN(obs_st)[0]

assert action.shape == torch.Size([50, 4])

print(action.shape)

print("action first spikes:", decode_first_spike(action))
print("action spike counts:", get_spike_counts(action))


torch.Size([50, 8])
obs first spikes: tensor([ 2.,  2.,  2., 51.,  3., 50.,  2.,  2.], grad_fn=<MinBackward0>)
obs spike counts: tensor([46., 36., 34.,  0., 30.,  1., 46., 32.])
torch.Size([50, 4])
action first spikes: tensor([51., 51.,  7., 51.], grad_fn=<MinBackward0>)
action spike counts: tensor([0., 0., 7., 0.], grad_fn=<SumBackward1>)


### Evaluate

In [834]:
V_st = critic_SNN(batch_obs_st)[0]
V = decode_first_spike_batched(V_st).squeeze()

print(V)


tensor([51., 51., 51., 51., 51., 51., 51., 51., 51., 51., 51.],
       grad_fn=<SqueezeBackward0>)


In [835]:
V = critic_ANN(batch_observation).squeeze()
print(V)

tensor([ 0.3381,  0.3932,  0.2107,  0.1316,  0.0187,  0.0992,  0.0625,  0.1254,
         0.1375,  0.3738, -0.2545], grad_fn=<SqueezeBackward0>)


In [836]:
mean = actor_ANN(batch_observation)
dist = MultivariateNormal(mean, cov_mat)

print(mean)
print(dist)

tensor([[-0.2182,  0.0463,  0.5205, -0.1196],
        [-0.1424,  0.0238,  0.4317, -0.1323],
        [-0.0184,  0.0606,  0.2266, -0.0647],
        [ 0.0735,  0.1262,  0.0394,  0.0267],
        [ 0.1990,  0.1087,  0.0016,  0.0638],
        [-0.0008,  0.2298,  0.2183,  0.0100],
        [ 0.1589,  0.3151, -0.0251, -0.2038],
        [ 0.1514,  0.0056, -0.0074,  0.0558],
        [-0.0426,  0.2294,  0.0213,  0.0985],
        [-0.3371,  0.1757,  0.2767, -0.0221],
        [ 0.4478,  0.2150,  0.1370,  0.0591]], grad_fn=<AddmmBackward0>)
MultivariateNormal(loc: torch.Size([11, 4]), covariance_matrix: torch.Size([11, 4, 4]))


In [837]:
mean_st = actor_SNN(batch_obs_st)[0]
mean = decode_first_spike_batched(mean_st)
dist = MultivariateNormal(mean, cov_mat)
print(mean)
print(dist)

tensor([[51., 51., 10., 51.],
        [51., 51.,  7., 51.],
        [51., 51.,  6., 51.],
        [51., 51., 10., 51.],
        [51., 51., 10., 51.],
        [51., 51.,  6., 51.],
        [51., 51.,  8., 51.],
        [51., 51., 10., 51.],
        [51., 51.,  6., 51.],
        [51., 51.,  9., 51.],
        [51., 51.,  7., 51.]], grad_fn=<MinBackward0>)
MultivariateNormal(loc: torch.Size([11, 4]), covariance_matrix: torch.Size([11, 4, 4]))


In [838]:
logits = actor_ANN(batch_observation)
dist = Categorical(logits=logits)

print(logits)
print(dist)

print(dist.sample())

tensor([[-0.2182,  0.0463,  0.5205, -0.1196],
        [-0.1424,  0.0238,  0.4317, -0.1323],
        [-0.0184,  0.0606,  0.2266, -0.0647],
        [ 0.0735,  0.1262,  0.0394,  0.0267],
        [ 0.1990,  0.1087,  0.0016,  0.0638],
        [-0.0008,  0.2298,  0.2183,  0.0100],
        [ 0.1589,  0.3151, -0.0251, -0.2038],
        [ 0.1514,  0.0056, -0.0074,  0.0558],
        [-0.0426,  0.2294,  0.0213,  0.0985],
        [-0.3371,  0.1757,  0.2767, -0.0221],
        [ 0.4478,  0.2150,  0.1370,  0.0591]], grad_fn=<AddmmBackward0>)
Categorical(logits: torch.Size([11, 4]))
tensor([0, 3, 2, 3, 2, 3, 1, 1, 1, 3, 1])


In [839]:
logits_st = actor_SNN(batch_obs_st)[0]
logits = decode_first_spike_batched(logits_st)

dist = Categorical(logits=logits)

print(logits)
print(dist)

m = nn.Softmax(dim=1)

print("percentages ", m(logits))

print(dist.sample())

avg_spike_time, spike_ratio = compute_spike_metrics(logits_st)

print(avg_spike_time)
print(spike_ratio)

tensor([[51., 51., 10., 51.],
        [51., 51.,  7., 51.],
        [51., 51.,  6., 51.],
        [51., 51., 10., 51.],
        [51., 51., 10., 51.],
        [51., 51.,  6., 51.],
        [51., 51.,  8., 51.],
        [51., 51., 10., 51.],
        [51., 51.,  6., 51.],
        [51., 51.,  9., 51.],
        [51., 51.,  7., 51.]], grad_fn=<MinBackward0>)
Categorical(logits: torch.Size([11, 4]))
percentages  tensor([[3.3333e-01, 3.3333e-01, 5.2096e-19, 3.3333e-01],
        [3.3333e-01, 3.3333e-01, 2.5937e-20, 3.3333e-01],
        [3.3333e-01, 3.3333e-01, 9.5417e-21, 3.3333e-01],
        [3.3333e-01, 3.3333e-01, 5.2096e-19, 3.3333e-01],
        [3.3333e-01, 3.3333e-01, 5.2096e-19, 3.3333e-01],
        [3.3333e-01, 3.3333e-01, 9.5417e-21, 3.3333e-01],
        [3.3333e-01, 3.3333e-01, 7.0504e-20, 3.3333e-01],
        [3.3333e-01, 3.3333e-01, 5.2096e-19, 3.3333e-01],
        [3.3333e-01, 3.3333e-01, 9.5417e-21, 3.3333e-01],
        [3.3333e-01, 3.3333e-01, 1.9165e-19, 3.3333e-01],
        [3.

In [840]:
logits_st = actor_SNN(batch_obs_st)[0]

spike_counts = get_spike_counts_batched(logits_st)

print("logits:", spike_counts)

m = nn.Softmax(dim=1)

print("percentages ", m(logits))

dist = Categorical(logits=logits)

print(dist.sample())

avg_spike_time, spike_ratio = compute_spike_metrics(logits_st)

print(avg_spike_time)
print(spike_ratio)


logits: tensor([[0., 0., 5., 0.],
        [0., 0., 7., 0.],
        [0., 0., 6., 0.],
        [0., 0., 6., 0.],
        [0., 0., 5., 0.],
        [0., 0., 8., 0.],
        [0., 0., 5., 0.],
        [0., 0., 4., 0.],
        [0., 0., 6., 0.],
        [0., 0., 5., 0.],
        [0., 0., 9., 0.]], grad_fn=<SumBackward1>)
percentages  tensor([[3.3333e-01, 3.3333e-01, 5.2096e-19, 3.3333e-01],
        [3.3333e-01, 3.3333e-01, 2.5937e-20, 3.3333e-01],
        [3.3333e-01, 3.3333e-01, 9.5417e-21, 3.3333e-01],
        [3.3333e-01, 3.3333e-01, 5.2096e-19, 3.3333e-01],
        [3.3333e-01, 3.3333e-01, 5.2096e-19, 3.3333e-01],
        [3.3333e-01, 3.3333e-01, 9.5417e-21, 3.3333e-01],
        [3.3333e-01, 3.3333e-01, 7.0504e-20, 3.3333e-01],
        [3.3333e-01, 3.3333e-01, 5.2096e-19, 3.3333e-01],
        [3.3333e-01, 3.3333e-01, 9.5417e-21, 3.3333e-01],
        [3.3333e-01, 3.3333e-01, 1.9165e-19, 3.3333e-01],
        [3.3333e-01, 3.3333e-01, 2.5937e-20, 3.3333e-01]],
       grad_fn=<SoftmaxBackwa

### Get Action

In [841]:
mean = actor_ANN(observation)
dist = MultivariateNormal(mean, cov_mat)

print(mean)
print(dist)

tensor([ 0.0357,  0.1744, -0.0581, -0.1847], grad_fn=<ViewBackward0>)
MultivariateNormal(loc: torch.Size([4]), covariance_matrix: torch.Size([4, 4]))


In [842]:
mean_st = actor_SNN(obs_st)[0]
mean = decode_first_spike(mean_st)
dist = MultivariateNormal(mean, cov_mat)
print(mean)
print(dist)

tensor([51., 51.,  7., 51.], grad_fn=<MinBackward0>)
MultivariateNormal(loc: torch.Size([4]), covariance_matrix: torch.Size([4, 4]))


In [843]:
logits = actor_ANN(observation)
dist = Categorical(logits=logits)

print(logits)
print(dist)

print(dist.sample())

tensor([ 0.0357,  0.1744, -0.0581, -0.1847], grad_fn=<ViewBackward0>)
Categorical(logits: torch.Size([4]))
tensor(0)


In [844]:
logits_st = actor_SNN(obs_st)[0]
logits = decode_first_spike(logits_st)
logits_2 = get_spike_counts(logits_st)

dist = Categorical(logits=logits)

avg_spike_time, spike_ratio = compute_spike_metrics(logits_st)

print(logits_st)
print(logits)
print(logits_2)

print(dist.sample())

print(avg_spike_time.item())
print(spike_ratio.item())

tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0.,

In [32]:
def testing(observation, num_steps, threshold, shift):
    """
    Generate spike trains from a single observation using a fixed global threshold.
    
    Parameters:
    - observation: A tensor representing the observation ([observation_dim]).
    - num_steps: The number of timesteps for the spike train.
    - threshold: A single global threshold value to be used for normalization.
    
    Returns:
    - spike_trains: Tensor of spike trains.
    """
    
    shift = shift.numpy()

    # Normalize and clip observation
    shifted_obs = np.add(observation, shift) 

    # torch version
    #shifted_obs = observation + shift


    normalized_obs = shifted_obs / (threshold + 1e-6)  # Avoid division by zero

    normalized_obs /= 2
    
    normalized_obs = normalized_obs.clamp(0, 1)  # Clip values to be within [0, 1]

    
    # Generate spike trains
    spike_trains = spikegen.rate(normalized_obs, num_steps=num_steps)
    
    # torch version
    #return spike_trains

    return spike_trains.numpy()

In [32]:
np_obs = np.array([[1.2, 0.5, 2.0, -5, 1.0, -4.5, 0.8, 0.3],
                  [1.2, 0.5, 2.0, -5, 1.0, -4.5, 0.8, 0.3],
                  [1.2, 0.5, 2.0, -5, 1.0, -4.5, 0.8, 0.3],
                  [1.2, 0.5, 2.0, -5, 1.0, -4.5, 0.8, 0.3],
                  [1.2, 0.5, 2.0, -5, 1.0, -4.5, 0.8, 0.3],
                  [1.2, 0.5, 2.0, -5, 1.0, -4.5, 0.8, 0.3]])



np_obs_st = generate_spike_trains_batched(np_obs, 100, threshold, shift)

#print(np_obs_st.requires_grad)


print(np_obs_st.shape)

yeet = actor_SNN(np_obs_st)[0]

print(type(yeet))
print(yeet.shape)

bruh = decode_first_spike_batched(yeet)

print("huh", bruh)
print("huhhhhh", bruh.requires_grad)


print(get_spike_counts_batched(yeet))

(6, 100, 8)
<class 'torch.Tensor'>
torch.Size([6, 100, 4])
huh tensor([[3., 3., 3., 3.],
        [3., 3., 3., 3.],
        [3., 4., 4., 4.],
        [3., 3., 3., 3.],
        [4., 4., 5., 4.],
        [4., 4., 4., 4.]], grad_fn=<MinBackward0>)
huhhhhh True
tensor([[96., 83., 91., 87.],
        [94., 77., 90., 83.],
        [91., 77., 87., 80.],
        [94., 78., 84., 81.],
        [92., 79., 84., 81.],
        [90., 80., 87., 84.]], grad_fn=<SumBackward1>)
