In [1]:
from typing import Union, Tuple, Optional, List
from collections import namedtuple

import gymnasium as gym

import torch
from torch.distributions import Normal
import torch.nn.functional as F
from torch import nn
from torch.optim import AdamW

In [2]:
DEVICE = torch.device('cpu')
# Standard deviation of the normal distribution
STD = 0.5
# Learning rate for the gradient ascent/descent
LR = 0.005
# Doscount factor
GAMMA = 0.95
# CLIP interval (1 - EPSILON, 1 + EPSILON)
EPSILON = 0.2
# Number of times to repeat learning using a batch
EPOCHS = 5
# Maximum timesteps per batch (may be violated to some extent)
MAX_TIMESTEPS_PER_BATCH = 4800
# Maximum timesteps per episode
MAX_TIMESTEPS_PER_EPISODE = 1600

In [3]:
# A named tuple that represents the transitions
Transition = namedtuple('Transition',
                        ('state', 'action', 'log_prob', 'next_state', 'reward'))

In [4]:
# The policy network is an estimator for the mean of a normal distribution
class FeedForwardNN(nn.Module):
	'''
		A Feed Forward Neural Network of the following structure:
						in_dim-64-64-out_dim
	'''
	def __init__(self, in_dim: int, out_dim: int) -> None:
		super(FeedForwardNN, self).__init__()

		self.layer1 = nn.Linear(in_dim, 64)
		self.layer2 = nn.Linear(64, 64)
		self.layer3 = nn.Linear(64, out_dim)

	def forward(self, x: torch.Tensor):

		activation1 = F.relu(self.layer1(x))
		activation2 = F.relu(self.layer2(activation1))
		output = self.layer3(activation2)

		return output

In [5]:
# Choose an environment with a continuous action space
env = gym.make('MountainCarContinuous-v0')

In [6]:
# Check if the action and state spaces are of type BOX
assert type(env.observation_space) == gym.spaces.Box
assert type(env.action_space) == gym.spaces.Box

In [7]:
# The dimensionality of the state representation
dim_state = env.observation_space.shape[0]
# The dimensionality of the action representation should be 1
assert env.action_space.shape[0] == 1
# With the test passed, set the action dimensionality to 1
dim_action = 1

In [8]:
# The actor network
actor = FeedForwardNN(dim_state, dim_action).to(DEVICE)
# The critic network
critic = FeedForwardNN(dim_state, 1).to(DEVICE)

In [9]:
# Optimizer for actor and critic networks 
actor_optim = AdamW(actor.parameters(), lr=LR, amsgrad=True)
critic_optim = AdamW(critic.parameters(), lr=LR, amsgrad=True)

