In [None]:
!pip install -r ../requirements.txt

In [16]:
# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/dqn/#dqnpy
import os
import random
import time
from dataclasses import dataclass

import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tyro
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter


@dataclass
class Args:
    exp_name: str = "dqn"
    """the name of this experiment"""
    seed: int = 1
    """seed of the experiment"""
    torch_deterministic: bool = True
    """if toggled, `torch.backends.cudnn.deterministic=False`"""
    cuda: bool = True
    """if toggled, cuda will be enabled by default"""
    track: bool = False
    """if toggled, this experiment will be tracked with Weights and Biases"""
    wandb_project_name: str = "cleanRL"
    """the wandb's project name"""
    wandb_entity: str = None
    """the entity (team) of wandb's project"""
    capture_video: bool = True
    """whether to capture videos of the agent performances (check out `videos` folder)"""
    save_model: bool = True
    """whether to save model into the `runs/{run_name}` folder"""
    upload_model: bool = False
    """whether to upload the saved model to huggingface"""
    hf_entity: str = ""
    """the user or org name of the model repository from the Hugging Face Hub"""

    # Algorithm specific arguments
    env_id: str = "CartPole-v1"
    """the id of the environment"""
    total_timesteps: int = 300000
    """total timesteps of the experiments"""
    learning_rate: float = 2.5e-4
    """the learning rate of the optimizer"""
    num_envs: int = 1
    """the number of parallel game environments"""
    buffer_size: int = 10000
    """the replay memory buffer size"""
    gamma: float = 0.99
    """the discount factor gamma"""
    tau: float = 1.0
    """the target network update rate"""
    target_network_frequency: int = 500
    """the timesteps it takes to update the target network"""
    batch_size: int = 128
    """the batch size of sample from the reply memory"""
    start_e: float = 1
    """the starting epsilon for exploration"""
    end_e: float = 0.05
    """the ending epsilon for exploration"""
    exploration_fraction: float = 0.5
    """the fraction of `total-timesteps` it takes from start-e to go end-e"""
    learning_starts: int = 10000
    """timestep to start learning"""
    train_frequency: int = 10
    """the frequency of training"""

In [17]:
def make_env(env_id, seed, idx, capture_video, run_name):
    def thunk():
        if capture_video and idx == 0:
            env = gym.make(env_id, render_mode="rgb_array")
            env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        else:
            env = gym.make(env_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        env.action_space.seed(seed)

        return env

    return thunk

In [18]:
# ALGO LOGIC: initialize agent here:
class QNetwork(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(np.array(env.single_observation_space.shape).prod(), 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, env.single_action_space.n),
        )

    def forward(self, x):
        return self.network(x)

In [19]:
def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
    slope = (end_e - start_e) / duration
    return max(slope * t + start_e, end_e)

In [20]:
import stable_baselines3 as sb3

if sb3.__version__ < "2.0":
    raise ValueError(
        """Ongoing migration: run the following command to install the new dependencies:

poetry run pip install "stable_baselines3==2.0.0a1"
"""
    )

# args = tyro.cli(Args)
args = Args()
assert args.num_envs == 1, "vectorized envs are not supported at the moment"
run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
if args.track:
    import wandb

    wandb.init(
        project=args.wandb_project_name,
        entity=args.wandb_entity,
        sync_tensorboard=True,
        config=vars(args),
        name=run_name,
        monitor_gym=True,
        save_code=True,
    )

In [21]:
writer = SummaryWriter(f"runs/{run_name}")
writer.add_text(
    "hyperparameters",
    "|param|value|\n|-|-|\n%s"
    % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
)

In [22]:
# TRY NOT TO MODIFY: seeding
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic

device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
print(device)

cuda


In [23]:
# env setup
envs = gym.vector.SyncVectorEnv(
    [
        make_env(args.env_id, args.seed + i, i, args.capture_video, run_name)
        for i in range(args.num_envs)
    ]
)
assert isinstance(
    envs.single_action_space, gym.spaces.Discrete
), "only discrete action space is supported"

q_network = QNetwork(envs).to(device)
optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate)
target_network = QNetwork(envs).to(device)
target_network.load_state_dict(q_network.state_dict())

rb = ReplayBuffer(
    args.buffer_size,
    envs.single_observation_space,
    envs.single_action_space,
    device,
    handle_timeout_termination=False,
)
start_time = time.time()

