In [112]:
import torch, torch.nn as nn
import snntorch as snn
from snntorch import spikegen
import numpy as np
import time

from torch.distributions import MultivariateNormal
import snntorch.functional as SF

## SNN def

In [113]:
hidden_size = 300  # number of hidden neurons


class SNN(nn.Module):
    def __init__(self, input_size, output_size, num_steps):
        super(SNN, self).__init__()

        self.num_steps = num_steps
        beta1 = 0.9
        beta2 = torch.rand((output_size), dtype = torch.float) # independent decay rate for each leaky neuron in layer 2: [0, 1)

        # layername.weight.data += 0.5

        self.fc1 = nn.Linear(input_size, hidden_size, dtype=torch.float)
        self.fc1.weight.data += 0.01
        self.lif1 = snn.Leaky(beta=beta1)
        self.fc2 = nn.Linear(hidden_size, output_size, dtype=torch.float)
        self.fc2.weight.data += 0.01
        self.lif2 = snn.Leaky(beta=beta2, learn_beta=True)

    def forward(self, x):
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        spk2_rec = [] # record output spikes
        mem2_rec = [] # record output hidden states


        for step in range(self.num_steps):
            cur1 = self.fc1(x[:, step, :])
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            spk2_rec.append(spk2) # record spikes
            mem2_rec.append(mem2) # record membrane

        return torch.stack(spk2_rec, dim=1), torch.stack(mem2_rec)

In [114]:
hidden_size = 64

class SNN_2(nn.Module):
    def __init__(self, input_size, output_size, num_steps):
        super(SNN_2, self).__init__()
        self.num_steps = num_steps

        self.fc1 = nn.Linear(input_size, hidden_size)
        self.lif1 = snn.Leaky(beta=0.9)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.lif2 = snn.Leaky(beta=0.9)
        self.fc3 = nn.Linear(hidden_size, output_size)
        self.lif3 = snn.Leaky(beta=0.9, learn_beta=True)

    def forward(self, x):
        # Initialize the membrane potentials
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        # Record the final layer
        spk3_rec = []

        for step in range(self.num_steps):
            cur1 = self.fc1(x[:,step,:])
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            cur3 = self.fc3(spk2)
            spk3, mem3 = self.lif3(cur3, mem3)

            spk3_rec.append(spk3)

        return torch.stack(spk3_rec, dim=1)

### Simulation set up


In [115]:
obs_dim = 5
act_dim = 5

# Define the number of time steps for the simulation
num_steps = 50

### Hyper-parameters

In [116]:
timesteps_per_batch = 4800  # Number of timesteps to run per batch
max_timesteps_per_episode = 1600  # Max number of timesteps per episode
n_updates_per_iteration = 5  # Number of times to update actor/critic per iteration
lr = 0.005  # Learning rate of actor optimizer
gamma = 0.95  # Discount factor to be applied when calculating Rewards-To-Go
clip = 0.2  # Recommended 0.2, helps define the threshold to clip the ratio during SGA

# Miscellaneous parameters
render = True  # If we should render during rollout
render_every_i = 10  # Only render every n iterations
save_freq = 10  # How often we save in number of iterations
seed = None  # Sets the seed of our program, used for reproducibility of results

### Initialization
#### actor and critic networks
#### optimizers
#### covariance matrix

In [117]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
    "mps") if torch.backends.mps.is_available() else torch.device("cpu")

# Initialize actor and critic networks
actor = SNN(obs_dim, act_dim, num_steps).to(device)  # ALG STEP 1
critic = SNN(obs_dim, 1, num_steps).to(device)

# Initialize optimizers for actor and critic
actor_optim = torch.optim.Adam(actor.parameters(), lr=lr, betas=(0.9, 0.999))
critic_optim = torch.optim.Adam(critic.parameters(), lr=lr, betas=(0.9, 0.999))

cov_var = torch.full(size=(act_dim,), fill_value=0.5)
cov_mat = torch.diag(cov_var)

### Methods
#### rewards to go calculation
#### evaluation


