<a href="https://colab.research.google.com/github/ZhengyuanCui/DRL/blob/main/Run_gymnasium.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# System dependencies for headless rendering
!apt-get update
!apt-get install -y libosmesa6-dev libgl1-mesa-glx libglfw3 libglew-dev patchelf ffmpeg

# Only install gym with MuJoCo support, skip Box2D
!pip install "gymnasium[mujoco]" mujoco==2.3.7 imageio matplotlib

Get:1 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease [1,581 B]
Get:2 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,632 B]
Get:3 https://r2u.stat.illinois.edu/ubuntu jammy InRelease [6,555 B]
Get:4 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  Packages [1,853 kB]
Hit:5 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:6 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Get:7 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]
Get:8 https://r2u.stat.illinois.edu/ubuntu jammy/main all Packages [9,161 kB]
Hit:9 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Hit:10 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Get:11 http://security.ubuntu.com/ubuntu jammy-security/main amd64 Packages [3,207 kB]
Get:12 http://archive.ubuntu.com/ubuntu jammy-backports InRelease [127 kB]
Hit:13 https://ppa.launchpadcontent.net/

In [None]:
import os
os.environ['MUJOCO_GL'] = 'egl'  # headless rendering

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Categorical

In [None]:
!pip install pyvirtualdisplay

Collecting pyvirtualdisplay
  Downloading PyVirtualDisplay-3.0-py3-none-any.whl.metadata (943 bytes)
Downloading PyVirtualDisplay-3.0-py3-none-any.whl (15 kB)
Installing collected packages: pyvirtualdisplay
Successfully installed pyvirtualdisplay-3.0


In [None]:
import gymnasium as gym
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.distributions import Normal # Changed to Normal for continuous actions

