In [1]:
import numpy as np
import pufferlib, pufferlib.vector
from pufferlib.environments import classic_control

In [2]:
num_envs = 12

In [3]:
vecenv = pufferlib.vector.make(
    classic_control.env_creator("CartPole-v1"),
    num_envs=num_envs,
    backend=pufferlib.vector.Multiprocessing,
)

Process Process-10:
Process Process-5:
Process Process-12:
Process Process-4:
Process Process-2:
Process Process-11:
Process Process-1:
Process Process-3:
Process Process-7:
Process Process-6:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/fitti/.conda/envs/puffer/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/fitti/.conda/envs/puffer/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/fitti/.conda/envs/puffer/lib/python3.11/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/home/fitti/.conda/envs/puffer/lib/python3.11/multiprocessing/process.py", line 108

In [4]:
env_rewards = [list() for _ in range(num_envs)]
env_rewards

returns = []

from itertools import count

dones = np.array([False] * num_envs)
truncateds = np.array([False] * num_envs)
dones, truncateds

obs, _ = vecenv.reset()
episodes = 0
for t in count():
    obs, rewards, dones, truncateds, _ = vecenv.step(vecenv.action_space.sample())
    for i, reward in enumerate(rewards):
        env_rewards[i].append(reward)
    for i in np.where(dones | truncateds)[0]:
        returns.append(sum(env_rewards[i]))
        env_rewards[i] = []
        episodes += 1
    if episodes >= 1_000:
        break

np.mean([np.sum(rets) for rets in returns]), t

(21.426, 1887)

In [5]:
import torch
from torch import nn

In [6]:
class QNetwork(nn.Module):
    def __init__(self, n_input, n_hiddens, n_actions):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(n_input, n_hiddens),
            nn.ReLU(),
            nn.Linear(n_hiddens, n_actions),
        )

    def forward(self, x):
        device = next(self.mlp.parameters()).device
        x = torch.tensor(np.array(x), dtype=torch.float32, device=device)
        
        return self.mlp(x)

In [7]:
import random

In [8]:
def select_actions(q_values, epsilon):
    batch_size, n_actions = q_values.shape
    if random.random() < epsilon:
        return np.random.choice(range(n_actions), batch_size)
    return np.array(q_values.argmax(dim=-1).cpu())

In [9]:
from collections import deque

In [10]:
class ReplayBuffer:
    def __init__(self, capacity: int):
        self.capacity = capacity
        self.queue = deque(maxlen=capacity)

    def store(self, transition):
        self.queue.append(transition)

    def extend(self, transitions):
        for transition in transitions:
            self.store(transition)

    def sample(self, n_samples: int):
        return [random.choice(self.queue) for _ in range(n_samples)]

    def __len__(self):
        return len(self.queue)

    def __repr__(self):
        return self.queue.__repr__()

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [12]:
from torch import optim

from itertools import count
from tqdm.auto import tqdm

import wandb

