In [1]:
import numpy as np
import pufferlib, pufferlib.vector
from pufferlib.environments import classic_control

In [2]:
num_envs = 12

In [3]:
vecenv = pufferlib.vector.make(
    classic_control.env_creator("CartPole-v1"),
    num_envs=num_envs,
    backend=pufferlib.vector.Multiprocessing,
)

In [4]:
env_rewards = [list() for _ in range(num_envs)]
env_rewards

returns = []

from itertools import count

dones = np.array([False] * num_envs)
truncateds = np.array([False] * num_envs)
dones, truncateds

obs, _ = vecenv.reset()
episodes = 0
for t in count():
    obs, rewards, dones, truncateds, _ = vecenv.step(vecenv.action_space.sample())
    for i, reward in enumerate(rewards):
        env_rewards[i].append(reward)
    for i in np.where(dones | truncateds)[0]:
        returns.append(sum(env_rewards[i]))
        env_rewards[i] = []
        episodes += 1
    if episodes >= 1_000:
        break

np.mean([np.sum(rets) for rets in returns]), t

(22.064, 1933)

In [5]:
import torch
from torch import nn

In [6]:
class PolicyNetwork(nn.Module):
    def __init__(self, n_input, n_hiddens, n_actions):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(n_input, n_hiddens),
            nn.ReLU(),
            nn.Linear(n_hiddens, n_actions),
        )

    def forward(self, x):
        device = next(self.mlp.parameters()).device
        x = torch.tensor(np.array(x), dtype=torch.float32, device=device)
        
        return self.mlp(x)

In [7]:
import random

In [8]:
def select_actions(logits):
    dist = torch.distributions.Categorical(logits=logits) # shape [batch_size, n_actions]
    
    actions = dist.sample() # Shape: [n_actions]
    log_probs = dist.log_prob(actions)
    entropies = dist.entropy()

    return actions.cpu(), log_probs, entropies

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [10]:
from torch import optim

from itertools import count
from tqdm.auto import tqdm

import wandb

* We need to perform gradient ascent on the objective, $\nabla_\theta J(\theta)=\nabla_\theta\operatorname{log}\pi_\theta(a_t|s_t)G_t$, which means descending the negative $\operatorname{log}$ gradient
* We have to store rewards at each step, as well as log probs
* Tracking entropy can be useful for diagnosing issues

In [None]:
def train(
    n_episodes = 8_000,
    learning_rate = 1e-4,
    leave_bar = True,
    suggestion_uuid = None,
):  
    net = PolicyNetwork(n_input, n_hiddens, n_actions).to(device)

    if wandb.run is not None:
        wandb.watch(net, log="all")

    env_log_probs = [list() for _ in range(num_envs)]
    env_entropies = [list() for _ in range(num_envs)]
    env_rewards = [list() for _ in range(num_envs)]
    returns = []
    loss = None
    
    gamma = 1.0
    
    optimizer = optim.Adam(net.parameters(), lr=learning_rate)
    
    obs, _ = vecenv.reset()
    obs = obs.copy()
    episodes = 0
    bar = tqdm(
        desc="Episodes",
        total=n_episodes,
        initial=0,
        leave=leave_bar,
    )
    
    try:
        for t in count():
            with torch.no_grad():
                logits = net(obs)
                actions, log_probs, entropies = select_actions(logits)
        
            next_obs, rewards, dones, truncateds, _ = vecenv.step(actions)
            obs = next_obs.copy()
        
            for i, (log_prob, entropy, reward) in enumerate(zip(log_probs, entropies, rewards)):
                env_log_probs[i].append(log_prob)
                env_entropies[i].append(entropy)
                env_rewards[i].append(reward)

            if len(buffer) >= update_batch_size:
                obs_s, actions_s, rewards_s, next_obs_s, dones_s = zip(*buffer.sample(update_batch_size))
                
                actions_t = torch.tensor(actions_s, device=device).unsqueeze(1)
                q_values_t = net(obs_s).gather(1, actions_t).squeeze(1)
                rewards_t = torch.tensor(rewards_s, device=device) / 500.
                dones_t = torch.tensor(dones_s, device=device)
                with torch.no_grad():
                    target_q_values_t = target_net(next_obs_s).max(dim=1).values
        
                target = rewards_t + gamma * ~dones_t * target_q_values_t
                loss = loss_fn(target, q_values_t)
        
        
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
            step_returns = []
                
            for i in np.where(dones | truncateds)[0]:
                ret = sum(env_rewards[i])
                returns.append(ret)
                env_rewards[i] = []
                episodes += 1
                if epsilon >= epsilon_end:
                    epsilon -= epsilon_decay_rate
                bar.update()

                if wandb.run is not None:
                    wandb.log(
                        data={
                            "avg_return": ret,
                            "epsilon": epsilon,
                            "loss": loss,
                        },
                        commit=True,
                    )
        
            if episodes >= n_episodes:
                bar.close()
                model_path = "./models/REINFORCE_scratch_CartPole_latest.pt"
                torch.save(net.state_dict(), model_path)
                wandb.log_model(model_path, "latest")
                wandb.unwatch()
                net.eval()
                break
        
            if episodes % n_target_update_eps == 0:
                target_net.load_state_dict(net.state_dict())

    except KeyboardInterrupt:
        print("Training halted manually.")
    
    mean_last_100 = np.mean([np.sum(rets) for rets in returns[-100:]])
    max_last_100 = np.max([np.sum(rets) for rets in returns[-100:]])
    return net, mean_last_100, max_last_100, t

