In [8]:
# define the mlp network for lunarlander
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import numpy as np
from typing import Tuple
import os
import gymnasium as gym
from typing import Callable

def layer_init(layer: nn.Module, std: float=np.sqrt(2), bias: float=0.01) -> nn.Module:
    # initialize the linear/convolutional layer
    if hasattr(layer, "weight") and layer.weight is not None:
        nn.init.orthogonal_(layer.weight, gain=std)
    if hasattr(layer, "bias") and layer.bias is not None:
        nn.init.constant_(layer.bias, bias)
    return layer

class LunarLanderMLP(nn.Module):
    # initialize the mlp
    def __init__(self, envs):
        super().__init__()
        # define the critic network
        self.critic = nn.Sequential(
            layer_init(nn.Linear(8, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 1), std=1.0)
        )
        # define the actor network
        self.actor = nn.Sequential(
            layer_init(nn.Linear(8, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, 64)),
            nn.Tanh(),
            layer_init(nn.Linear(64, envs.single_action_space.n), std=0.01)
        )

    def get_value(self, x):
        return self.critic(x)

    def get_action_and_value(self, x, action=None):
        logits = self.actor(x) # [batch, num_actions]
        dist = Categorical(logits=logits) # [batch, num_actions] after softmax to construct a distribution
        if action is None: # output actions in sample stage
            action = dist.sample() # tips: tensor.sample() is wrong!!!
        log_prob = dist.log_prob(action) # [batch]
        entropy = dist.entropy() # [batch]
        value = self.critic(x).squeeze(-1) # [batch]
        return action, log_prob, entropy, value

In [9]:
@torch.no_grad()
def compute_gae(
    rewards: torch.Tensor, 
    dones: torch.Tensor, 
    values: torch.Tensor, 
    next_value: torch.Tensor, 
    next_done: torch.Tensor, 
    gamma: float, 
    gae_lambda: float) -> Tuple[torch.Tensor, torch.Tensor]:
    # calculate GAE: general advantage estimination
    # need: next_obs(next_value, next_done), rewards, dones, values
    advantages = torch.zeros_like(rewards)
    next_advantage = 0
    num_steps = rewards.shape[0]
    for t in reversed(range(num_steps)):
        # if num_step == 100, t = 99, 98 ... 0
        # if done == 0, mask == 1, game didn't stop
        # next_done and next_value is the current state (internal vars)
        # check if it is the last step
        if t == num_steps - 1:
            mask = 1.0 - next_done
            next_value = next_value 
        else:
            # if it is not the last step, check the next step in the buffer
            mask = 1.0 - dones[t] # tips: to match the input dones, we should use dones[t] instead pf dones[t+1]
            next_value = values[t + 1]
        # calculate the td_error/delta
        delta = rewards[t] + mask * gamma * next_value - values[t]
        # calculate the GAE
        current_advantage = delta + gamma * gae_lambda * mask * next_advantage
        advantages[t] = current_advantage # save current advantage
        next_advantage = current_advantage
    returns = advantages + values
    return advantages, returns

In [10]:
# define hyper-params
env_id = "LunarLander-v3"
num_envs = 32
capture_video = False
run_name = "exp_1"  # extracted as a variable for consistency

# environment factory function
def make_env(env_id: str, idx: int, capture_video: bool, run_name: str) -> Callable[[], gym.Env]:
    def thunk() -> gym.Env:
        # create the basic env
        if capture_video and idx == 0: # only capture the env_id == 0
            env = gym.make(env_id, render_mode='rgb_array')
            env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        else:
            env = gym.make(env_id)
        # add a wrapper to get the real rewards
        # what is a wrapper in gym?
        # wrapper can add functions for the env without modifying the source code
        # e.g., output more infos after env.step(action) or record videos
        env = gym.wrappers.RecordEpisodeStatistics(env)
        return env
    return thunk

# create the vector envs
envs = gym.vector.SyncVectorEnv(
    [make_env(env_id, i, capture_video, run_name) for i in range(num_envs)]
)   

assert isinstance(envs.single_action_space, gym.spaces.Discrete)
print(
    f"LunarLander envs are created! "
    f"obs_shape:{envs.single_observation_space.shape}; "
    f"action_n:{envs.single_action_space.n}"
)

LunarLander envs are created! obs_shape:(8,); action_n:4


