PPO
Half Cheetah

In [None]:
import os
os.environ["MUJOCO_GL"] = "egl"
os.environ["PYOPENGL_PLATFORM"] = "egl"

In [None]:
!pip install gymnasium[mujoco]



In [None]:
import gymnasium as gym
import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import imageio
import os
from google.colab import files

In [None]:
def set_seed(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [None]:
# Actor Critic Network (shared backbone architecture)

class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, hidden=256):
        super().__init__()

        self.shared = nn.Sequential(
            nn.Linear(state_dim, hidden),
            nn.Tanh(),
            nn.Linear(hidden, hidden),
            nn.Tanh(),
        )

        self.mu_head = nn.Linear(hidden, action_dim)
        self.log_std = nn.Parameter(torch.ones(action_dim) * -0.5)
        self.value_head = nn.Linear(hidden, 1)

    def forward(self, x):
        z = self.shared(x)
        mu = self.mu_head(z)

        log_std = torch.clamp(self.log_std, -5, 2)
        std = torch.exp(log_std)

        value = self.value_head(z).squeeze(-1)
        return mu, std, value

    def get_action(self, state):
        mu, std, value = self.forward(state)
        dist = torch.distributions.Normal(mu, std)

        u = dist.rsample()
        action = torch.tanh(u)

        logp = dist.log_prob(u)
        logp -= torch.log(1 - action.pow(2) + 1e-6)
        logp = logp.sum(-1)

        entropy = dist.entropy().sum(-1)
        return action, logp, entropy, value

    def evaluate_actions(self, states, actions):
        mu, std, value = self.forward(states)
        dist = torch.distributions.Normal(mu, std)

        eps = 1e-6
        actions = torch.clamp(actions, -1 + eps, 1 - eps)
        u = 0.5 * torch.log((1 + actions) / (1 - actions))

        logp = dist.log_prob(u)
        logp -= torch.log(1 - actions.pow(2) + 1e-6)
        logp = logp.sum(-1)

        entropy = dist.entropy().sum(-1)
        return logp, entropy, value

    def act_deterministic(self, state):
        with torch.no_grad():
            mu, _, _ = self.forward(state)
            return torch.tanh(mu)


# Computing GAE

@torch.no_grad()
def compute_gae(rewards, dones, values, next_value, gamma=0.99, gae_lambda=0.95):
    T = rewards.shape[0]
    advantages = torch.zeros(T, device=rewards.device)

    gae = 0
    for t in reversed(range(T)):
        mask = 1.0 - dones[t]
        v_next = next_value if t == T - 1 else values[t + 1]
        delta = rewards[t] + gamma * v_next * mask - values[t]
        gae = delta + gamma * gae_lambda * mask * gae
        advantages[t] = gae

    returns = advantages + values
    return advantages.detach(), returns.detach()


# Training setup

seed = 0
set_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

env = gym.make("HalfCheetah-v5")
eval_env = gym.make("HalfCheetah-v5", render_mode="rgb_array")

os.makedirs("videos", exist_ok=True)

state_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]

initial_lr = 3e-4
model = ActorCritic(state_dim, action_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=initial_lr)

# PPO hyperparameters
gamma = 0.99
gae_lambda = 0.95
clip_eps = 0.2
vf_coef = 0.5
ent_coef = 0.01
max_grad_norm = 0.5
ppo_epochs = 10
mini_batch_size = 64

rollout_len = 2048
total_timesteps = 2_000_000
eval_interval = 20_000

global_step = 0
state, _ = env.reset()

# PPO training loop

