# PPO for Mujoco InvertedPendulum-v5

This notebook demonstrates Proximal Policy Optimization (PPO) for the Mujoco InvertedPendulum-v5 environment using vectorized environments and PyTorch. Each section is modular and can be run independently, with variables persisting across cells.

In [31]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

# Speed up: use vectorized envs
def make_env():
    return gym.make("InvertedPendulum-v5", render_mode=None, disable_env_checker=True)

envs = gym.vector.AsyncVectorEnv([make_env for _ in range(8)])  # 8 parallel envs
obs_dim = envs.single_observation_space.shape[0]
act_dim = envs.single_action_space.shape[0]

In [32]:
class ActorCritic(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.actor = nn.Sequential(
            nn.Linear(obs_dim, 64), nn.Tanh(),
            nn.Linear(64, 64), nn.Tanh(),
            nn.Linear(64, act_dim)
        )
        self.critic = nn.Sequential(
            nn.Linear(obs_dim, 64), nn.Tanh(),
            nn.Linear(64, 64), nn.Tanh(),
            nn.Linear(64, 1)
        )
        self.log_std = nn.Parameter(torch.zeros(act_dim))

    def forward(self, x):
        return self.actor(x), self.critic(x)

In [33]:
def get_action_and_value(model, obs, device):
    obs_tensor = torch.from_numpy(obs).float().to(device)
    mu, value = model(obs_tensor)
    std = model.log_std.exp()
    dist = torch.distributions.Normal(mu, std)
    action = dist.sample()
    log_prob = dist.log_prob(action).sum(-1)
    return action.cpu().detach().numpy(), log_prob.cpu().detach().numpy(), value.cpu().detach().numpy()

In [34]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print which device is being used
print(f"Using device: {device}")
model = ActorCritic(obs_dim, act_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=3e-4)

# PPO hyperparameters
clip_eps = 0.2
epochs = 10
steps_per_epoch = 2048
gamma = 0.99
lam = 0.95

Using device: cuda


In [35]:
num_updates = 50
reward_history = []
for update in range(num_updates):
    obs = envs.reset()[0]
    obs_buf, act_buf, logp_buf, rew_buf, val_buf, done_buf = [], [], [], [], [], []
    for step in range(steps_per_epoch):
        action, logp, value = get_action_and_value(model, obs, device)
        next_obs, reward, done, trunc, info = envs.step(action)
        obs_buf.append(obs)
        act_buf.append(action)
        logp_buf.append(logp)
        rew_buf.append(reward)
        val_buf.append(value)
        done_buf.append(done)
        obs = next_obs
    obs_buf = np.array(obs_buf)
    act_buf = np.array(act_buf)
    logp_buf = np.array(logp_buf)
    rew_buf = np.array(rew_buf)
    val_buf = np.array(val_buf)
    if val_buf.ndim == 3 and val_buf.shape[2] == 1:
        val_buf = val_buf.squeeze(-1)
    done_buf = np.array(done_buf)
    adv_buf = np.zeros_like(rew_buf)
    for env_idx in range(envs.num_envs):
        lastgaelam_env = 0
        for t in reversed(range(steps_per_epoch)):
            if t == steps_per_epoch - 1:
                nextnonterminal = 1.0 - done_buf[t, env_idx]
                nextvalue = val_buf[t, env_idx]
            else:
                nextnonterminal = 1.0 - done_buf[t+1, env_idx]
                nextvalue = val_buf[t+1, env_idx]
            delta = rew_buf[t, env_idx] + gamma * nextvalue * nextnonterminal - val_buf[t, env_idx]
            lastgaelam_env = delta + gamma * lam * nextnonterminal * lastgaelam_env
            adv_buf[t, env_idx] = lastgaelam_env
    returns = adv_buf + val_buf
    obs_flat = torch.from_numpy(obs_buf.reshape(-1, obs_dim)).float().to(device)
    act_flat = torch.from_numpy(act_buf.reshape(-1, act_dim)).float().to(device)
    logp_flat = torch.from_numpy(logp_buf.flatten()).float().to(device)
    adv_flat = torch.from_numpy(adv_buf.flatten()).float().to(device)
    ret_flat = torch.from_numpy(returns.flatten()).float().to(device)
    for _ in range(epochs):
        mu, value = model(obs_flat)
        std = model.log_std.exp()
        dist = torch.distributions.Normal(mu, std)
        new_logp = dist.log_prob(act_flat).sum(-1)
        ratio = (new_logp - logp_flat).exp()
        surr1 = ratio * adv_flat
        surr2 = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * adv_flat
        policy_loss = -torch.min(surr1, surr2).mean()
        value_loss = ((ret_flat - value.squeeze()) ** 2).mean()
        loss = policy_loss + 0.5 * value_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    mean_reward = np.mean(rew_buf)
    print(f"Update {update}: mean reward {mean_reward}")
    reward_history.append(mean_reward)

Update 0: mean reward 0.78369140625
Update 1: mean reward 0.7960205078125
Update 1: mean reward 0.7960205078125
Update 2: mean reward 0.81549072265625
Update 2: mean reward 0.81549072265625
Update 3: mean reward 0.835693359375
Update 3: mean reward 0.835693359375
Update 4: mean reward 0.8519287109375
Update 4: mean reward 0.8519287109375
Update 5: mean reward 0.871826171875
Update 5: mean reward 0.871826171875
Update 6: mean reward 0.88946533203125
Update 6: mean reward 0.88946533203125
Update 7: mean reward 0.9039306640625
Update 7: mean reward 0.9039306640625
Update 8: mean reward 0.916015625
Update 8: mean reward 0.916015625
Update 9: mean reward 0.92535400390625
Update 9: mean reward 0.92535400390625
Update 10: mean reward 0.93359375
Update 10: mean reward 0.93359375
Update 11: mean reward 0.9417724609375
Update 11: mean reward 0.9417724609375
Update 12: mean reward 0.9488525390625
Update 12: mean reward 0.9488525390625
Update 13: mean reward 0.954345703125
Update 13: mean reward 0

In [37]:
render_env = gym.make("InvertedPendulum-v5", render_mode="human", disable_env_checker=True)
obs = render_env.reset()[0]
done = False
while not done:
    action, _, _ = get_action_and_value(model, obs, device)
    obs, reward, done, trunc, info = render_env.step(action)
    if done or trunc:
        break
render_env.close()

/home/zle/miniconda3/envs/env_MFRL/lib/python3.13/site-packages/glfw/__init__.py:917: GLFWError: (65537) b'The GLFW library is not initialized'