In [118]:
def decode_first_spike_batched(spike_trains):
    """
    Decodes the first spike time from spike trains for batched data using 'time to first spike' method.

    Parameters:
        spike_trains - The batched spike trains with shape (batch_size, num_steps, num_neurons).

    Returns:
        decoded_vector - The decoded first spike times with shape (batch_size, num_neurons).
    """
    batch_size = spike_trains.size(0)
    num_neurons = spike_trains.size(2)
    decoded_vectors = []

    for batch_idx in range(batch_size):
        decoded_vector = [spike_trains.size(1)+1] * num_neurons
        
        for neuron_idx in range(num_neurons):
            first_spike = (spike_trains[batch_idx, :, neuron_idx] == 1).nonzero(as_tuple=True)[0]
            if first_spike.nelement() != 0:
                decoded_vector[neuron_idx] = first_spike[0].item() + 1
        
        decoded_vectors.append(decoded_vector)

    return torch.FloatTensor(decoded_vectors)

def decode_from_spikes_count(spikes):
    spike_counts = torch.sum(spikes, dim=1)
    action = torch.zeros(spikes.size(0))
    max_spike_count = torch.max(spike_counts)
    candidates = torch.where(spike_counts == max_spike_count)[0]
    if len(candidates) > 1:
        action[torch.multinomial(candidates.float(), 1)] = 1
    else:
        action[candidates] = 1
    return action

def encode_to_spikes(data, num_steps):
    """
    Encodes analog signals into spike trains using rate encoding.

    Parameters:
        data - The continuous-valued data to be encoded.
        num_steps - The number of time steps for the spike train.

    Returns:
        spike_train - The encoded spike train.
    """

    # Add a small epsilon to avoid division by zero
    epsilon = 1e-6


    # Normalize the data to be between 0 and 1
    normalized_data = (data - data.min()) / (data.max() - data.min() + epsilon)

    normalized_data = torch.clamp(normalized_data, 0, 1)

    # Convert normalized data to spike trains
    # TODO rate vs latency vs delta
    spike_train = spikegen.rate(normalized_data, num_steps=num_steps)

    return spike_train

def encode_to_spikes_batched(data, num_steps):
    """
    Encodes analog signals into spike trains using rate encoding.

    Parameters:
        data - The continuous-valued data to be encoded.
        num_steps - The number of time steps for the spike train.

    Returns:
        spike_train - The encoded spike train.
    """

    # Add a small epsilon to avoid division by zero
    epsilon = 1e-6

    min = data.min(dim=1, keepdim=True)[0]
    max = data.max(dim=1, keepdim=True)[0]

    normalized_data = (data - min) / (max - min + epsilon)
    normalized_data = torch.clamp(normalized_data, 0, 1)

    spike_train = spikegen.rate(normalized_data, num_steps=num_steps)

    return spike_train.transpose(0,1)



def compute_rtgs(batch_rews):
    """
        Compute the Reward-To-Go of each timestep in a batch given the rewards.

        Parameters:
            batch_rews - the rewards in a batch, Shape: (number of episodes, number of timesteps per episode)

        Return:
            batch_rtgs - the rewards to go, Shape: (number of timesteps in batch)
    """
    # The rewards-to-go (rtg) per episode per batch to return.
    # The shape will be (num timesteps per episode)
    batch_rtgs = []

    # Iterate through each episode
    for ep_rews in reversed(batch_rews):

        discounted_reward = 0  # The discounted reward so far

        # Iterate through all rewards in the episode. We go backwards for smoother calculation of each
        # discounted return (think about why it would be harder starting from the beginning)
        for rew in reversed(ep_rews):
            discounted_reward = rew + discounted_reward * gamma
            batch_rtgs.insert(0, discounted_reward)

    # Convert the rewards-to-go into a tensor
    batch_rtgs = torch.tensor(batch_rtgs, dtype=torch.float)

    return batch_rtgs

