In [None]:
from typing import Dict, Any

import gymnasium as gym
import numpy as np

import torch

from stable_baselines3 import DQN
from stable_baselines3.common.callbacks import BaseCallback, StopTrainingOnMaxEpisodes, EvalCallback, CheckpointCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack
from stable_baselines3.common.logger import Video

In [None]:
env = make_atari_env("ALE/Pong-v5", n_envs=4, seed=42, env_kwargs={"full_action_space": False, "frameskip": 1})
# Frame-stacking with 4 frames
env = VecFrameStack(env, n_stack=4)
#env = gym.wrappers.RecordEpisodeStatistics(env)

In [None]:
class VideoRecoderCallback(BaseCallback):
    def __init__(self, eval_env: gym.Env, render_frequency: int, num_eval_episodes: int = 1, deterministic: bool = True):
        """
        Record a video of an agent's trajectory traversing ``eval_env`` and logs it to TensorBoard

        :param eval_env: A gym environment from which the trajectory is recorded
        :param render_frequency: Render the agent's trajectory every eval_freq call of the callback
        :param num_eval_episodes: Number of episodes to render
        :param deterministic: wether to use deterministic of stochastic policy
        """
        super().__init__()

        self._eval_env = eval_env
        self._render_frequency = render_frequency
        self._num_eval_episodes = num_eval_episodes
        self._deterministic = deterministic

    def _on_step(self) -> bool:
        if self.n_calls % self._render_frequency == 0:
            screens = []

            def grab_screens(_locals: Dict[str, Any], _globals: Dict[str, Any]) -> None:
                """
                Renders the environment in its current state, recording the screen in the captured `screens` list

                :param _locals: A dictionary containing all local variables of the callback's scope
                :param _globals: A dictionary containing all global variables of the callback's scope
                """
                screen = self._eval_env.render()
                # PyTorch uses C*X*W vs H*W*C gym (and tensorflow) image convention
                screens.append(screen.transpose(2, 0, 1))


            evaluate_policy(
                self.model,
                self._eval_env,
                callback=grab_screens,
                n_eval_episodes=self._num_eval_episodes,
                deterministic=self._deterministic
            )

            self.logger.record(
                "trajectory/video",
                Video(torch.ByteTensor([screens]), fps=40),
                exclude=("stdout", "log", "json", "csv"),
            )

        return True

In [None]:
class MeanRewardPerEpisodeMetricCallback(BaseCallback):
    def __init__(self, verbose=0):
        super().__init__(verbose)
        self.episode_count = 0

    def _on_step(self) -> bool:
        if any(self.locals["dones"]):
            print(self.locals["num_collected_episodes"])
            self.episode_count += 1
            self.logger.record("rollout/ep_rew_mean_per_ep", 5)
            self.logger.dump(step=self.episode_count)

            #print(self.locals["dones"])
            #print(self.episode_count)

        #self.episode_count += np.sum(self.locals["dones"]).item()
        #print(self.model._episode_num)
        return True

In [None]:
model = DQN("CnnPolicy",
            env,
            verbose=1,
            tensorboard_log="runs/logs/",
            batch_size=32,
            buffer_size=10_000,
            exploration_final_eps=0.01,
            exploration_fraction=0.1,
            gradient_steps=1,
            learning_rate=0.0001,
            learning_starts=100_000,
            optimize_memory_usage=True,
            replay_buffer_kwargs={"handle_timeout_termination": False},
            target_update_interval=1000,
            train_freq=4,)

#model.policy
#video_recoder = VideoRecoderCallback(env, render_frequency=5000)
#mean_reward_callback = MeanRewardPerEpisodeMetricCallback()

In [None]:
img = torch.rand([1, 4, 84, 84], device="cuda")
output = model.policy(img)
output

In [None]:
model.learn(total_timesteps=10_000_000, tb_log_name="cnn_dqn")

#model.save("DQN_Pong-v5")

In [None]:
model.save("runs/checkpoints/cnn_dqn_4709668_steps")

In [None]:
model.save_replay_buffer("runs/checkpoints/cnn_dqn_rb_4709668_steps")

In [None]:
model.replay_buffer.size()