In [11]:
# train loop
# learning_rate, num_steps, total_timesteps, num_envs, batch_size, minibatch_size, update_epochs
# gamma, gae_lambda, eps, ent_coef, v_loss_coef, obs_dim, device
# rollout for buffer data
learning_rate = 2.5e-4
num_steps = 2048 # sample steps per env
total_timesteps = 10000000 # total training steps
batch_size = int(num_envs * num_steps)
minibatch_size = 512 # batch for every grad update
update_epochs = 4
gamma = 0.99
gae_lambda = 0.95
eps = 0.2
ent_coef = 0.01
v_loss_coef = 0.5
obs_dim = envs.single_observation_space.shape[0]
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# define buffer
obs = torch.zeros([num_steps, num_envs, obs_dim], dtype=torch.float32, device=device)
# why we don't store actions and log_probs like [num_steps, num_envs, envs.single_action_space.n]?
# because for ppo's grad backward update, we only use the selected action and can't reselect another new action
actions = torch.zeros([num_steps, num_envs], dtype=torch.long, device=device)
log_probs = torch.zeros([num_steps, num_envs], dtype=torch.float32, device=device)
rewards = torch.zeros([num_steps, num_envs], dtype=torch.float32, device=device)
dones = torch.zeros([num_steps, num_envs], dtype=torch.float32, device=device)
values = torch.zeros([num_steps, num_envs], dtype=torch.float32, device=device)

# Initialize Agent and Optimizer
agent = LunarLanderMLP(envs).to(device)
optimizer = torch.optim.Adam(agent.parameters(), lr=learning_rate, eps=1e-5)

# Initialize environment state
next_obs_np, _ = envs.reset(seed=42)
next_obs_tensor = torch.as_tensor(next_obs_np, dtype=torch.float32, device=device)
next_done = torch.zeros([num_envs], dtype=torch.float32, device=device)

num_updates = total_timesteps // batch_size
global_step = 0

print("Start Training...")

for update in range(1, num_updates + 1):
    ep_returns, ep_lengths = [], []

    # rollout
    for t in range(num_steps):
        global_step += num_envs
        with torch.no_grad():
            action_tensor, log_prob_tensor, entropy_tensor, value_tensor = agent.get_action_and_value(next_obs_tensor)

        obs[t] = next_obs_tensor
        actions[t] = action_tensor
        log_probs[t] = log_prob_tensor
        values[t] = value_tensor

        action_np = action_tensor.cpu().numpy()
        next_obs_np, reward, terminated, truncated, info = envs.step(action_np)
        
        # Handle episodic info
        if isinstance(info, dict) and "episode" in info:
            ep = info["episode"]
            done_mask = ep.get("_r", info.get("_episode", None))
            if done_mask is not None:
                for i in np.where(done_mask)[0]:
                    ep_returns.append(float(ep["r"][i]))
                    ep_lengths.append(int(ep["l"][i]))
                    
        next_obs_tensor = torch.as_tensor(next_obs_np, dtype=torch.float32, device=device)
        done_np = np.logical_or(terminated, truncated)

        rewards[t] = torch.as_tensor(reward, dtype=torch.float32, device=device)
        dones[t] = torch.as_tensor(done_np, dtype=torch.float32, device=device)
        next_done = dones[t]

        # episodic stats (alternative location)
        if isinstance(info, dict) and "final_info" in info:
            for finfo in info["final_info"]:
                if finfo and "episode" in finfo:
                    ep_returns.append(float(finfo["episode"]["r"]))
                    ep_lengths.append(int(finfo["episode"]["l"]))

    # GAE
    with torch.no_grad():
        next_value = agent.get_value(next_obs_tensor).squeeze(-1)
    advantages, returns = compute_gae(
        rewards=rewards, dones=dones, values=values,
        next_value=next_value, next_done=next_done,
        gamma=gamma, gae_lambda=gae_lambda
    )

    # flatten
    obs_batch = obs.reshape(-1, obs_dim)
    actions_batch = actions.reshape(-1)
    log_probs_batch = log_probs.reshape(-1)
    advantages_batch = advantages.reshape(-1)
    returns_batch = returns.reshape(-1)
    values_batch = values.reshape(-1)

    # normalize advantages once per update
    advantages_batch = (advantages_batch - advantages_batch.mean()) / (advantages_batch.std() + 1e-8)

    # PPO update
    clipfracs = []
    last_pg_loss = last_v_loss = last_ent = None

    for epoch in range(update_epochs):
        b_inds = torch.randperm(batch_size, device=device)
        for start in range(0, batch_size, minibatch_size):
            mb_inds = b_inds[start:start + minibatch_size]

            _, new_log_prob, entropy, new_value = agent.get_action_and_value(
                obs_batch[mb_inds],
                action=actions_batch[mb_inds]
            )

            log_ratio = new_log_prob - log_probs_batch[mb_inds]
            ratio = log_ratio.exp()

            with torch.no_grad():
                clipfracs.append(((ratio - 1.0).abs() > eps).float().mean().item())

            mb_adv = advantages_batch[mb_inds]
            pg_loss1 = -mb_adv * ratio
            pg_loss2 = -mb_adv * torch.clamp(ratio, 1 - eps, 1 + eps)
            pg_loss = torch.max(pg_loss1, pg_loss2).mean()

            v_loss = 0.5 * (new_value - returns_batch[mb_inds]).pow(2).mean()

            ent = entropy.mean()
            loss = pg_loss - ent_coef * ent + v_loss_coef * v_loss

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(agent.parameters(), 0.5)
            optimizer.step()

            last_pg_loss, last_v_loss, last_ent = pg_loss, v_loss, ent

    # explained variance (per update)
    y_pred = values_batch.detach().cpu().numpy()
    y_true = returns_batch.detach().cpu().numpy()
    var_y = np.var(y_true)
    explained_var = np.nan if var_y == 0 else 1 - np.var(y_true - y_pred) / var_y

    # print once per update
    if len(ep_returns) > 0:
        print(
            f"Update {update}/{num_updates} | step={global_step} | "
            f"ep_ret_mean={np.mean(ep_returns):.1f} ep_len_mean={np.mean(ep_lengths):.1f} | "
            f"loss={float(loss.detach().cpu()):.3f} pg={float(last_pg_loss.detach().cpu()):.3f} "
            f"v={float(last_v_loss.detach().cpu()):.3f} ent={float(last_ent.detach().cpu()):.3f} | "
            f"clipfrac={np.mean(clipfracs):.3f} exp_var={explained_var:.3f}"
        )
    else:
        print(
            f"Update {update}/{num_updates} | step={global_step} | "
            f"loss={float(loss.detach().cpu()):.3f} pg={float(last_pg_loss.detach().cpu()):.3f} "
            f"v={float(last_v_loss.detach().cpu()):.3f} ent={float(last_ent.detach().cpu()):.3f} | "
            f"clipfrac={np.mean(clipfracs):.3f} exp_var={explained_var:.3f}"
        )

