In [1]:
import numpy as np
import pufferlib, pufferlib.vector
from pufferlib.environments import classic_control

In [2]:
num_envs = 12

In [3]:
vecenv = pufferlib.vector.make(
    classic_control.env_creator("CartPole-v1"),
    num_envs=num_envs,
    backend=pufferlib.vector.Multiprocessing,
)

In [4]:
env_rewards = [list() for _ in range(num_envs)]
env_rewards

returns = []

from itertools import count

dones = np.array([False] * num_envs)
truncateds = np.array([False] * num_envs)
dones, truncateds

obs, _ = vecenv.reset()
episodes = 0
for t in count():
    obs, rewards, dones, truncateds, _ = vecenv.step(vecenv.action_space.sample())
    for i, reward in enumerate(rewards):
        env_rewards[i].append(reward)
    for i in np.where(dones | truncateds)[0]:
        returns.append(sum(env_rewards[i]))
        env_rewards[i] = []
        episodes += 1
    if episodes >= 1_000:
        break

np.mean([np.sum(rets) for rets in returns]), t

(22.007, 1924)

In [5]:
import torch
from torch import nn

In [6]:
class PPONetwork(nn.Module):
    def __init__(self, n_input, n_hiddens, n_actions):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(n_input, n_hiddens),
            nn.ReLU(),
        )

        self.policy_head = nn.Linear(n_hiddens, n_actions)
        self.value_head = nn.Linear(n_hiddens, 1)

    def forward(self, x):
        device = next(self.mlp.parameters()).device
        x = torch.tensor(np.array(x), dtype=torch.float32, device=device)

        x = self.mlp(x)
        logits = self.policy_head(x)
        value = self.value_head(x)
        
        return logits, value

In [7]:
import random

In [8]:
def select_actions(logits):
    dist = torch.distributions.Categorical(logits=logits) # shape [batch_size, n_actions]
    
    actions = dist.sample() # Shape: [n_actions]
    probs = dist.probs

    return actions.detach().cpu(), probs.detach().cpu()

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [10]:
from torch import optim

from itertools import count
from tqdm.auto import tqdm

import wandb

In [None]:
def train(
    n_epochs = 8_000,
    batch_size = 128,
    learning_rate = 1e-4,
    gamma = 0.99,
    clip_param = 0.2,
    leave_bar = True,
):  
    net = PPONetwork(n_input, n_hiddens, n_actions).to(device)
    optimizer = optim.Adam(net.parameters(), lr=learning_rate)

    if wandb.run is not None:
        wandb.watch(net, log="all")

    batch_states = []
    batch_values = []
    batch_actions = []
    batch_probs = []
    batch_rewards = []
    
    obs, _ = vecenv.reset()
    
    try:
        for epoch in tqdm(num_epochs, Epochs", leave=leave_bar, desc="Epochs"):
            # Collect one rollout for each env
            for t in range(batch_size):
                with torch.no_grad():
                    logits, values = net(obs)
                    actions, probs = select_actions(logits)
        
                next_obs, rewards, dones, truncateds, _ = vecenv.step(actions)

                batch_states.append(obs.copy())
                batch_values.append(values)
                batch_actions.append(actions)
                batch_probs.append(probs)
                batch_rewards.append(rewards / 500.0)
                batch_dones.append(dones | truncateds)

                obs = next_obs.copy()

            # Append the final states and values for computing the final losses
            with torch.no_grad():
                _, final_values = net(obs)
                batch_states[i].append(o.copy())
                batch_values[i].append(v)

            # Initialize the losses
            value_loss = 0.0
            policy_loss = 0.0
            
            # Calculate losses for each step in the batch
            for t in reversed(range(batch_size)):
                pass
                    
            batch_states = []
            batch_values = []
            batch_actions = []
            batch_probs = []
            batch_rewards = []

    except KeyboardInterrupt:
        print("Training halted manually.")
        
    # Finalize the run
    model_path = "./models/PPO_scratch_CartPole_latest.pt"
    torch.save(net.state_dict(), model_path)
    wandb.log_model(model_path, "latest")
    wandb.unwatch()
    net.eval()
    
    return net

In [None]:
n_input = vecenv.single_observation_space.shape[0]
n_actions = vecenv.single_action_space.n
n_hiddens = 16

n_episodes = 100_000
learning_rate = 1e-3
gamma = 0.99
entropy_coef = 0.0
min_update_batch_size = 16

project = "REINFORCE-scratch-CartPole"
config = {
    "num_envs": num_envs,
    "n_hiddens": n_hiddens,
    "n_episodes": n_episodes,
    "learning_rate": learning_rate,
    "entropy_coef": entropy_coef,
    "min_update_batch_size": min_update_batch_size,
}

wandb.init(
    project=project,
    config=config,
)
    
net, mean_return, max_return, t = train(
    n_episodes=n_episodes,
    learning_rate=learning_rate,
    gamma=gamma,
    entropy_coef=entropy_coef,
    min_update_batch_size=min_update_batch_size,
)

print(f"Mean: {mean_return * 500:.2f}, Max: {max_return * 500:.2f}")

In [None]:
import gymnasium as gym
import numpy as np
from datetime import datetime
from pathlib import Path

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
video_folder = f"./videos/REINFORCE_scratch_CartPole_{timestamp}"

env = gym.make("CartPole-v1", render_mode="rgb_array")
env = gym.wrappers.RecordVideo(env, video_folder)

epsilon = 0.0

for _ in range(9):
    ob, _ = env.reset()
    ob = np.expand_dims(ob, 0)
    ret = 0
    done, truncated = False, False
    while not (done or truncated):
        logits = net(ob)
        actions, *_ = select_actions(logits)
        ob, reward, done, truncated, _ = env.step(actions[0].numpy())
        ob = np.expand_dims(ob, 0)
        ret += reward
    
    print(ret)

if wandb.run is not None:
    latest_video = max(
        Path(video_folder).glob("*.mp4"),
        key=lambda x: x.stat().st_mtime
    )
    wandb.log({
        "video": wandb.Video(str(latest_video))
    })
    wandb.finish()
    
env.close()