In [1]:
from typing import Union, Tuple
from collections import namedtuple

import gymnasium as gym

import torch
from torch.distributions import Normal
import torch.nn.functional as F
from torch import nn
from torch.optim import AdamW

In [2]:
device = 'cpu'
# Standard deviation of the normal distribution
STD = 0.5
# Learning rate for the gradient ascent/descent
LR = 0.005
# Doscount factor
GAMMA = 0.95
# CLIP interval (1 - EPSILON, 1 + EPSILON)
EPSILON = 0.2
# Number of times to repeat learning using a batch
EPOCHS = 5
# A named tuple that represents the transitions
Transition = namedtuple('Transition',
                        ('state', 'action', 'log_prob', 'next_state', 'reward'))

In [3]:
# The policy network is an estimator for the mean of a normal distribution
class FeedForwardNN(nn.Module):
	'''
		A standard in_dim-64-64-out_dim Feed Forward Neural Network.
	'''
	def __init__(self, in_dim: int, out_dim: int) -> None:
		super(FeedForwardNN, self).__init__()

		self.layer1 = nn.Linear(in_dim, 64)
		self.layer2 = nn.Linear(64, 64)
		self.layer3 = nn.Linear(64, out_dim)

	def forward(self, obs: torch.Tensor):

		activation1 = F.relu(self.layer1(obs))
		activation2 = F.relu(self.layer2(activation1))
		output = self.layer3(activation2)

		return output

In [4]:
env = gym.make('MountainCarContinuous-v0')

In [5]:
assert type(env.observation_space) == gym.spaces.Box
assert type(env.action_space) == gym.spaces.Box

In [6]:
# The dimensionality of the state representation
obs_dim = env.observation_space.shape[0]
# The dimensionality of the action representation should be 1
assert env.action_space.shape[0] == 1
act_dim = 1

In [7]:
# The actor network
actor = FeedForwardNN(obs_dim, act_dim).to(device)
# The critic network
critic = FeedForwardNN(obs_dim, 1).to(device)

In [8]:
# Actor and critic networks optimizer
actor_optim = AdamW(actor.parameters(), lr=LR, amsgrad=True)
critic_optim = AdamW(critic.parameters(), lr=LR, amsgrad=True)

In [40]:
def get_action(obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    '''Get the action (sampled) from the Normal distribution
       generated by the given state
    '''
    # The mean for the actor gaussian distribution
    mean = actor(obs)
    # actor ~ Normal(mean, std=0.5)
    dist = Normal(loc=mean, scale=0.5)
    # Sample an action from the current normal distribution
    action = dist.sample()
    # Get the log probability of the action sampled 
    log_prob = dist.log_prob(action)
    # Detach log_prob as it depends on the paramaters of the actor 
    return action, log_prob.detach()
    
def generate_batch() -> Tuple[torch.Tensor,...]:
    '''Generates a batch of atomic experiences from running different
       episodes. It will be called several times depending on the 
       number of learning iterations.   
    '''
    
    # All transitions in a batch
    transitions = []
    # Maximum timesteps per batch (may be violated to some extent)
    MAX_TIMESTEPS_PER_BATCH = 4800
    # Maximum timesteps per episode
    MAX_TIMESTEPS_PER_EPISODE = 1600
    # Initial timestep in the batch
    t = 0

    # Collect transitions for the batch
    while t < MAX_TIMESTEPS_PER_BATCH:
        # Get the first state in the episode
        state, info = env.reset()
        # Convert it into a torch tensor
        state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
        # Run until we reach the max timestep for an episode or a terminal state
        for _ in range(MAX_TIMESTEPS_PER_EPISODE):
            # Increment the timestemp (t0 -> t1 -> t2 ...)
            t += 1
            # Get the action using the current actor network
            action, log_prob = get_action(state)
            # Get the next state, reward, next state and its status (if it is a terminal)
            observation, reward, terminated, truncated, _ = env.step(action.squeeze(0).detach().numpy())
            reward = torch.tensor([reward], device=device)
            # Was the next state a terminal state?
            done = terminated or truncated
            # If the next state is a terminal state, then set it to None
            if terminated:
                next_state = None
            else:
                next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)
            # Add the experience to the buffer 
            transitions.append(Transition(state, action, log_prob, next_state, reward))
            # Move to the next state
            state = next_state
            
            if done:
                break # End of episode

    # Batch size
    batch_size = len(transitions)
    # Convert the batch into a Transition object
    batch = Transition(*zip(*transitions))
    # A boolean tensor to indicate if a state is terminal in the batch
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                        batch.next_state)), device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                           if s is not None])
    # initialize V(next_state) for all next_state in the batch
    next_state_values = torch.zeros(batch_size, device=device)
    # V(next_state) = 0 for terminal next_state
    next_state_values[non_final_mask] = critic(non_final_next_states).squeeze(1)
    # Batch states
    state_batch = torch.cat(batch.state)
    # Batch actions (in each state)
    action_batch = torch.cat(batch.action)
    # Batch log probs
    log_prob_batch = torch.cat(batch.log_prob)
    # Batch rewards (as a result of actions in each state)
    reward_batch = torch.cat(batch.reward)
    # Get the target for this batch (r + \gmma * V(next_state) - V(state))
    # The target can also be used as an estimator of A_t
    target = reward_batch + GAMMA * next_state_values

    return state_batch, action_batch, log_prob_batch, target.detach()

def learn(total_timesteps: int) -> None:
    '''Learns the critic and actor
    
    Parameters
    ==========
    total_timesteps: The maximum number of timesteps throughout the
                     entire learning process. 
    '''
    # The initial timestep throughout the entire learning process
    t = 0
    
    while t < total_timesteps:
        # Generate a batch such that we can iterate through later
        batch = generate_batch()
        # Add the number of timesteps generated in the batch to the initial t
        t += len(batch)
        # A boolean tensor to indicate if a state is terminal in the batch
        non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                        batch.next_state)), device=device, dtype=torch.bool)
        non_final_next_states = torch.cat([s for s in batch.next_state
                                           if s is not None])
        # Batch states
        state_batch = torch.cat(batch.state)
        # Batch actions (in each state)
        action_batch = torch.cat(batch.action)
        # Batch log probs
        log_prob_batch = torch.cat(batch.log_prob)
        # Batch rewards (as a result of actions in each state)
        reward_batch = torch.cat(batch.reward)

    V = critic(batch_obs).squeeze()
    mean = actor(batch_obs)
    dist = Normal(loc=mean, scale=0.5)
    # New log values for the actions in the batch
    log_probs = dist.log_prob(batch_acts)

    return V, log_probs