## Environment Setup

In [None]:
!pip install tensorboardX
!pip install pyglet==1.5.1
!pip install torchsummary
!pip install optuna
!pip install optuna-dashboard

In [None]:
!pip install setuptools==65.5.1
!pip install gym==0.21.0
!pip install stable-baselines3[extra]

In [None]:
# !sudo apt-get install -y xvfb
!pip install pyvirtualdisplay

In [None]:
!nvidia-smi

## Imports

In [None]:
from pyvirtualdisplay import Display

virtual_display = Display(visible=0, size=(1024, 768))
virtual_display.start()

In [None]:
import sys
import os

# Get the absolute path to the parent directory of gym-tetris
gym_tetris_parent_path = os.path.abspath(os.path.join('..', 'gym-tetris'))

# Append the path to the sys.path
sys.path.append(gym_tetris_parent_path)

In [None]:
import random
import time
from distutils.util import strtobool

import gym
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter

from nes_py.wrappers import JoypadSpace
import gym_tetris
from gym_tetris.actions import SIMPLE_MOVEMENT

from torchsummary import summary

## Model

In [None]:
# class QNetwork(nn.Module):
#     def __init__(self, actions_num):
#         super().__init__()
#         self.network =  nn.Sequential(
#             # (1, 84, 84)
#             nn.Conv2d(1, 32, 8, stride=4),
#             nn.ReLU(),
#             # (32, 20, 20)
#             nn.Conv2d(32, 64, 4, stride=2),
#             nn.ReLU(),
#             # (64, 9, 9)
#             nn.Conv2d(64, 64, 3, stride=1),
#             nn.ReLU(),
#             # (64, 7, 7)
#             nn.Flatten(),
#             # 3136
#             nn.Dropout(0.2),
#             nn.Linear(3136, 512),
#             nn.Dropout(0.2),
#             nn.ReLU(),
#             nn.Linear(512, actions_num),
#         )

#     def forward(self, x):
#         return self.network(x / 255.0)

class QNetwork(nn.Module):
    def __init__(self, actions_num, frame_stack=1):
        super().__init__()
        self.network =  nn.Sequential(
            # (frame_stack, 84, 84)
            nn.Conv2d(frame_stack, 32, 8, stride=4),
            nn.ReLU(),
            # (32, 20, 20)
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            # (64, 9, 9)
            nn.Conv2d(64, 64, 3, stride=1),
            nn.ReLU(),
            # (64, 7, 7)
            nn.Flatten(),
            # 3136
            # nn.Dropout(0.2),
            nn.Linear(3136, 512),
            # nn.Dropout(0.2),
            nn.ReLU(),
            nn.Linear(512, actions_num),
        )

    def forward(self, x):
        # print(x.flatten()[0])
        return self.network(x / 255.0)


def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
    slope = (end_e - start_e) / duration
    return max(slope * t + start_e, end_e)

## Training

In [None]:
# from collections import deque
# class MaxAndSkipEnv(gym.Wrapper):
#     def __init__(self, env=None, skip=4):
#         super(MaxAndSkipEnv, self).__init__(env)
#         self._obs_buffer = deque(maxlen=2)
#         self._skip = skip

#     def step(self, action):
#         total_reward = 0.0
#         done = None
#         for _ in range(self._skip):
#             obs, reward, done, info = self.env.step(action)
#             self._obs_buffer.append(obs)
#             total_reward += reward
#             if done:
#                 break
#         max_frame = np.max(np.stack(self._obs_buffer), axis=0)
#         return max_frame, total_reward, done, info
    
#     def reset(self):
#         self._obs_buffer.clear()
#         obs = self.env.reset()
#         self._obs_buffer.append(obs)
#         return obs

In [None]:
class FrameSkipEnv(gym.Wrapper):
    def __init__(self, env=None, skip=4, only_first=False):
        super(FrameSkipEnv, self).__init__(env)
        self._skip = skip
        self._only_first = only_first

    def step(self, action):
        total_reward = 0.0
        done = None
        for i in range(self._skip):
            # Only do the action on the first frame (action 0 is always NOOP)
            real_action = 0 if (self._only_first and i > 0) else action
            obs, reward, done, info = self.env.step(real_action)
            total_reward += reward
            if done:
                break
        return obs, total_reward, done, info
    
    def reset(self):
        obs = self.env.reset()
        return obs

