In [153]:
import torch, torch.nn as nn
import snntorch as snn
from snntorch import spikegen
import numpy as np
import time

from torch.distributions import MultivariateNormal
import snntorch.functional as SF

In [154]:
timesteps_per_batch = 4800  # Number of timesteps to run per batch
max_timesteps_per_episode = 1600  # Max number of timesteps per episode
n_updates_per_iteration = 5  # Number of times to update actor/critic per iteration
lr = 0.005  # Learning rate of actor optimizer
gamma = 0.95  # Discount factor to be applied when calculating Rewards-To-Go
clip = 0.2  # Recommended 0.2, helps define the threshold to clip the ratio during SGA

# Miscellaneous parameters
render = True  # If we should render during rollout
render_every_i = 10  # Only render every n iterations
save_freq = 10  # How often we save in number of iterations
seed = None  # Sets the seed of our program, used for reproducibility of results

## SNN def

In [274]:
hidden_size = 300  # number of hidden neurons


class SNN(nn.Module):
    def __init__(self, input_size, output_size, num_steps):
        super(SNN, self).__init__()

        self.num_steps = num_steps
        beta1 = 0.9
        beta2 = torch.rand((output_size), dtype = torch.float) # independent decay rate for each leaky neuron in layer 2: [0, 1)

        self.fc1 = nn.Linear(input_size, hidden_size)
        self.lif1 = snn.Leaky(beta=beta1)
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.lif2 = snn.Leaky(beta=beta2, learn_beta=True)

    def forward(self, x):
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        spk2_rec = [] # record output spikes
        mem2_rec = [] # record output hidden states

        for step in range(self.num_steps):
            cur1 = self.fc1(x[:, step, :])
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            spk2_rec.append(spk2) # record spikes
            mem2_rec.append(mem2) # record membrane

        return torch.stack(spk2_rec, dim=1), torch.stack(mem2_rec)

In [303]:
hidden_size = 64

class SNN_2(nn.Module):
    def __init__(self, input_size, output_size, num_steps):
        super(SNN_2, self).__init__()
        self.num_steps = num_steps

        self.fc1 = nn.Linear(input_size, hidden_size)
        self.lif1 = snn.Leaky(beta=0.9)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.lif2 = snn.Leaky(beta=0.9)
        self.fc3 = nn.Linear(hidden_size, output_size)
        self.lif3 = snn.Leaky(beta=0.9, learn_beta=True)

    def forward(self, x):
        # Initialize the membrane potentials
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        # Record the final layer
        spk3_rec = []

        for step in range(self.num_steps):
            cur1 = self.fc1(x[:,step,:])
            spk1, mem1 = self.lif1(cur1, mem1)
            cur2 = self.fc2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)
            cur3 = self.fc3(spk2)
            spk3, mem3 = self.lif3(cur3, mem3)

            spk3_rec.append(spk3)

        return torch.stack(spk3_rec, dim=1)

In [298]:
def decode_time_to_first_spike(spike_trains):
    decoded_vector = [spike_trains.size(0)+1] * spike_trains.size(1)
    
    for neuron_idx in range(spike_trains.size(1)):
        first_spike = (spike_trains[:, neuron_idx] == 1).nonzero(as_tuple=True)[0]
        if first_spike.nelement() != 0:
            decoded_vector[neuron_idx] = first_spike[0].item() + 1
    
    return torch.FloatTensor(decoded_vector)

In [238]:
def decode_from_spikes_count(spikes):
    spike_counts = torch.sum(spikes, dim=1)
    action = torch.zeros(spikes.size(0))
    max_spike_count = torch.max(spike_counts)
    candidates = torch.where(spike_counts == max_spike_count)[0]
    if len(candidates) > 1:
        action[torch.multinomial(candidates.float(), 1)] = 1
    else:
        action[candidates] = 1
    return action

In [239]:
def encode_to_spikes(data, num_steps):
    """
    Encodes analog signals into spike trains using rate encoding.

    Parameters:
        data - The continuous-valued data to be encoded.
        num_steps - The number of time steps for the spike train.

    Returns:
        spike_train - The encoded spike train.
    """

    # Add a small epsilon to avoid division by zero
    epsilon = 1e-6

    # Normalize the data to be between 0 and 1
    normalized_data = (data - data.min()) / (data.max() - data.min() + epsilon)

    # Convert normalized data to spike trains
    # TODO rate vs latency vs delta
    spike_train = spikegen.rate(normalized_data, num_steps=num_steps)

    return spike_train

