In [18]:
import torch
import snntorch
from torch import nn
import torch.nn.functional as F
import numpy as np

import gym
import numpy as np
import torch
from torch import nn
from torch.distributions import MultivariateNormal
from torch.optim.adam import Adam

## ANN definition

In [3]:
class FeedForwardNN(nn.Module):
    """
        A standard in_dim-64-64-out_dim Feed Forward Neural Network.
    """

    def __init__(self, in_dim, out_dim):
        """
            Initialize the network and set up the layers.

            Parameters:
                in_dim - input dimensions as an int
                out_dim - output dimensions as an int

                Return:
                None
        """
        super(FeedForwardNN, self).__init__()

        self.layer1 = nn.Linear(in_dim, 64)
        self.layer2 = nn.Linear(64, 64)
        self.layer3 = nn.Linear(64, out_dim)

    def forward(self, obs):
        """
            Runs a forward pass on the neural network.

            Parameters:
                obs - observation to pass as input

            Return:
                output - the output of our forward pass
        """

        print("state", obs)

        # Convert observation to tensor if it's a numpy array
        if isinstance(obs, np.ndarray):
            obs = torch.tensor(obs, dtype=torch.float)

        activation1 = torch.relu(self.layer1(obs))
        activation2 = torch.relu(self.layer2(activation1))
        output = self.layer3(activation2)

        return output

In [24]:
def compute_rtgs(self, batch_rews):
    """
        Compute the Reward-To-Go of each timestep in a batch given the rewards.

        Parameters:
            batch_rews - the rewards in a batch, Shape: (number of episodes, number of timesteps per episode)

        Return:
            batch_rtgs - the rewards to go, Shape: (number of timesteps in batch)
    """
    # The rewards-to-go (rtg) per episode per batch to return.
    # The shape will be (num timesteps per episode)
    batch_rtgs = []

    # Iterate through each episode
    for ep_rews in reversed(batch_rews):

        discounted_reward = 0  # The discounted reward so far

        # Iterate through all rewards in the episode. We go backwards for smoother calculation of each
        # discounted return (think about why it would be harder starting from the beginning)
        for rew in reversed(ep_rews):
            discounted_reward = rew + discounted_reward * self.gamma
            batch_rtgs.insert(0, discounted_reward)

    # Convert the rewards-to-go into a tensor
    batch_rtgs = torch.tensor(batch_rtgs, dtype=torch.float)

    return batch_rtgs

In [4]:
obs_dim = 5
act_dim = 5

In [8]:
# create an instance for the actor network
actor = FeedForwardNN(obs_dim, act_dim)  # ALG STEP 1

# create an instance for the critic network

critic = FeedForwardNN(obs_dim, 1)

cov_var = torch.full(size=(act_dim,), fill_value=0.5)
cov_mat = torch.diag(cov_var)

In [11]:
# first observation, from reset
obs = torch.tensor([300.0, 500.0, 20.0, 1.5, 0.0])

In [20]:
mean = actor(obs)

print(mean.detach())

state tensor([300.0000, 500.0000,  20.0000,   1.5000,   0.0000])
tensor([  9.6540,  -1.3000,   7.4884, -37.3375, -33.8045])


In [22]:
# Create a distribution with the mean action and std from the covariance matrix above.
# For more information on how this distribution works, check out Andrew Ng's lecture on it:
# https://www.youtube.com/watch?v=JjB58InuTqM
dist = MultivariateNormal(mean, cov_mat)

# Sample an action from the distribution
action = dist.sample()

# Calculate the log probability for that action
log_prob = dist.log_prob(action)

# Return the sampled action and the log probability of that action in our distribution
print(action.detach().numpy())
print(log_prob.detach())

MultivariateNormal(loc: torch.Size([5]), covariance_matrix: torch.Size([5, 5]))
[  9.041698   -2.1892595   7.27935   -37.093643  -32.856956 ]
tensor(-5.0285)


In [23]:
obs = [300.0, 500.0, 20.0, 1.5, 0.0]
rew = 1200
truncated = False
terminated = False

In [None]:
# batch rewards

# batch obs

# batch actions

# batch log prob

# batch rewards to go