In [None]:
BOX = 47, 95, 209, 176
# Making an environment
def get_env(env_id, seed, capture_video, run_name, video_freq=100, frame_stack=4):
    env = gym_tetris.make(env_id)
    env = JoypadSpace(env, SIMPLE_MOVEMENT)
    env = gym.wrappers.RecordEpisodeStatistics(env)
    if capture_video:
        env = gym.wrappers.RecordVideo(env, f"videos/{run_name}", episode_trigger=lambda ep_num: ep_num % video_freq == 0)
    
    crop = lambda obs : obs[BOX[0]:BOX[2], BOX[1]:BOX[3], :]
    env = gym.wrappers.TransformObservation(env, crop)
    env = gym.wrappers.ResizeObservation(env, (84, 84))
    env = gym.wrappers.GrayScaleObservation(env)

    env = FrameSkipEnv(env, skip=16, only_first=True)
    env = gym.wrappers.FrameStack(env, frame_stack)
    env = FrameSkipEnv(env, skip=2, only_first=False)

    env.seed(seed)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)
    return env

In [None]:
# Evaluation
def evaluate(
    model: torch.nn.Module,
    env_id: str,
    eval_episodes: int,
    run_name: str,
    seed: int,
    device: torch.device = torch.device("cpu"),
    capture_video: bool = True,
    video_frequency: int = 1,
    frame_stack: int = 1
):
    env = get_env(env_id, seed, capture_video, run_name, video_frequency, frame_stack) 
    
    model.eval()

    scores = []
    for _ in range(eval_episodes):
        obs = env.reset()
        done = False
        while not done:
            input = np.expand_dims(np.array(obs), axis=0)
            q_values = model(torch.Tensor(input).to(device))
            action = int(torch.argmax(q_values))
            obs, _, done, info = env.step(action)
        
        print(f"eval_episode={len(scores)}, score={info.get('score')}, episodic_return={info.get('episode')['r']}")
        scores.append(info.get("score"))

    env.close()
    return scores

