# DQN

https://pytorch.org/tutorials/intermediate/mario_rl_tutorial.html

### Install

In [1]:
%%bash
pip install -q -r requirements.txt


[notice] A new release of pip is available: 24.0 -> 24.1
[notice] To update, run: pip install --upgrade pip


### Setup

In [2]:
# Gym is an OpenAI toolkit for RL
import gym
from gym.spaces import Box
from gym.wrappers import FrameStack
from gymnasium.wrappers import RecordVideo

# NES Emulator for OpenAI Gym
from nes_py.wrappers import JoypadSpace

# Super Mario environment for OpenAI Gym
import gym_super_mario_bros


from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from tensordict import TensorDict

import torch
from torch import nn
from torchvision import transforms as T

import os
import matplotlib.pyplot as plt
import numpy as np
import wandb
from tqdm import tqdm
from moviepy.editor import ImageSequenceClip
import shutil

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Hyperparameters

In [4]:
STACK_SIZE = 4
FRAMES_PER_BATCH = 1000
EPISODE = 1000

# DQN
BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4

# Tensor Dict
KEY_OBSERVATION = "observation"
KEY_ACTION = "action"
KEY_NEXT_STATE = "next"
KEY_REWARD = "reward"
KEY_DONE = "done"
KEY_TERMINATED = "terminated"

VIDEO_DIR = "/Users/kakuayato/MasterResearch/Mario/Video"
print(f"VIDEO_DIR is set to: {VIDEO_DIR}")

VIDEO_DIR is set to: /Users/kakuayato/MasterResearch/Mario/Video


In [5]:
# remove files in video directory 
shutil.rmtree(VIDEO_DIR, ignore_errors=True)

### ENV

In [6]:
if gym.__version__ < '0.26':
    env = gym_super_mario_bros.make("SuperMarioBros-1-1-v0", new_step_api=True)
else:
    env = gym_super_mario_bros.make("SuperMarioBros-1-1-v0", render_mode='rgb_array', apply_api_compatibility=True,)

In [7]:
# env = RecordVideo(env, video_folder="./video/mario/dqn")

In [8]:
# Limit the action-space to
#   0. walk right
#   1. jump right
env = JoypadSpace(env, [["right"], ["right", "A"]])

In [9]:
class SkipFrame(gym.Wrapper):
    def __init__(self, env, skip):
        """Return only every `skip`-th frame"""
        super().__init__(env)
        self._skip = skip

    def step(self, action):
        """Repeat action, and sum reward"""
        total_reward = 0.0
        for i in range(self._skip):
            # Accumulate reward and repeat the same action
            obs, reward, done, trunk, info = self.env.step(action)
            total_reward += reward
            if done:
                break
        return obs, total_reward, done, trunk, info


# 画像をグレースケールに変換
class GrayScaleObservation(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        obs_shape = self.observation_space.shape[:2]
        self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)

    def permute_orientation(self, observation):
        # permute [H, W, C] array to [C, H, W] tensor
        observation = np.transpose(observation, (2, 0, 1))
        observation = torch.tensor(observation.copy(), dtype=torch.float)
        return observation

    def observation(self, observation):
        observation = self.permute_orientation(observation)
        transform = T.Grayscale()
        observation = transform(observation)
        return observation


# 0 ~ 255 -> 0 ~ 1 に変換
class ResizeObservation(gym.ObservationWrapper):
    def __init__(self, env, shape):
        super().__init__(env)
        if isinstance(shape, int):
            self.shape = (shape, shape)
        else:
            self.shape = tuple(shape)

        obs_shape = self.shape + self.observation_space.shape[2:]
        self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)

    def observation(self, observation):
        transforms = T.Compose(
            [T.Resize(self.shape, antialias=True), T.Normalize(0, 255)]
        )
        observation = transforms(observation).squeeze(0)
        return observation

In [10]:
env = SkipFrame(env, skip=4)
env = GrayScaleObservation(env)
env = ResizeObservation(env, shape=84)
if gym.__version__ < '0.26':
    env = FrameStack(env, num_stack=STACK_SIZE, new_step_api=True)
else:
    env = FrameStack(env, num_stack=STACK_SIZE)

### Model

