In [40]:
import snntorch as snn
from snntorch import spikegen
import torch, torch.nn as nn
from snntorch import surrogate
from torch.distributions import MultivariateNormal, Categorical


import numpy as np

In [41]:
hidden_size = 64  # Number of hidden neurons


class SNN_lrl(nn.Module):
    def __init__(self, input_size, output_size, num_steps):
        super(SNN_lrl, self).__init__()

        self.num_steps = num_steps
        beta1 = 0.9
        beta2 = torch.rand((output_size), dtype=torch.float)  # Independent decay rate for each output neuron

        # Define layers
        self.fc1 = nn.Linear(input_size, hidden_size, dtype=torch.float)
        self.fc1.weight.data += 0.05
        self.lif1 = snn.Leaky(beta=beta1, spike_grad=surrogate.fast_sigmoid())

        self.fc2 = nn.Linear(hidden_size, output_size, dtype=torch.float)
        self.fc2.weight.data += 0.05
        self.lif2 = snn.Leaky(beta=beta2, learn_beta=True, spike_grad=surrogate.fast_sigmoid())

        # Linear readout layer
        self.readout = nn.Linear(output_size, output_size)

    def forward(self, x):

        # Convert observation to tensor if it's a numpy array
        if isinstance(x, np.ndarray):
            x = torch.tensor(x, dtype=torch.float)

        # Determine if input is batched or not
        is_batched = x.dim() == 3  # [batch_size, num_steps, input_size] is 3D

        if not is_batched:
            # If not batched, add a batch dimension
            x = x.unsqueeze(0)  # Shape becomes [1, num_steps, input_size]


        batch_size = x.size(0)  # This is 1 if not batched, otherwise the actual batch size

        # Initialize membrane potentials
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Record the spikes from the last layer
        spk2_rec = []
        mem2_rec = []

        for step in range(self.num_steps):
            cur1 = self.fc1(x[:, step, :])
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            spk2_rec.append(spk2)
            mem2_rec.append(mem2)

        # Shape: [batch_size, num_steps, output_size]
        spk2_stacked = torch.stack(spk2_rec, dim=1)
        mem2_stacked = torch.stack(mem2_rec, dim=1)

        # Take the average membrane potential across time (summarize spikes)
        # Shape: [batch_size, output_size]
        avg_spk2 = torch.mean(spk2_stacked, dim=1)
        avg_mem2 = torch.mean(mem2_stacked, dim=1)

        # Apply the linear readout layer to the average membrane potential
        # Shape: [batch_size, output_size]
        readout_output_spk = self.readout(avg_spk2)
        readout_output_mem = self.readout(avg_mem2)

        if not is_batched:
            # Remove the batch dimension if it was added
            readout_output_spk = readout_output_spk.squeeze(0)  # Shape becomes [output_size]
            readout_output_mem = readout_output_mem.squeeze(0)  # Shape becomes [output_size]

            spk2_stacked = spk2_stacked.squeeze(0)  # Shape becomes [num_steps, output_size]


        return readout_output_mem, spk2_stacked

In [42]:
hidden_size = 64  # Number of hidden neurons

class SNN_small(nn.Module):
    def __init__(self, input_size, output_size, num_steps):
        super(SNN_small, self).__init__()

        self.num_steps = num_steps
        beta1 = 0.9
        #beta2 = torch.rand((output_size), dtype=torch.float)  # Independent decay rate for each output neuron
        beta2 = 0.9

        # Define layers
        self.fc1 = nn.Linear(input_size, hidden_size, dtype=torch.float)
        self.fc1.weight.data += 0.01
        self.lif1 = snn.Leaky(beta=beta1, spike_grad=surrogate.fast_sigmoid())

        self.fc2 = nn.Linear(hidden_size, output_size, dtype=torch.float)
        self.fc2.weight.data += 0.01
        #self.lif2 = snn.Leaky(beta=beta2, learn_beta=True, spike_grad=surrogate.fast_sigmoid())
        self.lif2 = snn.Leaky(beta=beta2, spike_grad=surrogate.fast_sigmoid())

    def forward(self, x):

        # Convert observation to tensor if it's a numpy array
        if isinstance(x, np.ndarray):
            x = torch.tensor(x, dtype=torch.float)

        # Determine if input is batched or not
        is_batched = x.dim() == 3  # [batch_size, num_steps, input_size] is 3D

        if not is_batched:
            # If not batched, add a batch dimension
            x = x.unsqueeze(0)  # Shape becomes [1, num_steps, input_size]


        batch_size = x.size(0)  # This is 1 if not batched, otherwise the actual batch size

        # Initialize membrane potentials
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()

        # Record the spikes from the last layer
        spk2_rec = []
        mem2_rec = []

        for step in range(self.num_steps):
            cur1 = self.fc1(x[:, step, :])
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            spk2_rec.append(spk2)
            mem2_rec.append(mem2)


        output_spk = torch.stack(spk2_rec, dim=1)  # Shape: [batch_size, num_steps, output_size]
        output_mem = torch.stack(mem2_rec, dim=1)  # Shape: [batch_size, num_steps, output_size]

        if not is_batched:
            # Remove the batch dimension if it was added
            output_spk = output_spk.squeeze(0)  # Shape becomes [num_steps, output_size]
            output_mem = output_mem.squeeze(0)  # Shape becomes [num_steps, output_size]


        return output_spk, output_mem