In [None]:
# Single env training without optuna - for simplicity
def train(args):
    run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}"
    prefix = ""
    
    writer = SummaryWriter(f"runs/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )

    # TRY NOT TO MODIFY: seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    device_name = "cuda" if torch.cuda.is_available() and args.cuda else "cpu"
    device_name = "mps" if torch.backends.mps.is_available() and args.mps else device_name
    device = torch.device(device_name)

    print("device_name:", device_name)

    # env setup
    env = get_env(args.env_id, args.seed, args.capture_video, run_name, args.video_frequency, args.frame_stack)
    assert isinstance(env.action_space, gym.spaces.Discrete), "only discrete action space is supported"

    q_network = QNetwork(env.action_space.n, args.frame_stack).to(device)
    optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate)
    target_network = QNetwork(env.action_space.n, args.frame_stack).to(device)
    target_network.load_state_dict(q_network.state_dict())

    summary(q_network, input_size=(args.frame_stack, 84, 84), batch_size=args.batch_size, device=device_name)

    rb = ReplayBuffer(
        args.buffer_size,
        env.observation_space,
        env.action_space,
        device,
    )

    start_time = time.time()

    # Tracks the number of pieces we have played
    piece_count = 0
    board_height = None
    explore = True
    info = None

    episode_cnt = 0

    input_fps = 2
    input_size = (84,84)
    if args.capture_inputs_video:
        out = cv2.VideoWriter(f'episode0.mp4', cv2.VideoWriter_fourcc(*'mp4v'), input_fps, (input_size[1], input_size[0]), False)

    # TRY NOT TO MODIFY: start the game
    obs = env.reset()
    for global_step in range(args.total_timesteps):
        
        if args.capture_inputs_video and (episode_cnt % args.video_frequency == 0):
            img = np.array(obs).astype('uint8')
            if args.frame_stack > 1:
                img = img[0]
            out.write(img)
        
        # If a new piece has been generated, decide wether we will explore or exploit for this piece
        if (info is not None) and (piece_count != info.get("piece_count")):
            piece_count = info.get("piece_count")
            epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step)
            explore = (random.random() < epsilon)

        if explore:
            action = env.action_space.sample()
        else:
            input = np.expand_dims(np.array(obs), axis=0)
            input = torch.Tensor(input).to(device)
            q_values = q_network(input)
            action = int(torch.argmax(q_values))

        # Play a step with the given action
        next_obs, reward, done, info = env.step(action)

        if not done:
            # Add observation to replay buffer
            rb.add(obs, next_obs, np.array([action]), [reward], [done], [info])
            obs = next_obs
        else:
            print(f"Episode {episode_cnt} completed: {prefix}global_step={global_step},\tepisodic_return={info.get('episode')['r']:.1f},\tscore={info.get('score')}")
            writer.add_scalar("charts/episodic_return", info.get("episode")["r"], global_step)
            writer.add_scalar("charts/episodic_length", info.get("episode")["l"], global_step)
            writer.add_scalar("charts/epsilon", epsilon, global_step)
            writer.add_scalar("charts/score", info.get("score"), global_step)

            obs = env.reset()
            episode_cnt += 1

            if args.capture_inputs_video:
                if episode_cnt % args.video_frequency == 0:
                    out = cv2.VideoWriter(f'episode{episode_cnt}.mp4', cv2.VideoWriter_fourcc(*'mp4v'), input_fps, (input_size[1], input_size[0]), False)
                else:
                    out = None

        # Training Logic
        if global_step > args.learning_starts:
            if global_step % args.train_frequency == 0:
                data = rb.sample(args.batch_size)
                with torch.no_grad():
                    target_max, _ = target_network(data.next_observations).max(dim=1)
                    td_target = data.rewards.flatten() + args.gamma * target_max * (1 - data.dones.flatten())
                old_val = q_network(data.observations).gather(1, data.actions).squeeze()
                # loss = F.mse_loss(td_target, old_val)
                loss = F.mse_loss(old_val, td_target)

                if global_step % 100 == 0:
                    writer.add_scalar("losses/td_loss", loss, global_step)
                    writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
                    writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)

                # optimize the model
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            # update target network
            if global_step % args.target_network_frequency == 0:
                for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()):
                    target_network_param.data.copy_(
                        args.tau * q_network_param.data + (1.0 - args.tau) * target_network_param.data
                    )

            if global_step % args.backup_frequency == 0:
                model_backup_path = f"runs/{run_name}/{args.exp_name}.backup"
                torch.save(q_network.state_dict(), model_backup_path)

    if args.save_model:
        model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model"
        torch.save(q_network.state_dict(), model_path)
        print(f"{prefix}model saved to {model_path}")

        scores = evaluate(
            q_network,
            args.env_id,
            args.eval_episodes,
            run_name=f"{run_name}-eval",
            seed=args.seed,
            device=device,
            capture_video=args.capture_video,
            frame_stack=args.frame_stack
        )

        print("Eval Scores:", scores)
        
    env.close()
    writer.close()

## Main

In [None]:
class Args:
    def __init__(self):
        # Settings
        self.exp_name = "Tetris_DQN"
        self.torch_deterministic = True
        self.cuda = True
        self.mps = False
        self.capture_video = True
        self.capture_inputs_video = True
        self.save_model = True
        self.eval_episodes = 1
        self.video_frequency = 50
        self.backup_frequency = 10000

        # Hyper-Parameters
        self.env_id = "TetrisA-v5"
        self.frame_stack = 4
        self.seed = 2
        self.total_timesteps = 1_000_000
        self.learning_rate = 1e-4
        self.buffer_size = 50_000
        self.gamma = 0.99
        self.tau = 0.999
        self.target_network_frequency = 2000
        self.batch_size = 32
        self.start_e = 1
        self.end_e = 0.05
        self.exploration_fraction = 0.2
        self.learning_starts = 40_000
        self.train_frequency = 1

In [None]:
!rm -r runs/* videos/* images/* episode*.mp4

In [None]:
args = Args()
train(args)

In [None]:
%load_ext tensorboard
%tensorboard --logdir runs

In [None]:
# device = 'cuda'
# state_dict = torch.load('SavedModels/score_1k_incosistent.backup', device)
# q_network = QNetwork(6,4).to(device)
# q_network.load_state_dict(state_dict)

# scores = evaluate(
#     q_network,
#     args.env_id,
#     args.eval_episodes,
#     run_name=f"temp-eval",
#     seed=args.seed,
#     device=device,
#     capture_video=args.capture_video,
#     frame_stack=4
# )

# print("Eval Scores:", scores)