In [1]:
import numpy as np
import pufferlib, pufferlib.vector
from pufferlib.environments import classic_control

In [2]:
num_envs = 12

In [3]:
vecenv = pufferlib.vector.make(
    classic_control.env_creator("CartPole-v1"),
    num_envs=num_envs,
    backend=pufferlib.vector.Multiprocessing,
)

Process Process-4:
Process Process-7:
Process Process-5:
Process Process-1:
Process Process-2:
Process Process-11:
Process Process-6:
Process Process-12:
Process Process-3:
Process Process-10:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/fitti/.conda/envs/puffer/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/fitti/.conda/envs/puffer/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/fitti/.conda/envs/puffer/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/fitti/.conda/envs/puffer/lib/python3.11/multiprocessing/process.py", line 314

In [4]:
env_rewards = [list() for _ in range(num_envs)]
env_rewards

returns = []

from itertools import count

dones = np.array([False] * num_envs)
truncateds = np.array([False] * num_envs)
dones, truncateds

obs, _ = vecenv.reset()
episodes = 0
for t in count():
    obs, rewards, dones, truncateds, _ = vecenv.step(vecenv.action_space.sample())
    for i, reward in enumerate(rewards):
        env_rewards[i].append(reward)
    for i in np.where(dones | truncateds)[0]:
        returns.append(sum(env_rewards[i]))
        env_rewards[i] = []
        episodes += 1
    if episodes >= 1_000:
        break

np.mean([np.sum(rets) for rets in returns]), t

(22.166, 1946)

In [5]:
import torch
from torch import nn

In [6]:
class PolicyNetwork(nn.Module):
    def __init__(self, n_input, n_hiddens, n_actions):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(n_input, n_hiddens),
            nn.ReLU(),
            nn.Linear(n_hiddens, n_actions),
        )

    def forward(self, x):
        device = next(self.mlp.parameters()).device
        x = torch.tensor(np.array(x), dtype=torch.float32, device=device)
        
        return self.mlp(x)

In [7]:
import random

In [8]:
def select_actions(logits):
    dist = torch.distributions.Categorical(logits=logits) # shape [batch_size, n_actions]
    
    actions = dist.sample() # Shape: [n_actions]
    log_probs = dist.log_prob(actions)
    entropies = dist.entropy()

    return actions.detach().cpu(), log_probs, entropies

* We need to perform gradient ascent on the objective, $\nabla_\theta J(\theta)=\nabla_\theta\operatorname{log}\pi_\theta(a_t|s_t)G_t$, which means descending the negative $\operatorname{log}$ gradient
* We have to store rewards at each step, as well as log probs
* Tracking entropy can be useful for diagnosing issues

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [10]:
from torch import optim

from itertools import count
from tqdm.auto import tqdm

import wandb