while global_step < total_timesteps:

    # LR Annealing
    frac = 1.0 - (global_step / total_timesteps)
    lr_now = initial_lr * frac
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr_now

    states, actions, rewards, dones, values, logps = [], [], [], [], [], []
    rollout_rewards = []

    # Rollouts
    for _ in range(rollout_len):

        state_tensor = torch.tensor(state, dtype=torch.float32, device=device)
        action, logp, entropy, value = model.get_action(state_tensor)

        next_state, reward, terminated, truncated, _ = env.step(
            action.detach().cpu().numpy()
        )

        done = terminated or truncated

        states.append(state_tensor)
        actions.append(action.detach())
        rewards.append(torch.tensor(reward, dtype=torch.float32, device=device))
        dones.append(torch.tensor(float(done), dtype=torch.float32, device=device))
        values.append(value.detach())
        logps.append(logp.detach())

        rollout_rewards.append(reward)

        state = next_state
        global_step += 1

        if done:
            state, _ = env.reset()

        if global_step >= total_timesteps:
            break

    mean_rollout_reward = np.mean(rollout_rewards)

    states = torch.stack(states)
    actions = torch.stack(actions)
    rewards = torch.stack(rewards)
    dones = torch.stack(dones)
    values = torch.stack(values)
    logps_old = torch.stack(logps)

    with torch.no_grad():
        last_state_tensor = torch.tensor(state, dtype=torch.float32, device=device)
        _, _, next_value = model.forward(last_state_tensor)

    # Advantage Norm
    advantages, returns = compute_gae(
        rewards, dones, values, next_value, gamma, gae_lambda
    )

    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

    dataset_size = states.size(0)

    # Logging metrics
    approx_kl_total = 0
    clip_frac_total = 0
    explained_var_total = 0
    update_steps = 0

    # PPO main training loop
    for _ in range(ppo_epochs):

        indices = torch.randperm(dataset_size, device=device)

        for start in range(0, dataset_size, mini_batch_size):
            end = start + mini_batch_size
            batch_idx = indices[start:end]

            mb_states = states[batch_idx]
            mb_actions = actions[batch_idx]
            mb_advantages = advantages[batch_idx]
            mb_returns = returns[batch_idx]
            mb_logps_old = logps_old[batch_idx]
            mb_values_old = values[batch_idx]

            logps_new, entropy, values_new = model.evaluate_actions(
                mb_states, mb_actions
            )

            ratio = torch.exp(logps_new - mb_logps_old)

            surr1 = ratio * mb_advantages
            surr2 = torch.clamp(ratio, 1 - clip_eps, 1 + clip_eps) * mb_advantages      # PPO clipping

            policy_loss = -torch.min(surr1, surr2).mean()

            value_clipped = mb_values_old + torch.clamp(
                values_new - mb_values_old,
                -clip_eps,
                clip_eps,
            )

            # Value function clipping
            v_loss1 = (values_new - mb_returns).pow(2)
            v_loss2 = (value_clipped - mb_returns).pow(2)
            value_loss = 0.5 * torch.max(v_loss1, v_loss2).mean()

            entropy_loss = entropy.mean()

            # PPO Loss
            loss = policy_loss + vf_coef * value_loss - ent_coef * entropy_loss

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)      # Gradient clipping
            optimizer.step()

            # Logging stats
            approx_kl = (mb_logps_old - logps_new).mean().item()
            clip_frac = ((ratio - 1.0).abs() > clip_eps).float().mean().item()
            explained_var = 1 - torch.var(mb_returns - values_new) / (
                torch.var(mb_returns) + 1e-8
            )

            approx_kl_total += approx_kl
            clip_frac_total += clip_frac
            explained_var_total += explained_var.item()
            update_steps += 1

    print(
        f"Step: {global_step} | "
        f"LR: {lr_now:.6f} | "
        f"Reward: {mean_rollout_reward:.2f} | "
        f"KL: {approx_kl_total/update_steps:.5f} | "
        f"ClipFrac: {clip_frac_total/update_steps:.3f} | "
        f"ExplVar: {explained_var_total/update_steps:.3f}"
    )

    # Evaluation
    if global_step % eval_interval < rollout_len:

        s, _ = eval_env.reset()
        done = False
        ep_return = 0
        frames = []

        while not done:
            s_tensor = torch.tensor(s, dtype=torch.float32, device=device)
            action = model.act_deterministic(s_tensor)

            # Determininstic evaluation
            s, r, terminated, truncated, _ = eval_env.step(
                action.cpu().numpy()
            )
            done = terminated or truncated
            ep_return += r

            frame = eval_env.render()
            if frame is not None:
                frames.append(frame)

        print(f"[Eval @ {global_step}] Return: {ep_return:.2f}")

        video_path = f"videos/ppo_halfcheetah_eval_{global_step}.mp4"
        imageio.mimsave(video_path, frames, fps=30)
        print(f"Saved video: {video_path}")