In [240]:
def compute_rtgs(batch_rews):
    """
        Compute the Reward-To-Go of each timestep in a batch given the rewards.

        Parameters:
            batch_rews - the rewards in a batch, Shape: (number of episodes, number of timesteps per episode)

        Return:
            batch_rtgs - the rewards to go, Shape: (number of timesteps in batch)
    """
    # The rewards-to-go (rtg) per episode per batch to return.
    # The shape will be (num timesteps per episode)
    batch_rtgs = []

    # Iterate through each episode
    for ep_rews in reversed(batch_rews):

        discounted_reward = 0  # The discounted reward so far

        # Iterate through all rewards in the episode. We go backwards for smoother calculation of each
        # discounted return (think about why it would be harder starting from the beginning)
        for rew in reversed(ep_rews):
            discounted_reward = rew + discounted_reward * gamma
            batch_rtgs.insert(0, discounted_reward)

    # Convert the rewards-to-go into a tensor
    batch_rtgs = torch.tensor(batch_rtgs, dtype=torch.float)

    return batch_rtgs

In [241]:
obs_dim = 5
act_dim = 5

# Define the number of time steps for the simulation
num_steps = 50

In [313]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device(
    "mps") if torch.backends.mps.is_available() else torch.device("cpu")

# Initialize actor and critic networks
actor = SNN(obs_dim, act_dim, num_steps).to(device)  # ALG STEP 1
critic = SNN(obs_dim, 1, num_steps).to(device)

# Initialize optimizers for actor and critic
actor_optim = torch.optim.Adam(actor.parameters(), lr=lr, betas=(0.9, 0.999))
critic_optim = torch.optim.Adam(critic.parameters(), lr=lr, betas=(0.9, 0.999))

cov_var = torch.full(size=(act_dim,), fill_value=0.5)
cov_mat = torch.diag(cov_var)

#### observation needs to be spike trains

In [314]:
# first observation, from reset
obs = torch.tensor([300.0000, 450.0000,   0.0000,   4.7124,   0.0000])

In [315]:
# TODO observations -> spike trains
obs_spike_trains = encode_to_spikes(obs ,num_steps=num_steps)

reshaped = torch.reshape(obs_spike_trains, [1, 50, 5])
print(reshaped.shape)

torch.Size([1, 50, 5])


#### from get_action() method

In [316]:
# TODO spike trains -> mean action
mean_spike_trains = actor(reshaped)

mean_spike_trains = mean_spike_trains[0].detach()

print(mean_spike_trains.shape)


torch.Size([1, 50, 5])


In [317]:
# TODO decode with 

mean = decode_time_to_first_spike(np.reshape(mean_spike_trains, [50,5]))

print(mean)

tensor([51., 51., 51., 51., 51.])


In [302]:

# Create a distribution with the mean action and std from the covariance matrix above.
# For more information on how this distribution works, check out Andrew Ng's lecture on it:
# https://www.youtube.com/watch?v=JjB58InuTqM
dist = MultivariateNormal(mean, cov_mat)

# Sample an action from the distribution
action = dist.sample()

# Calculate the log probability for that action
log_prob = dist.log_prob(action)

print(mean.detach())
print(action.detach().numpy())
print(log_prob.detach())

tensor([51., 51., 51., 51., 51.])
[50.26872  51.12235  49.84141  50.737286 50.650578]
tensor(-4.9450)


In [None]:
# rollout

# batch obs
batch_obs = torch.tensor([[300.0000, 450.0000,   0.0000,   4.7124,   0.0000],
        [300.0667, 450.0000,   3.9801,   4.7124,   0.0000],
        [300.1997, 450.0000,   7.9404,   4.7124,   0.0000],
        [300.0000, 450.0000,   0.0000,   4.7124,   0.0000],
        [300.0667, 450.0000,   3.9801,   4.7124,   0.0000],
        [300.1997, 450.0000,   7.9404,   4.7124,   0.0000]])