def evaluate(batch_obs_ts, batch_acts):
    """
        Estimate the values of each observation, and the log probs of
        each action in the most recent batch with the most recent
        iteration of the actor network. Should be called from learn.

        Parameters:
            batch_obs - the observations from the most recently collected batch as a tensor.
                        Shape: (number of timesteps in batch, dimension of observation)
            batch_acts - the actions from the most recently collected batch as a tensor.
                        Shape: (number of timesteps in batch, dimension of action)

        Return:
            V - the predicted values of batch_obs
            log_probs - the log probabilities of the actions taken in batch_acts given batch_obs
    """

    # Query critic network for a value V for each batch_obs. Shape of V should be same as batch_rtgs
    V_ts = critic(batch_obs_ts)[0]

    V = decode_first_spike_batched(V_ts).squeeze().requires_grad_(True)

    # Calculate the log probabilities of batch actions using most recent actor network.
    # This segment of code is similar to that in get_action()
    mean_ts = actor(batch_obs_ts)[0].detach()

    mean = decode_first_spike_batched(mean_ts)

    dist = MultivariateNormal(mean, cov_mat)
    log_probs = dist.log_prob(batch_acts).requires_grad_(True)

    # Return the value vector V of each observation in the batch
    # and log probabilities log_probs of each action in the batch
    return V, log_probs

def get_action(obs):
    """
        Queries an action from the actor network, should be called from rollout.

        Parameters:
            obs - the observation at the current timestep

        Return:
            action - the action to take, as a numpy array
            log_prob - the log probability of the selected action in the distribution
    """

    obs = torch.tensor(obs)

    obs_st = encode_to_spikes_batched(obs.unsqueeze(0), num_steps=num_steps)

    # Query the actor network for a mean action
    mean_st = actor(obs_st)[0].detach()

    mean = decode_first_spike_batched(mean_st)


    # Create a distribution with the mean action and std from the covariance matrix above.
    # For more information on how this distribution works, check out Andrew Ng's lecture on it:
    # https://www.youtube.com/watch?v=JjB58InuTqM
    dist = MultivariateNormal(mean, cov_mat)

    # Sample an action from the distribution
    action = dist.sample()

    # Calculate the log probability for that action
    log_prob = dist.log_prob(action)

    # Return the sampled action and the log probability of that action in our distribution
    return action.detach().numpy(), log_prob.detach()

### Reset, First observation

In [119]:
# first observation, from reset
obs = torch.tensor([300.0000, 450.0000,   0.0000,   4.7124,   0.0000])

print(obs)

print(obs.shape)

tensor([300.0000, 450.0000,   0.0000,   4.7124,   0.0000])
torch.Size([5])


In [120]:
obs = obs.unsqueeze(0)

print(obs)


obs_st = encode_to_spikes_batched(obs ,num_steps=num_steps)

# print(obs_st)

#print(obs_st.shape)

tensor([[300.0000, 450.0000,   0.0000,   4.7124,   0.0000]])


### Get action and log prob of first observation
This is a one iteration of get_action() method

- Get network output to be used as a mean for the distribution.

- Create a distribution with the mean action and std from the covariance matrix above. <br/>
For more information on how this distribution works, check out Andrew Ng's lecture on it: <br/>
https://www.youtube.com/watch?v=JjB58InuTqM <br/>

- Sample an action from the distribution

- Calculate the log probability for that action


In [121]:
mean_spike_trains = actor(obs_st)

mean_spike_trains = mean_spike_trains[0].detach()
print(mean_spike_trains)


mean = decode_first_spike_batched(mean_spike_trains)
mean = (num_steps + 1) - mean

print(mean)

tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
        

In [122]:
dist = MultivariateNormal(mean, cov_mat)

action = dist.sample()

log_prob = dist.log_prob(action)

print("mean of distribution ", mean.detach())
print("action to take ", action.detach().numpy())
print("log probability of the action ",log_prob.detach())

mean of distribution  tensor([[0., 0., 0., 0., 0.]])
action to take  [[-0.02905154  0.7660263   0.5499304   0.5236275   0.03144384]]
log probability of the action  tensor([-4.0271])