env.close()
eval_env.close()


Device: cuda
Step: 2048 | LR: 0.000300 | Reward: -0.22 | KL: 0.07874 | ClipFrac: 0.551 | ExplVar: 0.026
Step: 4096 | LR: 0.000300 | Reward: -0.07 | KL: 0.07960 | ClipFrac: 0.543 | ExplVar: 0.096
Step: 6144 | LR: 0.000299 | Reward: -0.12 | KL: 0.05712 | ClipFrac: 0.496 | ExplVar: 0.169
Step: 8192 | LR: 0.000299 | Reward: -0.09 | KL: 0.05150 | ClipFrac: 0.453 | ExplVar: 0.233
Step: 10240 | LR: 0.000299 | Reward: -0.07 | KL: 0.04710 | ClipFrac: 0.466 | ExplVar: 0.273
Step: 12288 | LR: 0.000298 | Reward: 0.05 | KL: 0.04600 | ClipFrac: 0.435 | ExplVar: 0.297
Step: 14336 | LR: 0.000298 | Reward: -0.02 | KL: 0.03612 | ClipFrac: 0.416 | ExplVar: 0.310
Step: 16384 | LR: 0.000298 | Reward: 0.01 | KL: 0.04392 | ClipFrac: 0.415 | ExplVar: 0.368
Step: 18432 | LR: 0.000298 | Reward: -0.02 | KL: 0.03709 | ClipFrac: 0.413 | ExplVar: 0.383
Step: 20480 | LR: 0.000297 | Reward: -0.06 | KL: 0.03929 | ClipFrac: 0.399 | ExplVar: 0.374
[Eval @ 20480] Return: 1027.99
Saved video: videos/ppo_halfcheetah_eval_2

In [None]:
!zip -r videos.zip /content/videos

  adding: content/videos/ (stored 0%)
  adding: content/videos/ppo_halfcheetah_eval_221184.mp4 (deflated 1%)
  adding: content/videos/ppo_halfcheetah_eval_20480.mp4 (deflated 1%)
  adding: content/videos/ppo_halfcheetah_eval_2000000.mp4 (deflated 1%)
  adding: content/videos/ppo_halfcheetah_eval_380928.mp4 (deflated 1%)
  adding: content/videos/ppo_halfcheetah_eval_1120256.mp4 (deflated 1%)
  adding: content/videos/ppo_halfcheetah_eval_1161216.mp4 (deflated 1%)
  adding: content/videos/ppo_halfcheetah_eval_81920.mp4 (deflated 1%)
  adding: content/videos/ppo_halfcheetah_eval_1021952.mp4 (deflated 1%)
  adding: content/videos/ppo_halfcheetah_eval_1880064.mp4 (deflated 1%)
  adding: content/videos/ppo_halfcheetah_eval_342016.mp4 (deflated 1%)
  adding: content/videos/ppo_halfcheetah_eval_501760.mp4 (deflated 1%)
  adding: content/videos/ppo_halfcheetah_eval_141312.mp4 (deflated 1%)
  adding: content/videos/ppo_halfcheetah_eval_1761280.mp4 (deflated 1%)
  adding: content/videos/ppo_halfch

In [None]:
files.download('/content/videos.zip')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>