# TRY NOT TO MODIFY: start the game
obs, _ = envs.reset(seed=args.seed)
for global_step in range(args.total_timesteps):
    # ALGO LOGIC: put action logic here
    epsilon = linear_schedule(
        args.start_e,
        args.end_e,
        args.exploration_fraction * args.total_timesteps,
        global_step,
    )
    if random.random() < epsilon:
        actions = np.array(
            [envs.single_action_space.sample() for _ in range(envs.num_envs)]
        )
    else:
        q_values = q_network(torch.Tensor(obs).to(device))
        actions = torch.argmax(q_values, dim=1).cpu().numpy()

    # TRY NOT TO MODIFY: execute the game and log data.
    next_obs, rewards, terminations, truncations, infos = envs.step(actions)

    # TRY NOT TO MODIFY: record rewards for plotting purposes
    if "final_info" in infos:
        for info in infos["final_info"]:
            if info and "episode" in info:
                print(
                    f"global_step={global_step}, episodic_return={info['episode']['r']}"
                )
                writer.add_scalar(
                    "charts/episodic_return", info["episode"]["r"], global_step
                )
                writer.add_scalar(
                    "charts/episodic_length", info["episode"]["l"], global_step
                )

    # TRY NOT TO MODIFY: save data to reply buffer; handle `final_observation`
    real_next_obs = next_obs.copy()
    for idx, trunc in enumerate(truncations):
        if trunc:
            real_next_obs[idx] = infos["final_observation"][idx]
    rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

    # TRY NOT TO MODIFY: CRUCIAL step easy to overlook
    obs = next_obs

    # ALGO LOGIC: training.
    if global_step > args.learning_starts:
        if global_step % args.train_frequency == 0:
            data = rb.sample(args.batch_size)
            with torch.no_grad():
                target_max, _ = target_network(data.next_observations).max(dim=1)
                td_target = data.rewards.flatten() + args.gamma * target_max * (
                    1 - data.dones.flatten()
                )
            old_val = q_network(data.observations).gather(1, data.actions).squeeze()
            loss = F.mse_loss(td_target, old_val)

            if global_step % 100 == 0:
                writer.add_scalar("losses/td_loss", loss, global_step)
                writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
                print("SPS:", int(global_step / (time.time() - start_time)))
                writer.add_scalar(
                    "charts/SPS",
                    int(global_step / (time.time() - start_time)),
                    global_step,
                )

            # optimize the model
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # update target network
        if global_step % args.target_network_frequency == 0:
            for target_network_param, q_network_param in zip(
                target_network.parameters(), q_network.parameters()
            ):
                target_network_param.data.copy_(
                    args.tau * q_network_param.data
                    + (1.0 - args.tau) * target_network_param.data
                )

Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
  declare_namespace(pkg)
Implementing implicit namespace packages (as specified in PEP 420) is preferred to `pkg_resources.declare_namespace`. See https://setuptools.pypa.io/en/latest/references/keywords.html#keyword-namespace-packages
  declare_namespace(pkg)


Moviepy - Building video /home/emasuriano/Git/rn-group-7/dqn/videos/CartPole-v1__dqn__1__1732462861/rl-video-episode-0.mp4.
Moviepy - Writing video /home/emasuriano/Git/rn-group-7/dqn/videos/CartPole-v1__dqn__1__1732462861/rl-video-episode-0.mp4



                                                   

Moviepy - Done !
Moviepy - video ready /home/emasuriano/Git/rn-group-7/dqn/videos/CartPole-v1__dqn__1__1732462861/rl-video-episode-0.mp4
global_step=28, episodic_return=[29.]
Moviepy - Building video /home/emasuriano/Git/rn-group-7/dqn/videos/CartPole-v1__dqn__1__1732462861/rl-video-episode-1.mp4.
Moviepy - Writing video /home/emasuriano/Git/rn-group-7/dqn/videos/CartPole-v1__dqn__1__1732462861/rl-video-episode-1.mp4



                                                   

Moviepy - Done !
Moviepy - video ready /home/emasuriano/Git/rn-group-7/dqn/videos/CartPole-v1__dqn__1__1732462861/rl-video-episode-1.mp4
global_step=38, episodic_return=[10.]
global_step=49, episodic_return=[11.]
global_step=85, episodic_return=[36.]
global_step=98, episodic_return=[13.]
global_step=114, episodic_return=[16.]
global_step=131, episodic_return=[17.]
global_step=150, episodic_return=[19.]
Moviepy - Building video /home/emasuriano/Git/rn-group-7/dqn/videos/CartPole-v1__dqn__1__1732462861/rl-video-episode-8.mp4.
Moviepy - Writing video /home/emasuriano/Git/rn-group-7/dqn/videos/CartPole-v1__dqn__1__1732462861/rl-video-episode-8.mp4



                                                   