In [10]:
def get_action(obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    '''Get the action (sampled) from the Normal distribution
       generated by the given state
    '''
    # The mean for the actor gaussian distribution
    mean = actor(obs)
    # actor ~ Normal(mean, std=0.5)
    dist = Normal(loc=mean, scale=0.5)
    # Sample an action from the current normal distribution
    action = dist.sample()
    # Get the log probability of the action sampled 
    log_prob = dist.log_prob(action)
    # Detach log_prob as it depends on the parameters of the actor 
    return action, log_prob.detach()

def evaluate(
    batch_state: torch.Tensor,
    batch_action: torch.Tensor,
    only_v: bool = False
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
    '''Computes the state value for each state
       in the batch. If only_v is False, it will
       also computes the log probabilities of the
       actions sampled from the generated gaussian 
       distributions.
    
    Parameters
    ==========
    batch_state:  Tensor of states in the batch
    batch_action: Tensor of actions taken for each state in the batch
    only_v:       Whether the log probabilities of the actions in the batch
                  is needed. 

    Returns
    =======
    A tuple of two tensors (with the second element optional)
    '''
    # Will remain None if only_v is True
    log_probs = None
    V = critic(batch_state).squeeze(1)
    
    if not only_v:
        mean = actor(batch_state)
        dist = Normal(loc=mean, scale=0.5)
        log_probs = dist.log_prob(batch_action)
    
    return V, log_probs

def generate_batch() -> List[Transition]:
    '''Generates a batch of atomic experiences from running different
       episodes. It will be called several times depending on the 
       number of learning iterations.   
    '''
    
    # All transitions in a batch
    batch_transitions = []
    # Initial timestep in the batch
    t = 0

    # Collect transitions for the batch
    while t < MAX_TIMESTEPS_PER_BATCH:
        # Get the first state in the episode
        state, info = env.reset()
        # Convert it into a torch tensor
        state = torch.tensor(state, dtype=torch.float32, device=DEVICE).unsqueeze(0)
        # Run until we reach the max timestep for an episode or a terminal state
        for _ in range(MAX_TIMESTEPS_PER_EPISODE):
            # Increment the timestemp (t0 -> t1 -> t2 ...)
            t += 1
            # Get the action using the current actor network
            action, log_prob = get_action(state)
            # Get the next state, reward, and the kind of the next state (if it is a terminal)
            observation, reward, terminated, truncated, _ = env.step(action.squeeze(0).detach().numpy())
            reward = torch.tensor([reward], device=DEVICE)
            # Was the next state a terminal state?
            done = terminated or truncated
            # If the next state is a terminal state, then set it to None
            if terminated:
                next_state = None
            else:
                next_state = torch.tensor(observation, dtype=torch.float32, device=DEVICE).unsqueeze(0)
            # Add the experience to the bactch 
            batch_transitions.append(Transition(state, action, log_prob, next_state, reward))
            # Move to the next state
            state = next_state
            
            if done:
                break # End of the episode

    return batch_transitions

def decompose(batch_transitions: List[Transition]) -> Tuple[torch.Tensor,...]:
    '''Extract states, actions taken in each state,
       received rewards, next state, targets, and
       estimations of the advantage function as a
       torch tensor.

    Parameters
    ==========
    batch_transitions: A list of all transitions in a batch 

    Returns
    =======
    A tuple of torch tensors
    '''

    # Batch size
    batch_size = len(batch_transitions)
    
    # Convert the batch into a Transition object
    batch = Transition(*zip(*batch_transitions))
    
    # A boolean tensor to indicate if a state is terminal in the batch
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
                                        batch.next_state)), device=DEVICE, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state
                                           if s is not None])
    # initialize V(next_state) for all next_state in the batch
    next_state_values = torch.zeros(batch_size, device=DEVICE)
    
    # V(next_state) = 0 for terminal next_state
    next_state_values[non_final_mask] = critic(non_final_next_states).squeeze(1)
    
    # Batch states
    batch_state = torch.cat(batch.state)
    
    # Batch actions (in each state)
    batch_action = torch.cat(batch.action)
    
    # Batch log probs
    batch_log_prob = torch.cat(batch.log_prob)
    
    # Batch rewards (as a result of actions in each state)
    batch_reward = torch.cat(batch.reward)
    
    # Get the target for this batch (r + GAMMA * V(next_state))
    batch_target = (batch_reward + GAMMA * next_state_values).detach()
    
    return batch_state, batch_action, batch_log_prob, batch_target

def learn(total_timesteps: int) -> None:
    '''Learns the critic and actor
    
    Parameters
    ==========
    total_timesteps: The maximum number of timesteps throughout the
                     entire learning process. 
    '''
    # The initial timestep throughout the entire learning process
    t = 0

    # Number of batches generated so far
    iter = 0
    
    while t < total_timesteps:
        
        # Generate a batch such that we can iterate through later
        batch_state, batch_action, batch_log_prob, batch_target = decompose(batch_transitions := generate_batch())

        iter += 1

        print(f'Iteration {iter} started ...')
        # Add the number of timesteps generated in the batch to the initial t
        t += len(batch_transitions) 
         # Get the state value (V(state) for all states in the batch)
        V, _ = evaluate(batch_state, batch_action, only_v=True) 
        # Normalized advantage function estimation
        A_hat = batch_target - V.detach()
        A_hat = (A_hat - A_hat.mean()) / (A_hat.std() + 1e-10)
        
        for i in range(1, EPOCHS + 1):
            print(f'\t>>> Epoch {i} started ...')
            V, curr_log_prob = evaluate(batch_state, batch_action)
            ratios = torch.exp(curr_log_prob - batch_log_prob)
            surr1 = ratios * A_hat
            surr2 = torch.clamp(ratios, 1 - EPSILON, 1 + EPSILON) * A_hat
            # CLIP loss for actor
            actor_loss = (-torch.min(surr1, surr2)).mean()
            # MSE loss for critic
            critic_loss = nn.MSELoss()(V, batch_target)

            # Calculate gradients and perform backward propagation for actor
            actor_optim.zero_grad()
            actor_loss.backward()
            actor_optim.step()

            # Calculate gradients and perform backward propagation for critic
            critic_optim.zero_grad()
            critic_loss.backward(retain_graph=True)
            critic_optim.step()
        
        print(f'Iteration {iter} finished!')
        print('***************************')

In [None]:
learn(200000)