In [None]:
from carbs import CARBS, CARBSParams, LinearSpace, LogSpace, LogitSpace, ObservationInParam, ParamDictType, Param
import time

In [None]:
n_input = vecenv.single_observation_space.shape[0]
n_actions = vecenv.single_action_space.n
n_hiddens = 16

n_target_update_eps = 60
n_episodes = 8_000
update_batch_size = 256
learning_rate = 1e-4
epsilon_start = 0.9
epsilon_end = 0.0
epsilon_decay_percent = 0.8

project = "DQN-scratch-CartPole"
config = {
    "num_envs": num_envs,
    "n_hiddens": n_hiddens,
    "n_target_update_eps": n_target_update_eps,
    "n_episodes": n_episodes,
    "update_batch_size": update_batch_size,
    "learning_rate": learning_rate,
    "epsilon_start": epsilon_start,
    "epsilon_end": epsilon_end,
    "epsilon_decay_percent": epsilon_decay_percent,
}

wandb.init(
    project=project,
    config=config,
)
    
net, mean_return, max_return, t = train(
    n_target_update_eps=n_target_update_eps,
    n_episodes=n_episodes,
    update_batch_size=update_batch_size,
    learning_rate=learning_rate,
    epsilon_start=epsilon_start,
    epsilon_end=epsilon_end,
    epsilon_decay_percent=epsilon_decay_percent,
)

mean_return, max_return

In [None]:
import gymnasium as gym
import numpy as np
from datetime import datetime
from pathlib import Path

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
video_folder = f"./videos/DQN_scratch_CartPole_{timestamp}"

env = gym.make("CartPole-v1", render_mode="rgb_array")
env = gym.wrappers.RecordVideo(env, video_folder)

epsilon = 0.0

for _ in range(9):
    ob, _ = env.reset()
    ob = np.expand_dims(ob, 0)
    ret = 0
    done, truncated = False, False
    while not (done or truncated):
        with torch.no_grad():
            q_values = net(ob)
        actions = select_actions(q_values, epsilon)
        ob, reward, done, truncated, _ = env.step(actions[0])
        ob = np.expand_dims(ob, 0)
        ret += reward
    
    print(ret)

if wandb.run is not None:
    latest_video = max(
        Path(video_folder).glob("*.mp4"),
        key=lambda x: x.stat().st_mtime
    )
    wandb.log({
        "video": wandb.Video(str(latest_video))
    })
    wandb.finish()
    
env.close()