In [13]:
def train(
    n_target_update_eps = 60,
    n_episodes = 8_000,
    update_batch_size = 128,
    learning_rate = 1e-4,
    epsilon_start = 1.0,
    epsilon_end = 0.0,
    epsilon_decay_percent = 0.9,
    leave_bar = True,
    suggestion_uuid = None,
):  
    net = QNetwork(n_input, n_hiddens, n_actions).to(device)
    target_net = QNetwork(n_input, n_hiddens, n_actions).to(device)
    target_net.load_state_dict(net.state_dict())
    target_net.eval()

    if wandb.run is not None:
        wandb.watch(net, log="all")
    
    env_rewards = [list() for _ in range(num_envs)]
    env_rewards
    returns = []
    loss = None
    
    gamma = 1.0
    buffer_size = 1_000_000
    epsilon_decay = epsilon_decay_percent * n_episodes
    epsilon_decay_rate = (epsilon_start - epsilon_end) / epsilon_decay
    epsilon = epsilon_start
    
    optimizer = optim.Adam(net.parameters(), lr=learning_rate)
    loss_fn = nn.SmoothL1Loss()
    buffer = ReplayBuffer(capacity=buffer_size)
    
    obs, _ = vecenv.reset()
    obs = obs.copy()
    episodes = 0
    bar = tqdm(
        desc="Episodes",
        total=n_episodes,
        initial=0,
        leave=leave_bar,
    )
    
    debug = False
    debug_seen = False

    try:
        for t in count():
            with torch.no_grad():
                q_values = net(obs)
                actions = select_actions(q_values, epsilon)
        
            if debug and debug_seen:
                print("Ob seen:\t", obs[0])
                print("Action taken:\t", actions[0])
            
            next_obs, rewards, dones, truncateds, _ = vecenv.step(actions)
        
            if debug and debug_seen:
                print("Ob seen:\t", obs[0])
            buffer.extend(zip(obs.copy(), actions.copy(), rewards.copy(), next_obs.copy(), (dones | truncateds).copy()))
            obs = next_obs.copy()
        
            for i, reward in enumerate(rewards):
                env_rewards[i].append(reward)

            if len(buffer) >= update_batch_size:
                obs_s, actions_s, rewards_s, next_obs_s, dones_s = zip(*buffer.sample(update_batch_size))
                if debug:
                    print("\nOb:\t\t", obs_s[0])
                    print("Action:\t", actions_s[0])
                    print("Reward:\t", rewards_s[0])
                    print("Next ob:\t", next_obs_s[0])
                    print("Done:\t\t", dones_s[0])
            
                    print("\nBuffer length:", len(buffer))
                    print("\nBuffer items:", buffer.queue[0])
        
                actions_t = torch.tensor(actions_s, device=device).unsqueeze(1)
                q_values_t = net(obs_s).gather(1, actions_t).squeeze(1)
                rewards_t = torch.tensor(rewards_s, device=device) / 500.
                dones_t = torch.tensor(dones_s, device=device)
                with torch.no_grad():
                    target_q_values_t = target_net(next_obs_s).max(dim=1).values
        
                target = rewards_t + gamma * ~dones_t * target_q_values_t
                loss = loss_fn(target, q_values_t)
        
                if debug:
                    print("Actions T:\t", actions_t[0])
                    print("Qs T:\t\t", q_values_t[0])
                    print("Target Qs:\t", target_q_values_t[0])
                    print("TD Target:\t", target[0])
                    with torch.no_grad():
                        td_error = target - q_values_t
                    print("TD Error:\t", td_error[0])
        
        
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                if debug:
                    with torch.no_grad():
                        q_values_t = net(obs_s).gather(1, actions_t).squeeze(1)
                        target_q_values_t = target_net(next_obs_s).max(dim=1).values
                        target = rewards_t + gamma * ~dones_t * target_q_values_t
                        td_error = target - q_values_t
                        
                    print("Actions T:\t", actions_t[0])
                    print("Qs T:\t\t", q_values_t[0])
                    print("TD Target:\t", target[0])
                    print("Target Qs:\t", target_q_values_t[0])
                    print("TD Error:\t", td_error[0])
            
                    break

            step_returns = []
                
            for i in np.where(dones | truncateds)[0]:
                ret = sum(env_rewards[i])
                returns.append(ret)
                env_rewards[i] = []
                episodes += 1
                if epsilon >= epsilon_end:
                    epsilon -= epsilon_decay_rate
                bar.update()

                if wandb.run is not None:
                    wandb.log(
                        data={
                            "avg_return": ret,
                            "epsilon": epsilon,
                            "loss": loss,
                        },
                        commit=True,
                    )
        
            if episodes >= n_episodes:
                bar.close()
                model_path = "./models/DQN_scratch_CartPole_latest.pt"
                torch.save(net.state_dict(), model_path)
                wandb.log_model(model_path, "latest")
                wandb.unwatch()
                net.eval()
                break
                #debug = True
        
            if episodes % n_target_update_eps == 0:
                target_net.load_state_dict(net.state_dict())

    except KeyboardInterrupt:
        print("Training halted manually.")
    
    mean_last_100 = np.mean([np.sum(rets) for rets in returns[-100:]])
    max_last_100 = np.max([np.sum(rets) for rets in returns[-100:]])
    return net, mean_last_100, max_last_100, t

In [14]:
from carbs import CARBS, CARBSParams, LinearSpace, LogSpace, LogitSpace, ObservationInParam, ParamDictType, Param
import time

In [15]:
n_input = vecenv.single_observation_space.shape[0]
n_actions = vecenv.single_action_space.n
n_hiddens = 16

n_target_update_eps = 60
n_episodes = 8_000
update_batch_size = 256
learning_rate = 1e-4
epsilon_start = 0.9
epsilon_end = 0.0
epsilon_decay_percent = 0.8

project = "DQN-scratch-CartPole"
config = {
    "num_envs": num_envs,
    "n_hiddens": n_hiddens,
    "n_target_update_eps": n_target_update_eps,
    "n_episodes": n_episodes,
    "update_batch_size": update_batch_size,
    "learning_rate": learning_rate,
    "epsilon_start": epsilon_start,
    "epsilon_end": epsilon_end,
    "epsilon_decay_percent": epsilon_decay_percent,
}

wandb.init(
    project=project,
    config=config,
)
    