In [43]:
def generate_spike_trains(observation, num_steps, threshold, shift):
    """
    Generate spike trains from a single observation using a fixed global threshold.

    Parameters:
    - observation: A tensor representing the observation ([observation_dim]).
    - num_steps: The number of timesteps for the spike train.
    - threshold: A single global threshold value to be used for normalization.

    Returns:
    - spike_trains: Tensor of spike trains.
    """

    # Normalize and clip observation
    shifted_obs = np.add(observation, shift)

    # torch version
    # shifted_obs = observation + shift

    normalized_obs = shifted_obs / (threshold + 1e-6)  # Avoid division by zero

    normalized_obs /= 2

    normalized_obs = normalized_obs.clamp(0, 1)  # Clip values to be within [0, 1]

    # Generate spike trains
    spike_trains = spikegen.rate(normalized_obs, num_steps=num_steps)

    # torch version
    # return spike_trains

    return spike_trains.numpy()

In [44]:
def decode_first_spike(spike_trains):
    """
    Decodes the first spike time from spike trains using the 'time to first spike' method.

    Parameters:
        spike_trains - The spike trains with shape (num_steps, num_neurons).

    Returns:
        decoded_vector - A tensor representing the first spike times for each neuron with gradients retained.
    """
    num_steps, num_neurons = spike_trains.shape

    # Create a tensor with time steps and retain gradients
    time_tensor = torch.arange(1, num_steps + 1, dtype=torch.float32, requires_grad=True).unsqueeze(1).expand(num_steps,
                                                                                                              num_neurons)

    # Multiply spike_trains by the time tensor, masking out non-spike entries
    spike_times = spike_trains * time_tensor

    # Set all zero entries (no spike) to a very high value (greater than num_steps)
    spike_times = spike_times + (1 - spike_trains) * (num_steps + 1)

    # Find the minimum value in each column (i.e., first spike)
    first_spike_times, _ = spike_times.min(dim=0)

    # Transform the spike times into a format better suited for a categorical
    first_spike_times = (-2 / (num_steps + 1)) * first_spike_times + 2

    # Ensure that this tensor retains gradients
    return first_spike_times

In [45]:
def get_first_spike(spike_trains):
    """
    Decodes the first spike time from spike trains using the 'time to first spike' method.

    Parameters:
        spike_trains - The spike trains with shape (num_steps, num_neurons).

    Returns:
        decoded_vector - A tensor representing the first spike times for each neuron with gradients retained.
    """
    num_steps, num_neurons = spike_trains.shape

    time_tensor = torch.arange(1, num_steps + 1, dtype=torch.float32, requires_grad=True).unsqueeze(1).expand(num_steps,
                                                                                                              num_neurons)
    spike_times = spike_trains * time_tensor
    spike_times = spike_times + (1 - spike_trains) * (num_steps + 1)
    first_spike_times, _ = spike_times.min(dim=0)

    return first_spike_times

In [46]:
def decode_first_spike_batched(spike_trains):
    """
    Decodes the first spike time from batched spike trains using the 'time to first spike' method.

    Parameters:
        spike_trains - The batched spike trains with shape (batch_size, num_steps, num_neurons).

    Returns:
        decoded_vector - A tensor representing the first spike times for each neuron in each batch with gradients retained.
    """
    batch_size, num_steps, num_neurons = spike_trains.shape

    # Create a tensor with time steps and retain gradients
    time_tensor = torch.arange(1, num_steps + 1, dtype=torch.float32, requires_grad=True).unsqueeze(0).unsqueeze(
        2).expand(batch_size, num_steps, num_neurons)

    # Multiply spike_trains by the time tensor, masking out non-spike entries
    spike_times = spike_trains * time_tensor

    # Set all zero entries (no spike) to a very high value (greater than num_steps)
    spike_times = spike_times + (1 - spike_trains) * (num_steps + 1)

    # Find the minimum value in each column (i.e., first spike) for each batch
    first_spike_times, _ = spike_times.min(dim=1)

    # Transform the spike times into a format better suited for a categorical
    first_spike_times = (-2 / (num_steps + 1)) * first_spike_times + 2

    # Ensure that this tensor retains gradients
    return first_spike_times

