In [1]:
from datetime import datetime

import torch

from gymnasium import spaces

from stable_baselines3 import DQN
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.preprocessing import is_image_space
from stable_baselines3.dqn.policies import BaseFeaturesExtractor
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack

from transformers import SwinModel, SwinConfig

In [2]:
SWIN_CONFIG = SwinConfig(
    image_size=84,
    patch_size=3,
    num_channels=4,
    embed_dim=96,
    depths=[2, 3, 2],
    num_heads=[3, 3, 6],
    window_size=7,
    mlp_ratio=4.0,
    drop_path_rate=0.1,
)

In [3]:
class SwinDQN(BaseFeaturesExtractor):
    def __init__(self, observation_space: spaces.Box, features_dim: int = 384, normalized_image: bool = False,):
        assert isinstance(observation_space, spaces.Box), (
            "SwinDQN must be used with a gym.spaces.Box ",
            f"observation space, not {observation_space}",
        )
        super().__init__(observation_space, features_dim)
        # We assume CxHxW images (channels first)
        # Re-ordering will be done by pre-preprocessing or wrapper
        assert is_image_space(observation_space, check_channels=False, normalized_image=normalized_image), (
            "You should use SwinDQN "
            f"only with images not with {observation_space}\n"
            "(you are probably using `CnnPolicy` instead of `MlpPolicy` or `MultiInputPolicy`)\n"
            "If you are using `VecNormalize` or already normalized channel-first images "
            "you should pass `normalize_images=False`: \n"
            "https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html"
        )

        num_input_channels = observation_space.shape[0]
        config = SwinConfig(
            image_size=84,
            patch_size=3,
            num_channels=num_input_channels,
            embed_dim=96,
            depths=[2, 3, 2],
            num_heads=[3, 3, 6],
            window_size=7,
            mlp_ratio=4.0,
            drop_path_rate=0.1,
        )

        self.swin = SwinModel(config)

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        return self.swin(observations).pooler_output

In [4]:
env = make_atari_env("ALE/Pong-v5", n_envs=1, seed=42, env_kwargs={"full_action_space": False, "frameskip": 1})
# Frame-stacking with 4 frames
env = VecFrameStack(env, n_stack=4)

NAME = "Swin_DQN_Pong-v5"

current_datetime_str = datetime.now().strftime("%d-%m-%Y_%H:%M:%S")

checkpoint_callback = CheckpointCallback(
    save_freq=500_000,
    save_path=f"runs/checkpoints/{NAME}_{current_datetime_str}",
    name_prefix=NAME,
    save_replay_buffer=True,
    verbose=2
    
)

policy_kwargs = dict(
    features_extractor_class=SwinDQN,
    net_arch=[],
)

model = DQN("CnnPolicy",
            env,
            policy_kwargs=policy_kwargs,
            verbose=1,
            tensorboard_log="runs/logs/",
            batch_size=32,
            buffer_size=10_000,
            exploration_final_eps=0.01,
            exploration_fraction=0.1,
            gradient_steps=1,
            learning_rate=0.0001,
            learning_starts=100_000,
            optimize_memory_usage=True,
            replay_buffer_kwargs={"handle_timeout_termination": False},
            target_update_interval=1000,
            train_freq=4,)

A.L.E: Arcade Learning Environment (version 0.8.1+53f58b7)
[Powered by Stella]


Using cuda device
Wrapping the env in a VecTransposeImage.


In [None]:
model.learn(10_000_000, tb_log_name=f"{NAME}_{current_datetime_str}", callback=checkpoint_callback)

# Save at the end of training
model.save(f"runs/final_saves/{NAME}_{current_datetime_str}/{NAME}")
model.save_replay_buffer(f"runs/final_saves/{NAME}_{current_datetime_str}/{NAME}")