net, mean_return, max_return, t = train(
    n_target_update_eps=n_target_update_eps,
    n_episodes=n_episodes,
    update_batch_size=update_batch_size,
    learning_rate=learning_rate,
    epsilon_start=epsilon_start,
    epsilon_end=epsilon_end,
    epsilon_decay_percent=epsilon_decay_percent,
)

mean_return, max_return

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mfitti[0m. Use [1m`wandb login --relogin`[0m to force relogin


Episodes:   0%|          | 0/8000 [00:00<?, ?it/s]

Training halted manually.


(368.57, 500.0)

In [16]:
import gymnasium as gym
import numpy as np
from datetime import datetime
from pathlib import Path

timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
video_folder = f"./videos/DQN_scratch_CartPole_{timestamp}"

env = gym.make("CartPole-v1", render_mode="rgb_array")
env = gym.wrappers.RecordVideo(env, video_folder)

epsilon = 0.0

for _ in range(9):
    ob, _ = env.reset()
    ob = np.expand_dims(ob, 0)
    ret = 0
    done, truncated = False, False
    while not (done or truncated):
        with torch.no_grad():
            q_values = net(ob)
        actions = select_actions(q_values, epsilon)
        ob, reward, done, truncated, _ = env.step(actions[0])
        ob = np.expand_dims(ob, 0)
        ret += reward
    
    print(ret)

if wandb.run is not None:
    latest_video = max(
        Path(video_folder).glob("*.mp4"),
        key=lambda x: x.stat().st_mtime
    )
    wandb.log({
        "video": wandb.Video(str(latest_video))
    })
    wandb.finish()
    
env.close()

MoviePy - Building video /home/fitti/projects/puffer/videos/DQN_scratch_CartPole_20250216_200453/rl-video-episode-0.mp4.
MoviePy - Writing video /home/fitti/projects/puffer/videos/DQN_scratch_CartPole_20250216_200453/rl-video-episode-0.mp4




[Ame_index:   0%|          | 0/289 [00:00<?, ?it/s, now=None]
[Ame_index:  39%|███▉      | 113/289 [00:00<00:00, 1127.01it/s, now=None]
[Ame_index:  88%|████████▊ | 255/289 [00:00<00:00, 1294.73it/s, now=None]
[A                                                                       

MoviePy - Done !
MoviePy - video ready /home/fitti/projects/puffer/videos/DQN_scratch_CartPole_20250216_200453/rl-video-episode-0.mp4
289.0
MoviePy - Building video /home/fitti/projects/puffer/videos/DQN_scratch_CartPole_20250216_200453/rl-video-episode-1.mp4.
MoviePy - Writing video /home/fitti/projects/puffer/videos/DQN_scratch_CartPole_20250216_200453/rl-video-episode-1.mp4




[Ame_index:   0%|          | 0/374 [00:00<?, ?it/s, now=None]
[Ame_index:  31%|███       | 115/374 [00:00<00:00, 1145.78it/s, now=None]
[Ame_index:  68%|██████▊   | 253/374 [00:00<00:00, 1279.77it/s, now=None]
[A                                                                       

MoviePy - Done !
MoviePy - video ready /home/fitti/projects/puffer/videos/DQN_scratch_CartPole_20250216_200453/rl-video-episode-1.mp4
374.0
326.0
279.0
326.0
330.0
311.0
295.0
MoviePy - Building video /home/fitti/projects/puffer/videos/DQN_scratch_CartPole_20250216_200453/rl-video-episode-8.mp4.
MoviePy - Writing video /home/fitti/projects/puffer/videos/DQN_scratch_CartPole_20250216_200453/rl-video-episode-8.mp4




[Ame_index:   0%|          | 0/327 [00:00<?, ?it/s, now=None]
[Ame_index:  35%|███▌      | 115/327 [00:00<00:00, 1149.99it/s, now=None]
[Ame_index:  78%|███████▊  | 255/327 [00:00<00:00, 1292.90it/s, now=None]
[A                                                                       

MoviePy - Done !
MoviePy - video ready /home/fitti/projects/puffer/videos/DQN_scratch_CartPole_20250216_200453/rl-video-episode-8.mp4
327.0


0,1
avg_return,▁▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▁▁▂▂▄▁▃▃▂▄▂▃▂▅▅▃▅▅▆▆▇▆█▆
epsilon,███▇▇▇▇▇▇▆▆▆▆▆▆▆▅▅▅▅▅▅▅▅▄▄▄▄▄▄▄▄▃▃▃▃▂▂▂▁
loss,▁▁▁▃▂▂▄▂▂▂▃▄▂▅▅▅▃▂▅▃▄▃▃▃▇▅▅▅▂▅▇▅▂▇█▄▄▂▃▆

0,1
avg_return,348.0
epsilon,0.01955
loss,0.00293