In [11]:
class CNN(nn.Module):
    def __init__(self, n_actions, channels):
        super(CNN, self).__init__()
        self.steps_done = 0
        self.model = nn.Sequential(
            nn.Conv2d(channels, 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Dropout2d(0.25),
            nn.Flatten(),
            nn.Linear(576, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, n_actions),
            nn.LogSoftmax(dim=1)
        )

    def forward(self, x):
        return self.model(x)
    
    def action(self, state):
        self.steps_done += 1
        eps_threshold = EPS_END + (EPS_START - EPS_END) * np.exp(-1. * self.steps_done / EPS_DECAY)
        wandb.log({
            "policy/eps_threshold": eps_threshold
        })

        if np.random.rand() < eps_threshold:
            # 探索 
            action = env.action_space.sample()
            return torch.tensor([[action]], device=device, dtype=torch.long)
        else:
            # 活用
            with torch.no_grad():
                return self.forward(state).max(1)[1].view(1, 1)

In [12]:
model = CNN(env.action_space.n, env.observation_space.shape[0]).to(device)

### Replay Buffer

In [13]:
replay_buffer = ReplayBuffer(
    storage=LazyTensorStorage(max_size=FRAMES_PER_BATCH),
    sampler=SamplerWithoutReplacement(),
)

### Training

In [14]:
def colect_data(env,model,replay_buffer,device,current_epoch,):
    state, _ = env.reset()
    state = torch.tensor(np.array(state), device=device, dtype=torch.float32)
    done = False
    total_reward = 0
    steps = 0
    frames = []
    while not done:
        action = model.action(state.unsqueeze(0))
        next_state, reward, done, terminated, _ = env.step(action.item())
        total_reward += reward
        steps += 1

        next_state = torch.tensor(np.array(next_state), device=device, dtype=torch.float32)
        reward = torch.tensor([reward], device=device)
        done = torch.tensor([done], device=device, dtype=torch.bool)
        terminated = torch.tensor([terminated], device=device, dtype=torch.bool)

        frames.append(env.render())

        td = TensorDict({
            KEY_OBSERVATION: state.unsqueeze(0),
            KEY_ACTION: action,
            KEY_NEXT_STATE: {
                KEY_OBSERVATION: next_state.unsqueeze(0),
                KEY_REWARD: reward.unsqueeze(0),
                KEY_DONE: done.unsqueeze(0),
                KEY_TERMINATED: terminated.unsqueeze(0),
            },
        }, [1])
        
        replay_buffer.extend(td.reshape(-1).cpu())
        state = next_state

    wandb.log({
        "total_reward": total_reward,
        "steps": steps,
    })

    if current_epoch % 10 == 0:
        clip = ImageSequenceClip(frames, fps=30)
        clip.write_videofile(os.path.join(VIDEO_DIR, f"{current_epoch}.mp4"), verbose=False, logger=None)
        wandb.log({"video": wandb.Video(os.path.join(VIDEO_DIR, f"{current_epoch}.mp4"))})

In [15]:
def optimize(
        policy_net,
        target_net,
        replay_buffer,
        optimizer,
        device,
):
    if len(replay_buffer) < BATCH_SIZE:
        return

    batch = replay_buffer.sample(BATCH_SIZE).to(device)

    # Q(s, a)
    state_values = policy_net(batch[KEY_OBSERVATION])
    state_action_values = state_values.gather(1, batch[KEY_ACTION])

    # Q(s', a') = max_a' Q(s', a')
    # s' が終端状態である場合は、Q(s', a') = 0 とする
    next_state_values = torch.zeros(len(batch), device=device)
    with torch.no_grad():
        _next_state_values = target_net(batch[KEY_NEXT_STATE][KEY_OBSERVATION]).max(1)[0].detach()
    non_final_mask = torch.nonzero(~batch[KEY_NEXT_STATE][KEY_DONE].squeeze(1))
    next_state_values[non_final_mask] = _next_state_values[non_final_mask]

    # Q*(s, a) = r + γ * max_a' Q(s', a')
    expected_state_action_values = (next_state_values.unsqueeze(1) * GAMMA) + batch[KEY_NEXT_STATE][KEY_REWARD]

    criterion = nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values)
    wandb.log({"loss": loss.item()})

    optimizer.zero_grad()
    loss.backward()

    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()