Moviepy - Done !
Moviepy - video ready /home/emasuriano/Git/rn-group-7/dqn/videos/CartPole-v1__dqn__1__1732462861/rl-video-episode-8.mp4
global_step=187, episodic_return=[37.]
global_step=214, episodic_return=[27.]
global_step=229, episodic_return=[15.]
global_step=242, episodic_return=[13.]
global_step=265, episodic_return=[23.]
global_step=275, episodic_return=[10.]
global_step=304, episodic_return=[29.]
global_step=322, episodic_return=[18.]
global_step=333, episodic_return=[11.]
global_step=344, episodic_return=[11.]
global_step=390, episodic_return=[46.]
global_step=401, episodic_return=[11.]
global_step=433, episodic_return=[32.]
global_step=480, episodic_return=[47.]
global_step=496, episodic_return=[16.]
global_step=522, episodic_return=[26.]
global_step=556, episodic_return=[34.]
global_step=572, episodic_return=[16.]
global_step=591, episodic_return=[19.]
Moviepy - Building video /home/emasuriano/Git/rn-group-7/dqn/videos/CartPole-v1__dqn__1__1732462861/rl-video-episode-27.mp

                                                   

Moviepy - Done !
Moviepy - video ready /home/emasuriano/Git/rn-group-7/dqn/videos/CartPole-v1__dqn__1__1732462861/rl-video-episode-27.mp4
global_step=613, episodic_return=[22.]
global_step=637, episodic_return=[24.]
global_step=652, episodic_return=[15.]
global_step=663, episodic_return=[11.]
global_step=680, episodic_return=[17.]
global_step=693, episodic_return=[13.]
global_step=723, episodic_return=[30.]
global_step=739, episodic_return=[16.]
global_step=805, episodic_return=[66.]
global_step=836, episodic_return=[31.]
global_step=855, episodic_return=[19.]
global_step=880, episodic_return=[25.]
global_step=901, episodic_return=[21.]
global_step=912, episodic_return=[11.]
global_step=952, episodic_return=[40.]
global_step=985, episodic_return=[33.]
global_step=1003, episodic_return=[18.]
global_step=1018, episodic_return=[15.]
global_step=1066, episodic_return=[48.]
global_step=1090, episodic_return=[24.]
global_step=1102, episodic_return=[12.]
global_step=1115, episodic_return=[13.

                                                   

Moviepy - Done !
Moviepy - video ready /home/emasuriano/Git/rn-group-7/dqn/videos/CartPole-v1__dqn__1__1732462861/rl-video-episode-64.mp4
global_step=1375, episodic_return=[14.]
global_step=1393, episodic_return=[18.]
global_step=1416, episodic_return=[23.]
global_step=1464, episodic_return=[48.]
global_step=1486, episodic_return=[22.]
global_step=1498, episodic_return=[12.]
global_step=1512, episodic_return=[14.]
global_step=1524, episodic_return=[12.]
global_step=1549, episodic_return=[25.]
global_step=1593, episodic_return=[44.]
global_step=1618, episodic_return=[25.]
global_step=1632, episodic_return=[14.]
global_step=1656, episodic_return=[24.]
global_step=1674, episodic_return=[18.]
global_step=1687, episodic_return=[13.]
global_step=1702, episodic_return=[15.]
global_step=1724, episodic_return=[22.]
global_step=1740, episodic_return=[16.]
global_step=1764, episodic_return=[24.]
global_step=1780, episodic_return=[16.]
global_step=1790, episodic_return=[10.]
global_step=1800, epis

                                                   

Moviepy - Done !
Moviepy - video ready /home/emasuriano/Git/rn-group-7/dqn/videos/CartPole-v1__dqn__1__1732462861/rl-video-episode-125.mp4
global_step=2574, episodic_return=[12.]
global_step=2594, episodic_return=[20.]
global_step=2609, episodic_return=[15.]
global_step=2625, episodic_return=[16.]
global_step=2644, episodic_return=[19.]
global_step=2671, episodic_return=[27.]
global_step=2707, episodic_return=[36.]
global_step=2733, episodic_return=[26.]
global_step=2752, episodic_return=[19.]
global_step=2765, episodic_return=[13.]
global_step=2777, episodic_return=[12.]
global_step=2807, episodic_return=[30.]
global_step=2825, episodic_return=[18.]
global_step=2857, episodic_return=[32.]
global_step=2898, episodic_return=[41.]
global_step=2913, episodic_return=[15.]
global_step=2923, episodic_return=[10.]
global_step=2940, episodic_return=[17.]
global_step=2961, episodic_return=[21.]
global_step=2989, episodic_return=[28.]
global_step=3000, episodic_return=[11.]
global_step=3019, epi

                                                   

Moviepy - Done !
Moviepy - video ready /home/emasuriano/Git/rn-group-7/dqn/videos/CartPole-v1__dqn__1__1732462861/rl-video-episode-216.mp4
global_step=4606, episodic_return=[13.]
global_step=4629, episodic_return=[23.]
global_step=4674, episodic_return=[45.]
global_step=4692, episodic_return=[18.]
global_step=4703, episodic_return=[11.]
global_step=4721, episodic_return=[18.]
global_step=4735, episodic_return=[14.]
global_step=4751, episodic_return=[16.]
global_step=4763, episodic_return=[12.]
global_step=4800, episodic_return=[37.]
global_step=4834, episodic_return=[34.]
global_step=4858, episodic_return=[24.]
global_step=4895, episodic_return=[37.]
global_step=4913, episodic_return=[18.]
global_step=4945, episodic_return=[32.]
global_step=4968, episodic_return=[23.]
global_step=4982, episodic_return=[14.]
global_step=5001, episodic_return=[19.]
global_step=5013, episodic_return=[12.]
global_step=5036, episodic_return=[23.]
global_step=5071, episodic_return=[35.]
global_step=5095, epi



global_step=6010, episodic_return=[49.]
global_step=6026, episodic_return=[16.]
global_step=6039, episodic_return=[13.]
global_step=6063, episodic_return=[24.]
global_step=6090, episodic_return=[27.]
global_step=6132, episodic_return=[42.]
global_step=6145, episodic_return=[13.]
global_step=6159, episodic_return=[14.]
global_step=6187, episodic_return=[28.]
global_step=6203, episodic_return=[16.]
global_step=6244, episodic_return=[41.]
global_step=6254, episodic_return=[10.]
global_step=6335, episodic_return=[81.]
global_step=6365, episodic_return=[30.]
global_step=6382, episodic_return=[17.]
global_step=6439, episodic_return=[57.]
global_step=6449, episodic_return=[10.]
global_step=6545, episodic_return=[96.]
global_step=6571, episodic_return=[26.]
global_step=6602, episodic_return=[31.]
global_step=6616, episodic_return=[14.]
global_step=6639, episodic_return=[23.]
global_step=6669, episodic_return=[30.]
global_step=6692, episodic_return=[23.]
global_step=6714, episodic_return=[22.]


                                                   

Moviepy - Done !
Moviepy - video ready /home/emasuriano/Git/rn-group-7/dqn/videos/CartPole-v1__dqn__1__1732462861/rl-video-episode-343.mp4
global_step=7653, episodic_return=[12.]
global_step=7671, episodic_return=[18.]
global_step=7723, episodic_return=[52.]
global_step=7736, episodic_return=[13.]
global_step=7750, episodic_return=[14.]
global_step=7780, episodic_return=[30.]
global_step=7813, episodic_return=[33.]
global_step=7825, episodic_return=[12.]
global_step=7836, episodic_return=[11.]
global_step=7871, episodic_return=[35.]
global_step=7939, episodic_return=[68.]
global_step=7961, episodic_return=[22.]
global_step=7993, episodic_return=[32.]
global_step=8016, episodic_return=[23.]
global_step=8034, episodic_return=[18.]
global_step=8060, episodic_return=[26.]
global_step=8084, episodic_return=[24.]
global_step=8111, episodic_return=[27.]
global_step=8138, episodic_return=[27.]
global_step=8220, episodic_return=[82.]
global_step=8260, episodic_return=[40.]
global_step=8280, epi



global_step=8737, episodic_return=[51.]
global_step=8747, episodic_return=[10.]
global_step=8831, episodic_return=[84.]
global_step=8851, episodic_return=[20.]
global_step=8874, episodic_return=[23.]
global_step=8890, episodic_return=[16.]
global_step=8915, episodic_return=[25.]
global_step=8928, episodic_return=[13.]
global_step=8940, episodic_return=[12.]
global_step=8951, episodic_return=[11.]
global_step=8964, episodic_return=[13.]
global_step=9014, episodic_return=[50.]
global_step=9028, episodic_return=[14.]
global_step=9041, episodic_return=[13.]
global_step=9057, episodic_return=[16.]
global_step=9068, episodic_return=[11.]
global_step=9081, episodic_return=[13.]
global_step=9106, episodic_return=[25.]
global_step=9182, episodic_return=[76.]
global_step=9193, episodic_return=[11.]
global_step=9211, episodic_return=[18.]
global_step=9231, episodic_return=[20.]
global_step=9273, episodic_return=[42.]
global_step=9302, episodic_return=[29.]
global_step=9364, episodic_return=[62.]


                                                   

Moviepy - Done !
Moviepy - video ready /home/emasuriano/Git/rn-group-7/dqn/videos/CartPole-v1__dqn__1__1732462861/rl-video-episode-512.mp4
global_step=12037, episodic_return=[27.]
global_step=12062, episodic_return=[25.]
SPS: 3164
global_step=12102, episodic_return=[40.]
global_step=12125, episodic_return=[23.]
global_step=12159, episodic_return=[34.]
global_step=12197, episodic_return=[38.]
SPS: 3165
global_step=12214, episodic_return=[17.]
global_step=12230, episodic_return=[16.]
global_step=12248, episodic_return=[18.]
global_step=12260, episodic_return=[12.]
global_step=12288, episodic_return=[28.]
SPS: 3167
global_step=12306, episodic_return=[18.]
global_step=12330, episodic_return=[24.]
global_step=12364, episodic_return=[34.]
global_step=12377, episodic_return=[13.]
SPS: 3167
global_step=12403, episodic_return=[26.]
global_step=12417, episodic_return=[14.]
global_step=12428, episodic_return=[11.]
global_step=12457, episodic_return=[29.]
global_step=12475, episodic_return=[18.]




SPS: 3167
global_step=12516, episodic_return=[41.]
global_step=12543, episodic_return=[27.]
global_step=12561, episodic_return=[18.]
global_step=12584, episodic_return=[23.]
SPS: 3167
global_step=12614, episodic_return=[30.]
global_step=12661, episodic_return=[47.]
global_step=12674, episodic_return=[13.]
global_step=12688, episodic_return=[14.]
global_step=12699, episodic_return=[11.]
SPS: 3169
global_step=12724, episodic_return=[25.]
global_step=12771, episodic_return=[47.]
SPS: 3171
global_step=12820, episodic_return=[49.]
global_step=12856, episodic_return=[36.]
global_step=12868, episodic_return=[12.]
global_step=12892, episodic_return=[24.]
SPS: 3173
global_step=12902, episodic_return=[10.]
global_step=12928, episodic_return=[26.]
global_step=12948, episodic_return=[20.]
global_step=12959, episodic_return=[11.]
global_step=12981, episodic_return=[22.]
global_step=12997, episodic_return=[16.]
SPS: 3170
global_step=13014, episodic_return=[17.]
global_step=13052, episodic_return=[38

                                                   

Moviepy - Done !
Moviepy - video ready /home/emasuriano/Git/rn-group-7/dqn/videos/CartPole-v1__dqn__1__1732462861/rl-video-episode-729.mp4
global_step=17135, episodic_return=[23.]
global_step=17164, episodic_return=[29.]
global_step=17191, episodic_return=[27.]
SPS: 3095
global_step=17231, episodic_return=[40.]
global_step=17296, episodic_return=[65.]
SPS: 3097
global_step=17360, episodic_return=[64.]
SPS: 3096
global_step=17404, episodic_return=[44.]
global_step=17437, episodic_return=[33.]
global_step=17480, episodic_return=[43.]
SPS: 3094
global_step=17522, episodic_return=[42.]
global_step=17538, episodic_return=[16.]




global_step=17552, episodic_return=[14.]
global_step=17580, episodic_return=[28.]
SPS: 3092
global_step=17616, episodic_return=[36.]
global_step=17629, episodic_return=[13.]
global_step=17646, episodic_return=[17.]
global_step=17691, episodic_return=[45.]
SPS: 3090
global_step=17734, episodic_return=[43.]
global_step=17755, episodic_return=[21.]
global_step=17778, episodic_return=[23.]
SPS: 3089
global_step=17810, episodic_return=[32.]
global_step=17835, episodic_return=[25.]
global_step=17876, episodic_return=[41.]
global_step=17885, episodic_return=[9.]
SPS: 3088
global_step=17908, episodic_return=[23.]
global_step=17968, episodic_return=[60.]
global_step=17986, episodic_return=[18.]
SPS: 3090
global_step=18050, episodic_return=[64.]
global_step=18066, episodic_return=[16.]
global_step=18077, episodic_return=[11.]
global_step=18096, episodic_return=[19.]
SPS: 3091
global_step=18125, episodic_return=[29.]
global_step=18156, episodic_return=[31.]
global_step=18181, episodic_return=[25.

                                                   

Moviepy - Done !
Moviepy - video ready /home/emasuriano/Git/rn-group-7/dqn/videos/CartPole-v1__dqn__1__1732462861/rl-video-episode-1000.mp4
global_step=25051, episodic_return=[32.]
global_step=25068, episodic_return=[17.]
global_step=25093, episodic_return=[25.]
SPS: 3026
global_step=25141, episodic_return=[48.]
global_step=25150, episodic_return=[9.]
global_step=25182, episodic_return=[32.]
SPS: 3024
global_step=25234, episodic_return=[52.]
global_step=25278, episodic_return=[44.]
SPS: 3022
global_step=25303, episodic_return=[25.]
global_step=25322, episodic_return=[19.]
global_step=25348, episodic_return=[26.]




global_step=25389, episodic_return=[41.]
SPS: 3020
global_step=25402, episodic_return=[13.]
global_step=25451, episodic_return=[49.]
global_step=25461, episodic_return=[10.]
global_step=25474, episodic_return=[13.]
SPS: 3018
global_step=25526, episodic_return=[52.]
global_step=25544, episodic_return=[18.]
global_step=25590, episodic_return=[46.]
SPS: 3016
global_step=25607, episodic_return=[17.]
global_step=25623, episodic_return=[16.]
global_step=25698, episodic_return=[75.]
SPS: 3012
global_step=25721, episodic_return=[23.]
global_step=25770, episodic_return=[49.]
global_step=25789, episodic_return=[19.]
SPS: 3010
SPS: 3008
global_step=25916, episodic_return=[127.]
global_step=25970, episodic_return=[54.]
global_step=25999, episodic_return=[29.]
SPS: 3006
global_step=26013, episodic_return=[14.]
global_step=26049, episodic_return=[36.]
global_step=26060, episodic_return=[11.]
global_step=26098, episodic_return=[38.]
SPS: 3005
global_step=26189, episodic_return=[91.]
SPS: 3002
global_

                                                               

Moviepy - Done !
Moviepy - video ready /home/emasuriano/Git/rn-group-7/dqn/videos/CartPole-v1__dqn__1__1732462861/rl-video-episode-2000.mp4
global_step=92729, episodic_return=[192.]
SPS: 2558
SPS: 2557
global_step=92934, episodic_return=[205.]
SPS: 2557
SPS: 2556
global_step=93164, episodic_return=[230.]
SPS: 2556
SPS: 2555
global_step=93354, episodic_return=[190.]
SPS: 2554
SPS: 2554
global_step=93578, episodic_return=[224.]
SPS: 2553
SPS: 2553
SPS: 2552
global_step=93846, episodic_return=[268.]
SPS: 2552
SPS: 2551
global_step=94085, episodic_return=[239.]
SPS: 2551
SPS: 2550
global_step=94291, episodic_return=[206.]
SPS: 2550
global_step=94369, episodic_return=[78.]
SPS: 2550
SPS: 2550
SPS: 2549
global_step=94633, episodic_return=[264.]
SPS: 2549
SPS: 2548
global_step=94829, episodic_return=[196.]
SPS: 2547
global_step=94990, episodic_return=[161.]
SPS: 2547
SPS: 2546
global_step=95132, episodic_return=[142.]
SPS: 2546
SPS: 2545
global_step=95316, episodic_return=[184.]
SPS: 2545
glo

In [31]:
import random
from typing import Callable

import gymnasium as gym
import numpy as np
import torch


def evaluate(
    model_path: str,
    make_env: Callable,
    env_id: str,
    eval_episodes: int,
    run_name: str,
    Model: torch.nn.Module,
    device: torch.device = torch.device("cpu"),
    epsilon: float = 0.05,
    capture_video: bool = True,
):
    envs = gym.vector.SyncVectorEnv([make_env(env_id, 0, 0, capture_video, run_name)])
    model = Model(envs).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    obs, _ = envs.reset()
    episodic_returns = []
    while len(episodic_returns) < eval_episodes:
        if random.random() < epsilon:
            actions = np.array(
                [envs.single_action_space.sample() for _ in range(envs.num_envs)]
            )
        else:
            q_values = model(torch.Tensor(obs).to(device))
            actions = torch.argmax(q_values, dim=1).cpu().numpy()
        next_obs, _, _, _, infos = envs.step(actions)
        if "final_info" in infos:
            for info in infos["final_info"]:
                if "episode" not in info:
                    continue
                print(
                    f"eval_episode={len(episodic_returns)}, episodic_return={info['episode']['r']}"
                )
                episodic_returns += [info["episode"]["r"]]
        obs = next_obs

    return episodic_returns

In [32]:
if args.save_model:
    model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
    torch.save(q_network.state_dict(), model_path)
    print(f"model saved to {model_path}")

    episodic_returns = evaluate(
        model_path,
        make_env,
        args.env_id,
        eval_episodes=10,
        run_name=f"{run_name}-eval",
        Model=QNetwork,
        device=device,
        epsilon=0.05,
    )
    for idx, episodic_return in enumerate(episodic_returns):
        writer.add_scalar("eval/episodic_return", episodic_return, idx)

    if args.upload_model:
        from cleanrl_utils.huggingface import push_to_hub

        repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}"
        repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name
        push_to_hub(
            args,
            episodic_returns,
            repo_id,
            "DQN",
            f"runs/{run_name}",
            f"videos/{run_name}-eval",
        )

    envs.close()
    writer.close()

model saved to runs/CartPole-v1__dqn__1__1732462861/dqn.cleanrl_model


  model.load_state_dict(torch.load(model_path, map_location=device))


Moviepy - Building video /home/emasuriano/Git/rn-group-7/dqn/videos/CartPole-v1__dqn__1__1732462861-eval/rl-video-episode-0.mp4.
Moviepy - Writing video /home/emasuriano/Git/rn-group-7/dqn/videos/CartPole-v1__dqn__1__1732462861-eval/rl-video-episode-0.mp4



                                                               

Moviepy - Done !
Moviepy - video ready /home/emasuriano/Git/rn-group-7/dqn/videos/CartPole-v1__dqn__1__1732462861-eval/rl-video-episode-0.mp4
eval_episode=0, episodic_return=[421.]
Moviepy - Building video /home/emasuriano/Git/rn-group-7/dqn/videos/CartPole-v1__dqn__1__1732462861-eval/rl-video-episode-1.mp4.
Moviepy - Writing video /home/emasuriano/Git/rn-group-7/dqn/videos/CartPole-v1__dqn__1__1732462861-eval/rl-video-episode-1.mp4



                                                               

Moviepy - Done !
Moviepy - video ready /home/emasuriano/Git/rn-group-7/dqn/videos/CartPole-v1__dqn__1__1732462861-eval/rl-video-episode-1.mp4
eval_episode=1, episodic_return=[482.]
eval_episode=2, episodic_return=[430.]
eval_episode=3, episodic_return=[382.]
eval_episode=4, episodic_return=[500.]
eval_episode=5, episodic_return=[383.]
eval_episode=6, episodic_return=[500.]
eval_episode=7, episodic_return=[500.]
Moviepy - Building video /home/emasuriano/Git/rn-group-7/dqn/videos/CartPole-v1__dqn__1__1732462861-eval/rl-video-episode-8.mp4.
Moviepy - Writing video /home/emasuriano/Git/rn-group-7/dqn/videos/CartPole-v1__dqn__1__1732462861-eval/rl-video-episode-8.mp4



                                                               

Moviepy - Done !
Moviepy - video ready /home/emasuriano/Git/rn-group-7/dqn/videos/CartPole-v1__dqn__1__1732462861-eval/rl-video-episode-8.mp4
eval_episode=8, episodic_return=[500.]
eval_episode=9, episodic_return=[325.]