In [11]:
def train(
    n_episodes = 8_000,
    learning_rate = 1e-4,
    gamma = 0.99,
    entropy_coef = 0.01,
    min_update_batch_size = 16,
    leave_bar = True,
    suggestion_uuid = None,
):  
    net = PolicyNetwork(n_input, n_hiddens, n_actions).to(device)

    if wandb.run is not None:
        wandb.watch(net, log="all")

    env_log_probs = [list() for _ in range(num_envs)]
    env_entropies = [list() for _ in range(num_envs)]
    env_rewards = [list() for _ in range(num_envs)]
    returns = []
    losses = []
    loss = None
    trailing_return = None
    return_alpha = 0.99
    
    optimizer = optim.Adam(net.parameters(), lr=learning_rate)
    
    obs, _ = vecenv.reset()
    episodes = 0
    bar = tqdm(
        desc="Episodes",
        total=n_episodes,
        initial=0,
        leave=leave_bar,
    )
    
    try:
        for t in count():
            logits = net(obs)
            actions, log_probs, entropies = select_actions(logits)
    
            obs, rewards, dones, truncateds, _ = vecenv.step(actions)
        
            for i, (log_prob, entropy, reward) in enumerate(zip(log_probs, entropies, rewards)):
                env_log_probs[i].append(log_prob)
                env_entropies[i].append(entropy)
                env_rewards[i].append(reward / 500.0)

            for i in np.where(dones | truncateds)[0]:
                full_return = sum(env_rewards[i])
                
                good_episode = False
                if trailing_return is None:
                    trailing_return = full_return
                    good_episode = True
                else:
                    trailing_return = trailing_return * return_alpha + full_return * (1 - return_alpha)
                    good_episode = full_return > trailing_return

                
                episode_return = 0.0
                loss = 0.0
                
                T = len(env_rewards[i])
                for t in reversed(range(T)):
                    episode_return = env_rewards[i][t] + gamma * episode_return
                    # Note: This trick only works because there are no negative returns in CartPole!
                    loss += env_log_probs[i][t] * (-1 if good_episode else 1) * episode_return
                    loss -= env_entropies[i][t] * entropy_coef

                loss /= T
                losses.append(loss)

                with torch.no_grad():
                    avg_entropy = sum(env_entropies[i]) / T

                full_return = sum(env_rewards[i])
                if wandb.run is not None:
                    wandb.log(
                        data={
                            "return": full_return,
                            "avg_entropy": avg_entropy.item(),
                            "loss": abs(loss.item()),
                        },
                        commit=True,
                    )
                
                returns.append(full_return)
                
                env_log_probs[i] = []
                env_entropies[i] = []
                env_rewards[i] = []

                episodes += 1
                bar.update()

            if len(losses) >= min_update_batch_size:
                mean_loss = sum(losses) / len(losses)
                
                optimizer.zero_grad()
                mean_loss.backward()
                optimizer.step()

                env_log_probs = [list() for _ in range(num_envs)]
                env_entropies = [list() for _ in range(num_envs)]
                env_rewards = [list() for _ in range(num_envs)]
                returns = []
                losses = []
                loss = None
                
                obs, _ = vecenv.reset()
        
            if episodes >= n_episodes:
                bar.close()
                model_path = "./models/REINFORCE_scratch_CartPole_latest.pt"
                torch.save(net.state_dict(), model_path)
                wandb.log_model(model_path, "latest")
                wandb.unwatch()
                net.eval()
                break

    except KeyboardInterrupt:
        print("Training halted manually.")
    
    mean_last_100 = np.mean([np.sum(rets) for rets in returns[-100:]])
    max_last_100 = np.max([np.sum(rets) for rets in returns[-100:]])
    return net, mean_last_100, max_last_100, t

In [12]:
n_input = vecenv.single_observation_space.shape[0]
n_actions = vecenv.single_action_space.n
n_hiddens = 16

n_episodes = 100_000
learning_rate = 1e-3
gamma = 0.99
entropy_coef = 0.0
min_update_batch_size = 16

project = "REINFORCE-scratch-CartPole"
config = {
    "num_envs": num_envs,
    "n_hiddens": n_hiddens,
    "n_episodes": n_episodes,
    "learning_rate": learning_rate,
    "entropy_coef": entropy_coef,
    "min_update_batch_size": min_update_batch_size,
}

wandb.init(
    project=project,
    config=config,
)
    
net, mean_return, max_return, t = train(
    n_episodes=n_episodes,
    learning_rate=learning_rate,
    gamma=gamma,
    entropy_coef=entropy_coef,
    min_update_batch_size=min_update_batch_size,
)