# batch actions
batch_acts =  torch.tensor([[13.1795, 10.1534, 10.7613,  9.3591, -7.8791],
        [13.7555,  9.6675, 10.0287,  9.6279, -6.4844],
        [13.1954, 10.4574,  9.9011,  9.5315, -7.1549],
        [13.4455,  9.4522,  9.5629, 10.1283, -6.6890],
        [14.7801,  8.4991,  8.6375,  9.9466, -6.9821],
        [13.8373, 10.2118, 10.8795,  8.7003, -6.5202]])

# batch log prob
batch_log_probs = torch.tensor([-5.0741, -3.2771, -3.4157, -4.1163, -8.0700, -4.7627])

# batch rewards
batch_rtgs = torch.tensor([-3.1173, -2.1154, -1.0746, -3.1173, -2.1154, -1.0746])

# batch lengths
batch_lens = [3, 3]

In [None]:
# TODO batch_obs to spike trains batch obs

batch_obs_spike_trains = []

#### spike trains actions / preductions -> actions / predictions 

In [None]:
def evaluate(batch_obs, batch_acts):
    """
        Estimate the values of each observation, and the log probs of
        each action in the most recent batch with the most recent
        iteration of the actor network. Should be called from learn.

        Parameters:
            batch_obs - the observations from the most recently collected batch as a tensor.
                        Shape: (number of timesteps in batch, dimension of observation)
            batch_acts - the actions from the most recently collected batch as a tensor.
                        Shape: (number of timesteps in batch, dimension of action)

        Return:
            V - the predicted values of batch_obs
            log_probs - the log probabilities of the actions taken in batch_acts given batch_obs
    """
    
    # Query critic network for a value V for each batch_obs. Shape of V should be same as batch_rtgs
    V = critic(batch_obs_spike_trains).squeeze()

    # Calculate the log probabilities of batch actions using most recent actor network.
    # This segment of code is similar to that in get_action()
    mean = actor(batch_obs)
    dist = MultivariateNormal(mean, cov_mat)
    log_probs = dist.log_prob(batch_acts)

    # Return the value vector V of each observation in the batch
    # and log probabilities log_probs of each action in the batch
    return V, log_probs

In [None]:
# Calculate advantage at k-th iteration
V, _ = evaluate(batch_obs, batch_acts)
A_k = batch_rtgs - V.detach()

print(V)
print(A_k)

In [None]:
# normalizing the advantage
A_k = (A_k - A_k.mean()) / (A_k.std() + 1e-10)

print(A_k)

In [None]:
# Calculate V_phi and pi_theta(a_t | s_t)
V, curr_log_probs = evaluate(batch_obs, batch_acts)

# Calculate the ratio pi_theta(a_t | s_t) / pi_theta_k(a_t | s_t)
# NOTE: we just subtract the logs, which is the same as
# dividing the values and then canceling the log with e^log.
# For why we use log probabilities instead of actual probabilities,
# here's a great explanation:
# https://cs.stackexchange.com/questions/70518/why-do-we-use-the-log-in-gradient-based-reinforcement-algorithms
# TL;DR makes gradient ascent easier behind the scenes.
ratios = torch.exp(curr_log_probs - batch_log_probs)

print(V)
print(curr_log_probs)
print(ratios)

In [None]:
# Calculate surrogate losses.
surr1 = ratios * A_k
surr2 = torch.clamp(ratios, 1 - clip, 1 + clip) * A_k

print(surr1)
print(surr2)

In [None]:
# Calculate actor and critic losses.
# NOTE: we take the negative min of the surrogate losses because we're trying to maximize
# the performance function, but Adam minimizes the loss. So minimizing the negative
# performance function maximizes it.
actor_loss = (-torch.min(surr1, surr2)).mean()
critic_loss = nn.MSELoss()(V, batch_rtgs)

print(actor_loss)
print(critic_loss)

# Calculate gradients and perform backward propagation for actor network
actor_optim.zero_grad()
actor_loss.backward(retain_graph=True)
actor_optim.step()

# Calculate gradients and perform backward propagation for critic network
critic_optim.zero_grad()
critic_loss.backward()
critic_optim.step()