In [123]:
batch_obs = torch.tensor([[300.0000, 450.0000,   0.0000,   4.7124,   0.0000],
        [299.9333, 450.0000,   3.9801,   4.7124,   0.0000],
        [299.8003, 450.0000,   7.9404,   4.7124,   0.0000],
        [300.0000, 450.0000,   0.0000,   4.7124,   0.0000],
        [299.9333, 450.0000,   3.9801,   4.7124,   0.0000],
        [299.8003, 450.0000,   7.9404,   4.7124,   0.0000]])

action, log_prob = get_action(batch_obs[0])

print(action)
print(log_prob)


torch.Size([1, 50, 5])
[[51.251507 50.864086 51.058483 51.350163 51.643894]]
tensor([-3.4842])


  obs = torch.tensor(obs)


### Rollout

- batch observations collected from simulation, first obs O_0 is from reset [0, n-1]
- batch actions collected from querying the network given observations [1, n]
- batch log probabilities collected from querying the network given observations [1, n]
- batch rewards collected from simulation after taking an action. [1, n]
- batch lenghts stores batch and episode lengths

In [124]:
# rollout

# batch obs
batch_obs = torch.tensor([[300.0000, 450.0000,   0.0000,   4.7124,   0.0000],
        [299.9333, 450.0000,   3.9801,   4.7124,   0.0000],
        [299.8003, 450.0000,   7.9404,   4.7124,   0.0000],
        [300.0000, 450.0000,   0.0000,   4.7124,   0.0000],
        [299.9333, 450.0000,   3.9801,   4.7124,   0.0000],
        [299.8003, 450.0000,   7.9404,   4.7124,   0.0000]])

batch_acts = []
batch_log_probs = []

for i in range(len(batch_obs)):
        # get action takes train spikes
        # check if this could be done without a loop
        action, log_prob = get_action(batch_obs[i])
        batch_acts.append(action)
        batch_log_probs.append(log_prob)


batch_rews = [[-1.1209449646284, -1.134160306418572, -1.1539176437616128], [-1.1209449646284, -1.134160306418572, -1.1539176437616128]]

batch_rtgs = compute_rtgs(batch_rews)

batch_lens = [3, 3]

batch_acts = torch.tensor(np.array(batch_acts), dtype=torch.float).squeeze()
batch_log_probs = torch.tensor(np.array(batch_log_probs), dtype=torch.float).squeeze()

print("batch acts ",batch_acts)
print("batch log probs ",batch_log_probs)
print("batch rtg ",batch_rtgs)

  obs = torch.tensor(obs)


torch.Size([1, 50, 5])
torch.Size([1, 50, 5])
torch.Size([1, 50, 5])
torch.Size([1, 50, 5])
torch.Size([1, 50, 5])
torch.Size([1, 50, 5])
batch acts  tensor([[51.8653, 50.4827, 50.7976, 51.7156, 50.3940],
        [51.5114, 51.0334, 51.5198, 50.7115, 50.1022],
        [50.8452, 50.5212, 51.2185, 51.5430, 51.2571],
        [50.4020, 50.0194, 51.5051, 51.1050, 50.8650],
        [50.3344, 50.7350, 51.2924, 51.7496, 50.4155],
        [50.1316, 51.3448, 50.5148, 51.2064, 51.5824]])
batch log probs  tensor([-4.7985, -4.2839, -3.5237, -4.4654, -4.3640, -4.3521])
batch rtg  tensor([-3.2398, -2.2304, -1.1539, -3.2398, -2.2304, -1.1539])


In [169]:
batch_obs_ts = encode_to_spikes_batched(batch_obs, num_steps=num_steps)

print(batch_obs_ts.shape)

torch.Size([6, 50, 5])


### Evaluate one iteration (for demonstration)
- Query critic network for a value V for each batch_obs. Shape of V should be same as batch_rtgs
- Calculate the log probabilities of batch actions using most recent actor network.
This segment of code is similar to that in get_action()