Logging to runs/logs/Swin_DQN_Pong-v5_11-04-2023_16:03:09_1
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.63e+03 |
|    ep_rew_mean      | -20.5    |
|    exploration_rate | 0.996    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 1965     |
|    time_elapsed     | 1        |
|    total_timesteps  | 3601     |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.62e+03 |
|    ep_rew_mean      | -20.5    |
|    exploration_rate | 0.993    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 1974     |
|    time_elapsed     | 3        |
|    total_timesteps  | 7199     |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.63e+03 |
|    ep_rew_mean      | -20.6    |
|    exploration_rate | 0.989 

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.67e+03 |
|    ep_rew_mean      | -20.4    |
|    exploration_rate | 0.917    |
| time/               |          |
|    episodes         | 92       |
|    fps              | 2010     |
|    time_elapsed     | 41       |
|    total_timesteps  | 83825    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.68e+03 |
|    ep_rew_mean      | -20.4    |
|    exploration_rate | 0.913    |
| time/               |          |
|    episodes         | 96       |
|    fps              | 2010     |
|    time_elapsed     | 43       |
|    total_timesteps  | 87849    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.68e+03 |
|    ep_rew_mean      | -20.4    |
|    exploration_rate | 0.909    |
| time/               |          |
|    episodes       

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.63e+03 |
|    ep_rew_mean      | -20.3    |
|    exploration_rate | 0.857    |
| time/               |          |
|    episodes         | 160      |
|    fps              | 134      |
|    time_elapsed     | 1073     |
|    total_timesteps  | 144267   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0402   |
|    n_updates        | 11066    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.64e+03 |
|    ep_rew_mean      | -20.3    |
|    exploration_rate | 0.853    |
| time/               |          |
|    episodes         | 164      |
|    fps              | 130      |
|    time_elapsed     | 1139     |
|    total_timesteps  | 148257   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00584  |
|    n_updates      

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.62e+03 |
|    ep_rew_mean      | -20.4    |
|    exploration_rate | 0.799    |
| time/               |          |
|    episodes         | 224      |
|    fps              | 105      |
|    time_elapsed     | 1921     |
|    total_timesteps  | 202615   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0027   |
|    n_updates        | 25653    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.62e+03 |
|    ep_rew_mean      | -20.4    |
|    exploration_rate | 0.796    |
| time/               |          |
|    episodes         | 228      |
|    fps              | 102      |
|    time_elapsed     | 2007     |
|    total_timesteps  | 206179   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00116  |
|    n_updates      

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.66e+03 |
|    ep_rew_mean      | -20.4    |
|    exploration_rate | 0.741    |
| time/               |          |
|    episodes         | 288      |
|    fps              | 78       |
|    time_elapsed     | 3338     |
|    total_timesteps  | 261380   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00206  |
|    n_updates        | 40344    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.67e+03 |
|    ep_rew_mean      | -20.4    |
|    exploration_rate | 0.738    |
| time/               |          |
|    episodes         | 292      |
|    fps              | 77       |
|    time_elapsed     | 3429     |
|    total_timesteps  | 265134   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00111  |
|    n_updates      

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.7e+03  |
|    ep_rew_mean      | -20.3    |
|    exploration_rate | 0.683    |
| time/               |          |
|    episodes         | 352      |
|    fps              | 67       |
|    time_elapsed     | 4778     |
|    total_timesteps  | 320666   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00128  |
|    n_updates        | 55166    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.71e+03 |
|    ep_rew_mean      | -20.3    |
|    exploration_rate | 0.679    |
| time/               |          |
|    episodes         | 356      |
|    fps              | 66       |
|    time_elapsed     | 4869     |
|    total_timesteps  | 324405   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00447  |
|    n_updates      

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.71e+03 |
|    ep_rew_mean      | -20.3    |
|    exploration_rate | 0.624    |
| time/               |          |
|    episodes         | 416      |
|    fps              | 60       |
|    time_elapsed     | 6238     |
|    total_timesteps  | 379981   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00418  |
|    n_updates        | 69995    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.72e+03 |
|    ep_rew_mean      | -20.2    |
|    exploration_rate | 0.62     |
| time/               |          |
|    episodes         | 420      |
|    fps              | 60       |
|    time_elapsed     | 6333     |
|    total_timesteps  | 383818   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00384  |
|    n_updates      

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.84e+03 |
|    ep_rew_mean      | -20.2    |
|    exploration_rate | 0.563    |
| time/               |          |
|    episodes         | 480      |
|    fps              | 56       |
|    time_elapsed     | 7772     |
|    total_timesteps  | 441373   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00751  |
|    n_updates        | 85343    |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.83e+03 |
|    ep_rew_mean      | -20.2    |
|    exploration_rate | 0.56     |
| time/               |          |
|    episodes         | 484      |
|    fps              | 56       |
|    time_elapsed     | 7862     |
|    total_timesteps  | 444937   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00181  |
|    n_updates      

----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.97e+03 |
|    ep_rew_mean      | -20.1    |
|    exploration_rate | 0.5      |
| time/               |          |
|    episodes         | 544      |
|    fps              | 53       |
|    time_elapsed     | 9390     |
|    total_timesteps  | 505069   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.0021   |
|    n_updates        | 101267   |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.99e+03 |
|    ep_rew_mean      | -20.1    |
|    exploration_rate | 0.496    |
| time/               |          |
|    episodes         | 548      |
|    fps              | 53       |
|    time_elapsed     | 9499     |
|    total_timesteps  | 509352   |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.00277  |
|    n_updates      