# save model params
os.makedirs("models", exist_ok=True)
model_path = f"models/{run_name}_final.pth"
torch.save(agent.state_dict(), model_path)
print(f"Training Completed. Model saved to: {model_path}")

Start Training...
Update 1/152 | step=65536 | ep_ret_mean=-184.6 ep_len_mean=91.4 | loss=286.192 pg=-0.021 v=572.452 ent=1.379 | clipfrac=0.025 exp_var=-0.007
Update 2/152 | step=131072 | ep_ret_mean=-156.0 ep_len_mean=96.8 | loss=188.643 pg=-0.096 v=377.507 ent=1.357 | clipfrac=0.042 exp_var=-0.011
Update 3/152 | step=196608 | ep_ret_mean=-143.6 ep_len_mean=106.0 | loss=134.204 pg=0.027 v=268.380 ent=1.325 | clipfrac=0.053 exp_var=-0.001
Update 4/152 | step=262144 | ep_ret_mean=-132.5 ep_len_mean=115.3 | loss=112.786 pg=-0.067 v=225.731 ent=1.281 | clipfrac=0.059 exp_var=-0.000
Update 5/152 | step=327680 | ep_ret_mean=-103.1 ep_len_mean=132.3 | loss=103.987 pg=-0.005 v=208.009 ent=1.227 | clipfrac=0.106 exp_var=-0.000
Update 6/152 | step=393216 | ep_ret_mean=-80.7 ep_len_mean=147.9 | loss=101.929 pg=0.008 v=203.866 ent=1.179 | clipfrac=0.064 exp_var=-0.002
Update 7/152 | step=458752 | ep_ret_mean=-49.7 ep_len_mean=181.3 | loss=79.546 pg=0.004 v=159.107 ent=1.133 | clipfrac=0.051 exp_v

In [12]:
# Evaluation and Video Recording
print("Starting Evaluation and Recording...")

video_dir = f"videos/{run_name}-final"
os.makedirs(video_dir, exist_ok=True)

# Create evaluation environment (Evaluation needs only 1 env)
# Set max_episode_steps to avoid early truncation during recording
eval_env = gym.make(env_id, render_mode="rgb_array", max_episode_steps=5000)
eval_env = gym.wrappers.RecordVideo(
    eval_env,
    video_folder=video_dir,
    episode_trigger=lambda episode_id: True, # Record all episodes (we run only 1 here)
    name_prefix="final",
)
eval_env = gym.wrappers.RecordEpisodeStatistics(eval_env)

obs, _ = eval_env.reset(seed=123)
done = False
ep_ret = 0.0

while not done:
    obs_t = torch.as_tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)  # [1, obs_dim]
    with torch.no_grad():
        logits = agent.actor(obs_t)
        action = int(torch.argmax(logits, dim=-1).item())  # use argmax for deterministic

    obs, reward, terminated, truncated, info = eval_env.step(action)
    done = bool(terminated or truncated)
    ep_ret += float(reward)

eval_env.close()
print("final eval return:", ep_ret, "video saved to:", video_dir)

Starting Evaluation and Recording...
final eval return: 214.93616379451146 video saved to: videos/exp_1-final