# Define the policy network for continuous actions
class PolicyNetwork(nn.Module):
    def __init__(self, input_size, output_size):
        super(PolicyNetwork, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.fc2 = nn.Linear(64, 64)
        # Output mean and standard deviation for a Normal distribution
        self.fc_mean = nn.Linear(64, output_size)
        self.fc_log_std = nn.Linear(64, output_size) # Predict log standard deviation

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        mean = self.fc_mean(x)
        # Ensure standard deviation is positive using softplus and a small epsilon
        log_std = self.fc_log_std(x)
        std = torch.exp(log_std) # Exponentiate log_std to get std
        return mean, std

# Define the value network (remains the same)
class ValueNetwork(nn.Module):
    def __init__(self, input_size):
        super(ValueNetwork, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        value = self.fc3(x)
        return value

# PPO Algorithm for continuous actions
class PPO:
    def __init__(self, env, policy, value_net, gamma=0.99, clip_eps=0.2, lr=3e-4):
        self.env = env
        self.policy = policy
        self.value_net = value_net
        self.gamma = gamma
        self.clip_eps = clip_eps
        self.optimizer_policy = optim.Adam(policy.parameters(), lr=lr)
        self.optimizer_value = optim.Adam(value_net.parameters(), lr=lr)

    def compute_advantages(self, rewards, values, next_values, dones):
        advantages = []
        gae = 0
        for t in reversed(range(len(rewards))):
            delta = rewards[t] + self.gamma * next_values[t] * (1 - dones[t]) - values[t]
            gae = delta + self.gamma * 0.95 * (1 - dones[t]) * gae
            advantages.insert(0, gae)
        return advantages

    def update_policy(self, states, actions, log_probs_old, advantages):
        # Convert to tensors
        states = torch.tensor(states, dtype=torch.float32)
        actions = torch.tensor(actions, dtype=torch.float32) # Actions are float for continuous spaces
        log_probs_old = torch.tensor(log_probs_old, dtype=torch.float32)
        advantages = torch.tensor(advantages, dtype=torch.float32)

        # Compute new action probabilities (log_probs)
        mean, std = self.policy(states)
        dist = Normal(mean, std)
        log_probs = dist.log_prob(actions).sum(axis=-1) # Sum log_probs for multi-dimensional actions

        # Compute the ratio (pi_theta / pi_theta_old)
        ratio = torch.exp(log_probs - log_probs_old)

        # Compute the clipped objective function
        obj = ratio * advantages
        obj_clipped = torch.clamp(ratio, 1 - self.clip_eps, 1 + self.clip_eps) * advantages
        loss = -torch.min(obj, obj_clipped).mean()

        # Update policy
        self.optimizer_policy.zero_grad()
        loss.backward()
        self.optimizer_policy.step()

    def update_value(self, states, returns):
        # Convert to tensor
        states = torch.tensor(states, dtype=torch.float32)
        returns = torch.tensor(returns, dtype=torch.float32)

        # Compute value loss
        values = self.value_net(states)
        value_loss = (returns - values).pow(2).mean()

        # Update value network
        self.optimizer_value.zero_grad()
        value_loss.backward()
        self.optimizer_value.step()

    def train(self, total_timesteps=1000000, batch_size=64, update_every=2000):
        timestep = 0
        while timestep < total_timesteps:
            # Collect experience
            states, actions, rewards, dones, log_probs, next_states = [], [], [], [], [], []
            state, _ = self.env.reset() # reset returns a tuple in gymnasium
            done = False

            while not done:
                mean, std = self.policy(torch.tensor(state, dtype=torch.float32))
                dist = Normal(mean, std)
                action = dist.sample()
                log_prob = dist.log_prob(action).sum(axis=-1) # Sum log_probs for multi-dimensional actions

                next_state, reward, terminated, truncated, _ = self.env.step(action.numpy()) # step returns 5 values in gymnasium
                done = terminated or truncated

                states.append(state)
                actions.append(action.numpy()) # Store actions as numpy arrays
                rewards.append(reward)
                dones.append(done)
                log_probs.append(log_prob.item())
                next_states.append(next_state)

                state = next_state
                timestep += 1

                if timestep % update_every == 0:
                    # Compute value estimates
                    values = [self.value_net(torch.tensor(s, dtype=torch.float32)).item() for s in states]
                    next_values = [self.value_net(torch.tensor(s, dtype=torch.float32)).item() for s in next_states]
                    advantages = self.compute_advantages(rewards, values, next_values, dones)

                    # Update the policy and value networks
                    self.update_policy(states, actions, log_probs, advantages)
                    self.update_value(states, rewards)

                    # Clear collected experience after update
                    states, actions, rewards, dones, log_probs, next_states = [], [], [], [], [], []
                    state, _ = self.env.reset() # reset after update
                    done = False


            # Optionally, save the model here after training

# Create environment
env = gym.make('Ant-v5')

# Initialize networks
# The output size of the policy network should match the action space dimension
policy = PolicyNetwork(input_size=env.observation_space.shape[0], output_size=env.action_space.shape[0])
value_net = ValueNetwork(input_size=env.observation_space.shape[0])

# Create PPO agent
ppo_agent = PPO(env, policy, value_net)

# Train the agent
ppo_agent.train(total_timesteps=1000000)

  states = torch.tensor(states, dtype=torch.float32)


In [None]:
# Setup display for headless rendering
from IPython import display as ipythondisplay
from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
display.start()

# Import necessary libraries for rendering
import imageio
import base64
from pathlib import Path

# Function to display video
def display_video(video_path):
    mp4 = open(video_path, 'rb').read()
    data_url = "data:video/mp4;base64," + base64.b64encode(mp4).decode()
    ipythondisplay.display(ipythondisplay.HTML(f'<video width="600" controls><source src="{data_url}" type="video/mp4"></video>'))

# Create environment with rendering enabled
env_eval = gym.make('Ant-v5', render_mode='rgb_array')

# Run a few episodes with the trained policy and render
num_eval_episodes = 3
video_filename = "ant_v5_evaluation.mp4"

with imageio.get_writer(video_filename, fps=60) as video:
    for episode in range(num_eval_episodes):
        state, _ = env_eval.reset()
        done = False
        while not done:
            # Get action from the policy (using the trained policy from the previous cell)
            # Ensure the state is a tensor and add a batch dimension
            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            mean, std = ppo_agent.policy(state_tensor)
            dist = Normal(mean, std)
            action = dist.sample().squeeze(0).detach().numpy() # Sample action and convert to numpy

            next_state, reward, terminated, truncated, _ = env_eval.step(action)
            frame = env_eval.render()
            video.append_data(frame)

            state = next_state
            done = terminated or truncated

# Close the environment
env_eval.close()

# Display the generated video
display_video(video_filename)

# Stop the virtual display (optional, but good practice)
display.stop()

<pyvirtualdisplay.display.Display at 0x7aba5072ca50>