In [None]:
# ! pip install tensorboard

In [None]:
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from torch.distributions import MultivariateNormal
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
import gymnasium as gym
import random
import time
import os

In [None]:
# Networks

class feedForwardNN(nn.Module):
    def __init__(self, inDim, outDim):
        super(feedForwardNN, self).__init__()
        
        self.layer1 = nn.Linear(inDim, 64)
        self.layer2 = nn.Linear(64, 64)
        self.layer3 = nn.Linear(64, outDim)
        
    def forward(self, x):
        # x should be a tensor
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = self.layer3(x)
        
        return x

In [None]:
class PPO:
    def __init__(self, env):
        # Define the hyper parameters
        self._initHyperParameters()
        
        # Extract environment information
        self.env = env
        self.obsDim = env.observation_space.shape[0]
        self.actDim = env.action_space.shape[0]
        
        # Define the networks
        self.actor = feedForwardNN(self.obsDim, self.actDim)
        self.critic = feedForwardNN(self.obsDim, 1)
        
        # Create our variable for the matrix.
        # Note that I chose 0.5 for stdev arbitrarily.
        self.cov_var = torch.full(size=(self.act_dim,), fill_value=0.5)
        self.cov_mat = torch.diag(self.cov_var)
    
    def _initHyperParameters(self):
        # The default hyper parameters of the PPO strategy
        self.timeStepsPerBatch = 4800
        self.maxTimeStepsPerEpisode = 1600
        self.gamma = 0.95
    
    def getActions(self, obs):
        mean = self.actor(obs)
        dist = MultivariateNormal(mean, self.cov_mat)
        action = dist.sample()
        logProb = dist.log_prob(action)
        
        return action.detach().numpy(), logProb.detach()
    
    def rollout(self):
        # The data collector
        batchObs = []
        batchActions = []
        batchLogProbs = []
        batchRewards = []
        batchRewardsToGo = []
        batchEpisodeLengths = []
        
        t = 0
        while t < self.timeStepsPerBatch:
            # Rewards per episode
            episodeRewards = []
            
            obs = self.env.reset()
            done = False
            
            for tEpisode in range(self.maxTimeStepsPerEpisode):
                t += 1
                
                # Collect observations
                batchObs.append(obs)
                
                action, logProb = self.getActions(obs)
                obs, reward, done, info = self.env.step(action)
                
                episodeRewards.append(reward)
                batchActions.append(action)
                batchLogProbs.append(logProb)
                
                if done: break
            
            batchEpisodeLengths.append(tEpisode + 1)
            batchRewards.append(episodeRewards)
        
        batchObs = torch.tensor(batchObs, dtype=torch.float32)
        batchActions = torch.tensor(batchActions, dtype=torch.float32)
        batchLogProbs = torch.tensor(batchLogProbs, dtype=torch.float32)
        
        batchRewardsToGo = self.computeRewardsToGo(batchRewards)
        
        return batchObs, batchActions, batchLogProbs, batchRewardsToGo, batchEpisodeLengths
    
    def computeRewardsToGo(self, batchRewards):
        # The rewards-to-go per episode in each batch
        batchRewardsToGo = []
        
        for episodeRewards in reversed(batchRewards):
            discountedReward = 0
            
            for rew in reversed(episodeRewards):
                discountedReward = rew + discountedReward * self.gamma
                batchRewardsToGo.insert(0, discountedReward)
        
        batchRewardsToGo = torch.tensor(batchRewardsToGo, dtype=torch.float32)
        
        return batchRewardsToGo
    
    def learn(self, totalSteps):
        t = 0
        while t < totalSteps:            
            batchObs, batchActions, batchLogProbs, batchRewardsToGo = self.rollout()
            
    def evaluate(self, batchObs):
        # Query critic network for a value V for each obs in batchObs
        V = self.critic(batchObs).squeeze()
        
        return V