In [17]:
from datetime import datetime

import torch

from gymnasium import spaces

from stable_baselines3 import DQN
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.preprocessing import is_image_space
from stable_baselines3.dqn.policies import BaseFeaturesExtractor
from stable_baselines3.common.env_util import make_atari_env
from stable_baselines3.common.vec_env import VecFrameStack

from transformers import SwinModel, SwinConfig

In [4]:
SWIN_CONFIG = SwinConfig(
    image_size=84,
    patch_size=3,
    num_channels=4,
    embed_dim=96,
    depths=[2, 3, 2],
    num_heads=[3, 3, 6],
    window_size=7,
    mlp_ratio=4.0,
    drop_path_rate=0.1,
)

SwinModel(
  (embeddings): SwinEmbeddings(
    (patch_embeddings): SwinPatchEmbeddings(
      (projection): Conv2d(4, 96, kernel_size=(3, 3), stride=(3, 3))
    )
    (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): SwinEncoder(
    (layers): ModuleList(
      (0): SwinStage(
        (blocks): ModuleList(
          (0-1): 2 x SwinLayer(
            (layernorm_before): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
            (attention): SwinAttention(
              (self): SwinSelfAttention(
                (query): Linear(in_features=96, out_features=96, bias=True)
                (key): Linear(in_features=96, out_features=96, bias=True)
                (value): Linear(in_features=96, out_features=96, bias=True)
                (dropout): Dropout(p=0.0, inplace=False)
              )
              (output): SwinSelfOutput(
                (dense): Linear(in_features=96, out_features=96, bias=True)
      

In [5]:
class SwinDQN(BaseFeaturesExtractor):
    def __init__(self, observation_space: spaces.Box, features_dim: int = 384, normalized_image: bool = False,):
        assert isinstance(observation_space, spaces.Box), (
            "SwinDQN must be used with a gym.spaces.Box ",
            f"observation space, not {observation_space}",
        )
        super().__init__(observation_space, features_dim)
        # We assume CxHxW images (channels first)
        # Re-ordering will be done by pre-preprocessing or wrapper
        assert is_image_space(observation_space, check_channels=False, normalized_image=normalized_image), (
            "You should use SwinDQN "
            f"only with images not with {observation_space}\n"
            "(you are probably using `CnnPolicy` instead of `MlpPolicy` or `MultiInputPolicy`)\n"
            "If you are using `VecNormalize` or already normalized channel-first images "
            "you should pass `normalize_images=False`: \n"
            "https://stable-baselines3.readthedocs.io/en/master/guide/custom_env.html"
        )

        num_input_channels = observation_space.shape[0]
        config = SwinConfig(
            image_size=84,
            patch_size=3,
            num_channels=num_input_channels,
            embed_dim=96,
            depths=[2, 3, 2],
            num_heads=[3, 3, 6],
            window_size=7,
            mlp_ratio=4.0,
            drop_path_rate=0.1,
        )

        self.swin = SwinModel(config)

    def forward(self, observations: torch.Tensor) -> torch.Tensor:
        return self.swin(observations).pooler_output

In [13]:
env = make_atari_env("ALE/Pong-v5", n_envs=1, seed=42, env_kwargs={"full_action_space": False, "frameskip": 1})
# Frame-stacking with 4 frames
env = VecFrameStack(env, n_stack=4)

NAME = "Swin_DQN_Pong-v5"

current_datetime_str = datetime.now().strftime("%d-%m-%Y_%H:%M:%S")

checkpoint_callback = CheckpointCallback(
    save_freq=500_000,
    save_path=f"runs/checkpoints/{NAME}_{current_datetime_str}",
    name_prefix=NAME,
    save_replay_buffer=True,
    verbose=2
    
)

policy_kwargs = dict(
    features_extractor_class=SwinDQN,
    net_arch=[],
)

model = DQN("CnnPolicy",
            env,
            policy_kwargs=policy_kwargs,
            verbose=1,
            tensorboard_log="runs/logs/",
            batch_size=32,
            buffer_size=10_000,
            exploration_final_eps=0.01,
            exploration_fraction=0.1,
            gradient_steps=1,
            learning_rate=0.0001,
            learning_starts=100_000,
            optimize_memory_usage=True,
            replay_buffer_kwargs={"handle_timeout_termination": False},
            target_update_interval=1000,
            train_freq=4,)

Using cuda device
Wrapping the env in a VecTransposeImage.


In [9]:
model.learn(10_000_000, tb_log_name=f"{NAME}_{current_datetime_str}", callback=checkpoint_callback)

# Save at the end of training
model.save(f"runs/final_saves/{NAME}_{current_datetime_str}/{NAME}")
model.save_replay_buffer(f"runs/final_saves/{NAME}_{current_datetime_str}/{NAME}")

Logging to runs/logs/Swin_DQN_Pong-v5_11-04-2023_15:40:51_1
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.63e+03 |
|    ep_rew_mean      | -20.5    |
|    exploration_rate | 0.996    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 1591     |
|    time_elapsed     | 2        |
|    total_timesteps  | 3601     |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.62e+03 |
|    ep_rew_mean      | -20.5    |
|    exploration_rate | 0.993    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 1533     |
|    time_elapsed     | 4        |
|    total_timesteps  | 7199     |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 3.63e+03 |
|    ep_rew_mean      | -20.6    |
|    exploration_rate | 0.989 

KeyboardInterrupt: 