# Exercise 2: Advantage Actor-Critic (A2C)

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import equinox as eqx
import optax
import cv2
from pyvirtualdisplay import Display

from a2c_utils import *

The Lunar Lander environment is a playground for RL defined in the [Gymnasium library](https://gymnasium.farama.org/environments/box2d/lunar_lander/). In this exercise, we will use the advantage actor-critic (A2C) method to learn a policy to control the lunar lander in an optimal way.

#### Exercise 2.1: Define the Policy
In the first part of this exercise, implement the `Policy.__call__` method to define the model that we will train to control the lunar lander. This model will output a `MultivariateNormalDiag` object that defines a multivariate normal distribution over the action space (see the definition in `a2c_utils.py`, and also an estimate of the value function for the input.

In [None]:
class Policy(eqx.Module):
    # Define a model to represent the actor and critic. 
    trunk_layers: list
    action_mean_head: eqx.Module
    action_std_head: eqx.Module
    value_head: eqx.Module

    def __init__(self, state_dim, action_dim, key):
        # `PRNGKey`s for initializing NN layers.
        keys = jax.random.split(key, 5)
        
        # Embedding layers.
        self.trunk_layers = [
            eqx.nn.Linear(state_dim, 128, key=keys[0]),
            eqx.nn.Linear(128, 128, key=keys[1]),
        ]
        
        # Actor's layers.
        self.action_mean_head = eqx.nn.Linear(128, action_dim, key=keys[2])
        self.action_std_head = eqx.nn.Linear(128, action_dim, key=keys[3])
        
        # Critic's layers.
        self.value_head = eqx.nn.Linear(128, 1, key=keys[4])
        
    @jax.jit
    def __call__(self, x):
        """
        Evaluate the policy at the input `x`.

        Args:
            x: input to evaluate
        Returns:
            A tuple with the first element being a MultivariateNormalDiag representing
            the policy for the input and the second element being the float value function
            estimate for the input.
        """
        ##### YOUR CODE STARTS HERE #####
        # Create a model with two linear embedding layers with ReLU activations
        # that feed into the action_mean head, standard deviation head, and value head.
        # Hint: use jax.nn.softplus for the standard deviation head.
        # Hint: make sure to use all of the Policy member variables defined above
        
        ###### YOUR CODE END HERE ######

#### Exercise 2.2: Compute Episode Returns and Training Loss Function
Implement the following functions:
1. `compute_returns`: a function to compute the discounted tail returns for an episode
2. `train_loss_for_epsiode`: a function to compute the training loss for an episode

In [None]:
def compute_returns(rewards, discount_factor):
    """
    Compute the discounted returns from a sequence of rewards:
      rewards = [r_0, r_1, ..., r_{T-1}].

    Specifically, compute the list of discounted tail returns, [G_0, G_1, ..., G_{T-1}]
    where:
      G_t = sum_{k=t}^{T-1} Î³^{k-t} r_k

    Args:
        rewards: array of rewards, [r_0, r_1, ..., r_{T-1}]
        discount_factor: temporal discount factor for rewards
    Returns:
        Array of returns, [G_0, G_1, ..., G_{T-1}]
    """
    ##### YOUR CODE STARTS HERE #####
    
    ###### YOUR CODE END HERE ######

def train_loss_for_epsiode(policy, states, actions, returns, num_steps):
    """
    Compute the loss function for the given episode data. Uses Monte Carlo estimates
    of the value function via the `returns` input to compute the advantage function.

    Args:
        policy: policy model
        states: array of states for the episode
        actions: array of actions taken in the episode
        returns: array of discounted tail returns for the episode
        num_steps: number of steps in the episode
    Returns:
        Float representing the actor loss + critic loss
    """
    mask = jnp.arange(len(states)) < num_steps
    ##### YOUR CODE STARTS HERE #####
    # Hint: use jax.vmap on policy
    # Hint: when you compute the actor loss, use jax.lax.stop_gradient(advantages)
    # to make sure the gradients of the critical model are not computed in the actor
    # loss.
    # Hint: use the `mask` variable on the computed advantages to since the states,
    # actions, and returns are padded but only the first num_steps values are valid.
    
    ###### YOUR CODE END HERE ######

@jax.jit
def train_step_for_episode(opt_state, policy, states, actions, returns, num_steps):
    """
    Use the function `train_loss_for_epsiode` to update the model parameters.
    """
    grads = jax.grad(train_loss_for_epsiode)(policy, states, actions, returns, num_steps)
    updates, opt_state = optimizer.update(grads, opt_state)
    policy = optax.apply_updates(policy, updates)
    return opt_state, policy

#### Exercise 2.3: Train Model
Run the code below to train the model.

In [None]:
# Training parameters.
discount_factor = 0.99 # Discount factor for computing tail returns.
ema_factor = 0.99 # Exponential moving average for standardizing returns.
key = jax.random.PRNGKey(0) # Random seed for NN initialization, action sampling.
max_steps = 600 # Max number of steps in a given episode
num_episodes = 300 # Number of episodes to train over
num_renders = 4 # Render video of the lunar lander this many times

# Set display
Display(visible=False, size=(1400, 900)).start()

# Define lunar lander environment
lunar_lander = LunarLander()

# Define policy object
key, policy_key = jax.random.split(key)
policy = Policy(lunar_lander.state_dim, lunar_lander.action_dim, policy_key)
optimizer = optax.adam(learning_rate=1e-3)
opt_state = optimizer.init(policy)

# Train
episodic_rewards_ema = [] # Exponential moving average of episodic rewards.
return_ema = None  # Exponential moving average of returns (i.e., critic targets).
return_emv = None  # Exponential moving variance of returns (i.e., critic targets).
render_interval = int(num_episodes / num_renders) 
for i_episode in range(num_episodes + 1):
    video_filename = f"{lunar_lander.video_file_directory}/lunar_lander_episode_{i_episode}.mp4" \
                        if i_episode % render_interval == 0 else None
    
    # Reset environment at the start of each episode, and clear accumulators.
    state = lunar_lander.env.reset()[0]
    episodic_reward = 0
    states = []
    actions = []
    rewards = []

    # Sample a trajectory (episode) according to our stochastic policy/environment dynamics.
    for t in range(max_steps):
        states.append(state)
        action_distribution, _ = policy(state)
        key, sample_key = jax.random.split(key)
        action = np.array(action_distribution.sample(sample_key))  # Leave JAX to interact with gym.
        state, reward, done, _, _ = lunar_lander.env.step(action)
        episodic_reward += reward
        actions.append(action)
        rewards.append(reward)

        # Render video at specified interval of episodes
        if video_filename is not None:
            if t == 0:
                video = cv2.VideoWriter(video_filename, cv2.VideoWriter_fourcc(*"mp4v"), 50, (600, 400))
            video.write(lunar_lander.env.render())
        if done:
            break

    if video_filename is not None:
        video.release()
        display_local_video_in_notebook(video_filename)

    # Compute (standardized) tail returns for the episode and update moving averages.
    returns = compute_returns(rewards, discount_factor)
    if i_episode == 0:
        episodic_rewards_ema.append(episodic_reward)
        return_ema = returns.mean()
        return_emv = returns.var()
    else:
        episodic_rewards_ema.append(0.95 * episodic_rewards_ema[-1] + (1 - 0.95) * episodic_reward)
        return_ema = ema_factor * return_ema + (1 - ema_factor) * returns.mean()
        return_emv = ema_factor * (return_emv + (1 - ema_factor) * np.mean((returns - return_ema)**2))
    # Note: standardizing returns using moving population statistics is reminiscent of batch normalization.
    standardized_returns = (returns - return_ema) / (np.sqrt(return_emv) + 1e-6)

    # Run a train step based on the episode's data.
    num_steps = len(states)
    # JAX prefers all arrays to be the same shape, so we pad to the batch size.
    opt_state, policy = train_step_for_episode(
        opt_state,
        policy,
        np.pad(states, ((0, max_steps - num_steps), (0, 0))),
        np.pad(actions, ((0, max_steps - num_steps), (0, 0))),
        np.pad(standardized_returns, ((0, max_steps - num_steps),)),
        num_steps,
    )

    # Periodically log results.
    if i_episode % 10 == 0:
        print(
            f"Episode {i_episode}\tLast reward: {episodic_reward:.2f}\tMoving average reward: {episodic_rewards_ema[-1]:.2f}"
        )

# Plot EMA rewards
plt.figure(figsize=(10, 8))
plt.plot(episodic_rewards_ema)
plt.xlabel('Episode')
plt.ylabel('Total Reward (exponential moving average)')
plt.title('Total Reward per Episode')