In [170]:
V_ts = critic(batch_obs_ts)[0].detach()
V = decode_first_spike_batched(V_ts).squeeze()


mean_ts = actor(batch_obs_ts)[0].detach()

mean = decode_first_spike_batched(mean_ts)

dist = MultivariateNormal(mean, cov_mat)
log_probs = dist.log_prob(batch_acts)

print(mean.shape)
print(V)
print(log_probs)

torch.Size([6, 5])
tensor([51., 51., 51., 51., 51., 51.])
tensor([ -6.1439,  -7.2954,  -6.0614,  -6.0476, -16.3022,  -4.9679])


In [171]:
# Calculate advantage at k-th iteration
V, _ = evaluate(batch_obs_ts, batch_acts)
A_k = batch_rtgs - V.detach()

print(V)
print(A_k)

tensor([51., 51., 51., 51., 51., 51.], requires_grad=True)
tensor([-54.2398, -53.2304, -52.1539, -54.2398, -53.2304, -52.1539])


In [172]:
# normalizing the advantage
A_k = (A_k - A_k.mean()) / (A_k.std() + 1e-10)

print(A_k)

tensor([-1.1059, -0.0240,  1.1298, -1.1059, -0.0240,  1.1298])


### Update the network through a number of iterations

- Calculate V_phi and pi_theta(a_t | s_t)
- Calculate the ratio pi_theta(a_t | s_t) / pi_theta_k(a_t | s_t) <br/>
NOTE: we just subtract the logs, which is the same as<br/>
dividing the values and then canceling the log with e^log.<br/>
For why we use log probabilities instead of actual probabilities,<br/>
here's a great explanation:<br/>
https://cs.stackexchange.com/questions/70518/why-do-we-use-the-log-in-gradient-based-reinforcement-algorithms<br/>
TL;DR makes gradient ascent easier behind the scenes.<br/>
- Calculate surrogate losses.
- Calculate actor and critic losses. <br/>
NOTE: we take the negative min of the surrogate losses because we're trying to maximize <br/>
the performance function, but Adam minimizes the loss. So minimizing the negative <br/>
performance function maximizes it. <br/>
- Calculate gradients and perform backward propagation for actor and critic network



In [173]:
num_updates_per_iteration = 1

actor_loss_arr = []
critic_loss_arr = []


for _ in range(num_updates_per_iteration):

    print("## update start ##")

    V, curr_log_probs = evaluate(batch_obs_ts, batch_acts)


    ratios = torch.exp(curr_log_probs - batch_log_probs)

    surr1 = ratios * A_k
    surr2 = torch.clamp(ratios, 1 - clip, 1 + clip) * A_k

    actor_loss = (-torch.min(surr1, surr2)).mean()
    critic_loss = nn.MSELoss()(V, batch_rtgs)

    actor_loss_arr.append(actor_loss)
    critic_loss_arr.append(critic_loss)

    actor_optim.zero_grad()
    actor_loss.backward(retain_graph=True)
    actor_optim.step()

    critic_optim.zero_grad()
    critic_loss.backward()
    critic_optim.step()


    print("* V \n",V.detach())
    print("* curr_log_probs \n",curr_log_probs.detach())
    print("* ratios \n",ratios.detach())
    print("* surr1 \n",surr1.detach())
    print("* surr2 \n",surr2.detach())

print(actor_loss)
print(critic_loss)

## update start ##
* V 
 tensor([51., 51., 51., 51., 51., 51.])
* curr_log_probs 
 tensor([ -6.1439,  -7.2954,  -6.0614,  -6.0476, -16.3022,  -4.9679])
* ratios 
 tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])
* surr1 
 tensor([-1.1059, -0.0240,  1.1298, -1.1059, -0.0240,  1.1298])
* surr2 
 tensor([-1.1059, -0.0240,  1.1298, -1.1059, -0.0240,  1.1298])
tensor(1.4106e-06, grad_fn=<MeanBackward0>)
tensor(2831.8203, grad_fn=<MseLossBackward0>)
