In [None]:
import gymnasium as gym

import os
import random
import time
from utils.params import Params
from utils.dqn_atari import QNetwork, make_env, linear_schedule

import torch
import torch.optim as optim

import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from stable_baselines3.common.buffers import ReplayBuffer
from tqdm import tqdm

from pathlib import Path

## Parameters

In [None]:
params = Params(
    exp_name="dqn-space-invaders",
    env_id="SpaceInvadersNoFrameskip-v4",
    n_runs=10000000,
    learning_rate=1e-4,
    num_envs=1,
    buffer_size=1000000,
    discount_factor=0.99,
    tau=1.,
    target_network_frequency=1000,
    batch_size=32,
    initial_epsilon=1,
    final_epsilon=0.01,
    epsilon_decay=0.10,
    learning_starts=80000,
    train_frequency=4,
    save_exp_folder=Path("experiments/SpaceInvadersNoFrameskip-v4/"),
    save_model=True,
    seed=1
)

# create the figure folder if it doesn't exists
params.save_exp_folder.mkdir(parents=True, exist_ok=True)

### Seeding

In [None]:
random.seed(params.seed)
np.random.seed(params.seed)
torch.manual_seed(params.seed)
torch.backends.cudnn.deterministic = params.torch_deterministic

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() and params.cuda else "cpu")
run_name = f"{params.env_id}__{params.exp_name}__{params.seed}__{int(time.time())}"

### Tensorboard init

In [None]:
writer = SummaryWriter(f"{params.save_exp_folder}/runs/{run_name}")
writer.add_text(
    "hyperparameters",
    "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}" for key, value in params._asdict().items()]))
)

### env setup

In [None]:
envs = gym.vector.SyncVectorEnv(
    [make_env(params.env_id, params.seed + i, i, params.capture_video, run_name, params.save_exp_folder) for i in range(params.num_envs)]
)

assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported"

### network setup

In [None]:
q_network = QNetwork(envs).to(device)
optimizer = optim.Adam(q_network.parameters(), lr=params.learning_rate)
target_network = QNetwork(envs).to(device)
target_network.load_state_dict(q_network.state_dict())

rb = ReplayBuffer(
    params.buffer_size,
    envs.single_observation_space,
    envs.single_action_space,
    device,
    optimize_memory_usage=True,
    handle_timeout_termination=False,
)

## Training loop

In [None]:
start_time= time.time()

# start the game!
obs, _ = envs.reset(seed=params.seed)

for global_step in tqdm(range(params.n_runs)):
    # select agent action
    epsilon = linear_schedule(params.initial_epsilon, params.final_epsilon, params.epsilon_decay * params.n_runs, global_step)
    if random.random() < epsilon:
        actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
    else:
        q_values = q_network(torch.Tensor(obs).to(device))
        actions = torch.argmax(q_values, dim=1).cpu().numpy()

    # execute action and log data
    next_obs, rewards, terminations, truncations, infos = envs.step(actions)

    # record rewards for plotting purposes
    if "final_info" in infos:
        for info in infos["final_info"]:
            # skip envs that are not done
            if "episode" not in info:
                continue
            # print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
            writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
            writer.add_scalar("charts/episodic_return", info["episode"]["l"], global_step)
            writer.add_scalar("charts/epsilon", epsilon, global_step)
            break

    # save data to replay buffer
    real_next_obs = next_obs.copy()
    for idx, trunc in enumerate(truncations):
        if trunc:
            real_next_obs[idx] = info["final_observation"][idx]
    rb.add(obs, real_next_obs, actions, rewards, terminations, infos)

    obs = next_obs

    # training
    if global_step > params.learning_starts:
        if global_step % params.train_frequency == 0:
            data = rb.sample(params.batch_size)
            
            with torch.no_grad():
                target_max, _ = target_network(data.next_observations).max(dim=1)
                td_target = data.rewards.flatten() + params.discount_factor * target_max * (1 - data.dones.flatten())
            old_val = q_network(data.observations).gather(1, data.actions).squeeze()
            loss = F.mse_loss(td_target, old_val)

            if global_step % 100 == 0:
                writer.add_scalar("losses/td_loss", loss, global_step)
                writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
                # print("SPS:", int(global_step / (time.time() - start_time)))
                writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

            # optimize the model
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # update target network
        if global_step % params.target_network_frequency == 0:
            for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()):
                target_network_param.data.copy_(
                    params.tau * q_network_param.data + (1.0 - params.tau) * target_network_param.data
                )    

## Save model

In [None]:
if params.save_model:
    model_path = f"{params.save_exp_folder}/runs/{run_name}/{params.exp_name}.cleanrl_model"
    torch.save(q_network.state_dict(), model_path)
    print(f"model saved to {model_path}")
    from utils.dqn_eval import evaluate

    episodic_returns = evaluate(
        model_path,
        make_env,
        params.env_id,
        eval_episodes=10,
        run_name=f"{run_name}-eval",
        save_exp_folder=params.save_exp_folder,
        Model=QNetwork,
        device=device,
        epsilon=0.05,
        capture_video=True
    )

    for idx, episodic_return in enumerate(episodic_returns):
        writer.add_scalar("eval/episodic_return", episodic_return, idx)

## Clean up

In [None]:
envs.close()
writer.close()