print(f"Mean: {mean_return * 500:.2f}, Max: {max_return * 500:.2f}")

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mfitti[0m. Use [1m`wandb login --relogin`[0m to force relogin


Episodes:   0%|          | 0/100000 [00:00<?, ?it/s]

Training halted manually.
Mean: 500.00, Max: 500.00


In [13]:
import gymnasium as gym
import numpy as np
from datetime import datetime
from pathlib import Path

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
video_folder = f"./videos/REINFORCE_scratch_CartPole_{timestamp}"

env = gym.make("CartPole-v1", render_mode="rgb_array")
env = gym.wrappers.RecordVideo(env, video_folder)

epsilon = 0.0

for _ in range(9):
    ob, _ = env.reset()
    ob = np.expand_dims(ob, 0)
    ret = 0
    done, truncated = False, False
    while not (done or truncated):
        logits = net(ob)
        actions, *_ = select_actions(logits)
        ob, reward, done, truncated, _ = env.step(actions[0].numpy())
        ob = np.expand_dims(ob, 0)
        ret += reward
    
    print(ret)

if wandb.run is not None:
    latest_video = max(
        Path(video_folder).glob("*.mp4"),
        key=lambda x: x.stat().st_mtime
    )
    wandb.log({
        "video": wandb.Video(str(latest_video))
    })
    wandb.finish()
    
env.close()

MoviePy - Building video /home/fitti/projects/puffer/videos/REINFORCE_scratch_CartPole_20250217_084004/rl-video-episode-0.mp4.
MoviePy - Writing video /home/fitti/projects/puffer/videos/REINFORCE_scratch_CartPole_20250217_084004/rl-video-episode-0.mp4




[Ame_index:   0%|          | 0/500 [00:00<?, ?it/s, now=None]
[Ame_index:  24%|██▍       | 119/500 [00:00<00:00, 1182.12it/s, now=None]
[Ame_index:  53%|█████▎    | 263/500 [00:00<00:00, 1330.45it/s, now=None]
[Ame_index:  82%|████████▏ | 409/500 [00:00<00:00, 1388.85it/s, now=None]
[A                                                                       

MoviePy - Done !
MoviePy - video ready /home/fitti/projects/puffer/videos/REINFORCE_scratch_CartPole_20250217_084004/rl-video-episode-0.mp4
500.0
MoviePy - Building video /home/fitti/projects/puffer/videos/REINFORCE_scratch_CartPole_20250217_084004/rl-video-episode-1.mp4.
MoviePy - Writing video /home/fitti/projects/puffer/videos/REINFORCE_scratch_CartPole_20250217_084004/rl-video-episode-1.mp4




[Ame_index:   0%|          | 0/500 [00:00<?, ?it/s, now=None]
[Ame_index:  24%|██▎       | 118/500 [00:00<00:00, 1178.23it/s, now=None]
[Ame_index:  52%|█████▏    | 262/500 [00:00<00:00, 1328.72it/s, now=None]
[Ame_index:  80%|████████  | 402/500 [00:00<00:00, 1358.74it/s, now=None]
[A                                                                       

MoviePy - Done !
MoviePy - video ready /home/fitti/projects/puffer/videos/REINFORCE_scratch_CartPole_20250217_084004/rl-video-episode-1.mp4
500.0
500.0
500.0
500.0
500.0
500.0
500.0
MoviePy - Building video /home/fitti/projects/puffer/videos/REINFORCE_scratch_CartPole_20250217_084004/rl-video-episode-8.mp4.
MoviePy - Writing video /home/fitti/projects/puffer/videos/REINFORCE_scratch_CartPole_20250217_084004/rl-video-episode-8.mp4




[Ame_index:   0%|          | 0/500 [00:00<?, ?it/s, now=None]
[Ame_index:  25%|██▍       | 124/500 [00:00<00:00, 1234.57it/s, now=None]
[Ame_index:  53%|█████▎    | 267/500 [00:00<00:00, 1348.54it/s, now=None]
[Ame_index:  82%|████████▏ | 410/500 [00:00<00:00, 1381.91it/s, now=None]
[A                                                                       

MoviePy - Done !
MoviePy - video ready /home/fitti/projects/puffer/videos/REINFORCE_scratch_CartPole_20250217_084004/rl-video-episode-8.mp4
500.0


0,1
avg_entropy,▇▇▇███████▇▃▇▆▄▂▅▄▄▅▃▃▃▃▃▂▃▁▂▃▂▃▂▂▃▂▂▂▂▂
loss,▁▁▁▁▂▂▃▂▃▆▅█████████▇█▇█████▇▇▇█▇▇██████
return,▁▁▁▁▁▁▁▁▁▂▅▃▂█▅█████████████████████████

0,1
avg_entropy,0.54114
loss,0.08903
return,1.0