In [16]:
policy_net = CNN(env.action_space.n, env.observation_space.shape[0]).to(device)
target_net = CNN(env.action_space.n, env.observation_space.shape[0]).to(device)
target_net.load_state_dict(policy_net.state_dict())

optimizer = torch.optim.Adam(policy_net.parameters(), lr=LR, amsgrad=True)

In [17]:
if wandb.run is None:
    wandb.init(
        project="mario",
        tags=["mario", "dqn"],
        monitor_gym=True,
    )

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

In [None]:
wandb.watch(policy_net)
for i_episode in tqdm(range(EPISODE)):
    colect_data(env, policy_net, replay_buffer, device, current_epoch=i_episode)
    optimize(policy_net, target_net, replay_buffer, optimizer, device)

    # Soft update
    target_net_state_dict = target_net.state_dict()
    policy_net_state_dict = policy_net.state_dict()
    for key in target_net_state_dict:
        target_net_state_dict[key] = TAU * policy_net_state_dict[key] + (1 - TAU) * target_net_state_dict[key]
    target_net.load_state_dict(target_net_state_dict)

  0%|          | 0/1000 [00:02<?, ?it/s]


OSError: [Errno 32] Broken pipe

MoviePy error: FFMPEG encountered the following error while writing file video/mario/dqn/0.mp4:

 b'video/mario/dqn/0.mp4: No such file or directory\n'

## Playgrounds

In [None]:
state, _ = env.reset()
state = torch.tensor(np.array(state), device=device).unsqueeze(0)
model(state)

In [None]:
model.action(state)

In [None]:
replay_buffer = ReplayBuffer(
    storage=LazyTensorStorage(max_size=1024),
    sampler=SamplerWithoutReplacement(),
)

# collect data
state, _ = env.reset()
state = torch.tensor(np.array(state), device=device, dtype=torch.float32)
is_done = False
while not is_done:
    action = model.action(state.unsqueeze(0))
    next_state, reward, done, terminated, _ = env.step(action.item())
    is_done = done or terminated
    next_state = torch.tensor(np.array(next_state), device=device, dtype=torch.float32)
    reward = torch.tensor([reward], device=device)
    done = torch.tensor([done], device=device, dtype=torch.bool)
    terminated = torch.tensor([terminated], device=device, dtype=torch.bool)

    td = TensorDict({
        KEY_OBSERVATION: state.unsqueeze(0),
        KEY_ACTION: action,
        KEY_NEXT_STATE: {
            KEY_OBSERVATION: next_state.unsqueeze(0),
            KEY_REWARD: reward.unsqueeze(0),
            KEY_DONE: done.unsqueeze(0),
            KEY_TERMINATED: terminated.unsqueeze(0),
        },
    }, [1])
    
    replay_buffer.extend(td.reshape(-1).cpu())
    state = next_state

In [None]:
values = torch.zeros(3, 5)
random_tensor = torch.randn(3, 5)
tensor = torch.tensor([[True],
                       [False],
                       [True]])
mask = torch.nonzero(tensor.squeeze(1))
values[mask] = random_tensor[mask]
values

In [None]:
batch_size = 1024
batch = replay_buffer.sample(batch_size).to(device)
non_final_mask = torch.nonzero(~batch[KEY_NEXT_STATE][KEY_DONE].squeeze(1))

state_values = policy_net(batch[KEY_OBSERVATION])
state_action_values = state_values.gather(1, batch[KEY_ACTION])

next_state_values = torch.zeros(len(batch), device=device)
with torch.no_grad():
    _next_state_values = target_net(batch[KEY_NEXT_STATE][KEY_OBSERVATION]).max(1)[0].detach()
    next_state_values[non_final_mask] = _next_state_values[non_final_mask]

# Q(s, a) = r + γ * max_a' Q(s', a')
expected_state_action_values = (next_state_values.unsqueeze(1) * GAMMA) + batch[KEY_NEXT_STATE][KEY_REWARD]

criterion = nn.SmoothL1Loss()
loss = criterion(state_action_values, expected_state_action_values)