In [47]:
def generate_spike_trains_batched(observations, num_steps, threshold, shift):
    """
    Generate spike trains from batched observations using a fixed global threshold.

    Parameters:
    - observations: A tensor representing the batched observations ([batch_size, observation_dim]).
    - num_steps: The number of timesteps for the spike train.
    - threshold: A single global threshold value to be used for normalization.
    - shift: A value to shift the observation range to handle negative values.

    Returns:
    - spike_trains: Tensor of spike trains with shape (batch_size, num_steps, observation_dim).
    """


    # Normalize and shift observations
    normalized_obs = np.add(observations, shift) / (2 * (threshold + 1e-6))  # Avoid division by zero
    normalized_obs = normalized_obs.clamp(0, 1)  # Clip values to [0, 1]

    # Generate spike trains for each observation in the batch
    spike_trains = spikegen.rate(normalized_obs, num_steps=num_steps)

    # Rearrange the output to have shape (batch_size, num_steps, observation_dim)
    spike_trains = spike_trains.permute(1, 0, 2)

    # torch version
    # return spike_trains

    return spike_trains.numpy()

In [48]:
def compute_spike_metrics(spk_output):
    """
    Compute the average spike time and ratio of neurons that spike at least once.

    Handles both batched ([batch_size, num_steps, output_size]) and unbatched ([num_steps, output_size]) outputs.

    Parameters:
        spk_output: Spiking activity output from the actor network.
                    Shape can be either [batch_size, num_steps, output_size] or [num_steps, output_size].

    Returns:
        avg_spike_time: The average time at which spikes occur
        spike_ratio: The ratio of neurons that spike at least once
    """
    if spk_output.dim() == 3:
        # Batched case: [batch_size, num_steps, output_size]
        spike_times = get_first_spike_batched(spk_output)
        avg_spike_time = torch.mean(spike_times)  # Average spike time

        # Calculate the ratio of neurons that spiked at least once per batch
        spike_ratio = (spk_output.sum(dim=1) > 0).float().mean()

    elif spk_output.dim() == 2:
        # Unbatched case: [num_steps, output_size]
        spike_times = get_first_spike(spk_output)
        avg_spike_time = torch.mean(spike_times)  # Average spike time

        # Calculate the ratio of neurons that spiked at least once
        spike_ratio = (spk_output.sum(dim=0) > 0).float().mean()

    else:
        raise ValueError("spk_output must have 2 or 3 dimensions, got shape: {}".format(spk_output.shape))

    return avg_spike_time.detach(), spike_ratio.detach()

In [49]:
obs_dim = 8
act_dim = 4
num_steps = 32        

actor = SNN_lrl(obs_dim, act_dim, num_steps)
critic = SNN_lrl(obs_dim, 1, num_steps)

threshold = torch.tensor([1.5, 1.5, 5, 5, 3.14, 5, 1, 1])
shift = np.array([1.5, 1.5, 5, 5, 3.14, 5, 0, 0])

In [50]:
obs_batch = np.array([[ 0.11220751, 1.069325, 0.41773742, -0.7382385, -0.02831368, -0.08749236, 0., 0.],
[ 0.11624928, 1.0521305, 0.40627155, -0.7642402, -0.03038585, -0.04144341, 0., 0.],
[ 0.12014542, 1.0351937, 0.39247927, -0.75280756, -0.03321088, -0.05650074, 0., 0.],])

In [51]:

obs_st = generate_spike_trains(obs_batch[0],
                                            num_steps=num_steps,
                                            threshold=threshold,
                                            shift=shift)

obs_st_batch = generate_spike_trains_batched(obs_batch,
                                            num_steps=num_steps,
                                            threshold=threshold,
                                            shift=shift)

spk_output, spikes = actor(obs_st)
print("action ", spk_output)
print("action ", spikes)

avg_spike_time, spike_ratio = compute_spike_metrics(spikes)

print("avg spike time ", avg_spike_time)
print("spike ratio ", spike_ratio)
print("---------------")

logits = spk_output




#logits = decode_first_spike(spk_output)

dist = Categorical(logits=logits)

# Sample an action from the distribution
action = dist.sample()

# Calculate the log probability for that action
log_prob = dist.log_prob(action)


action  tensor([ 1.1359,  1.0008, -0.4039, -0.1630], grad_fn=<SqueezeBackward1>)
action  tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [1., 0., 1., 0.],
        [1., 1., 1., 0.],
        [1., 0., 1., 0.],
        [0., 0., 0., 0.],
        [1., 0., 1., 0.],
        [1., 0., 1., 0.],
        [1., 0., 0., 0.],
        [1., 1., 1., 0.],
        [0., 0., 0., 0.],
        [1., 0., 1., 0.],
        [1., 1., 0., 0.],
        [0., 0., 0., 0.],
        [1., 0., 1., 0.],
        [1., 0., 1., 0.],
        [1., 1., 1., 0.],
        [1., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 0., 0., 0.],
        [1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [1., 0., 1., 0.],
        [1., 0., 0., 0.],
        [1., 1., 1., 0.],
        [1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 0., 0., 0.],
        [1., 1., 1., 0.],
        [1., 0., 0., 0.],
        [1., 0., 1., 0.]], grad_fn=<SqueezeBackward1>)
avg spike time  tensor(10.7500)
spike ratio  tenso

In [52]:


bb, zz = critic(obs_st_batch)

print(bb.shape)

torch.Size([3, 1])
