# Installs

In [None]:
!pip install tensorboardX
!pip install pyglet==1.5.1
!pip install torchsummary
!pip install optuna
!pip install optuna-dashboard
!pip install torchrl
!pip install setuptools==65.5.1
!pip install gym==0.21.0
!pip install stable-baselines3[extra]
!pip install lz4
!sudo apt-get install -y xvfb
!pip install pyvirtualdisplay

In [None]:
!nvidia-smi

# Imports

In [None]:
from pyvirtualdisplay import Display

virtual_display = Display(visible=0, size=(1024, 768))
virtual_display.start()

In [None]:
import sys
import os

# Get the absolute path to the parent directory of gym-tetris
gym_tetris_parent_path = os.path.abspath(os.path.join('..', 'gym-tetris'))

# Append the path to the sys.path
sys.path.append(gym_tetris_parent_path)

In [None]:
import random
import time
from distutils.util import strtobool

from gym import Wrapper, ObservationWrapper
from gym.wrappers import RecordEpisodeStatistics, RecordVideo, FrameStack
from gym.spaces import Box, Discrete

from nes_py.wrappers import JoypadSpace
from gym_tetris.actions import SIMPLE_MOVEMENT
from gym_tetris.tetris_env import TetrisEnv

import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter

# from stable_baselines3.common.buffers import ReplayBuffer

from tensordict import TensorDict
from torchrl.data import TensorDictPrioritizedReplayBuffer, LazyTensorStorage
# from torchrl.data import PrioritizedReplayBuffer, ListStorage, LazyMemmapStorage

from torchsummary import summary
from collections import deque

In [None]:
import optuna
from optuna.pruners import MedianPruner
from optuna.samplers import TPESampler
from optuna.visualization import plot_optimization_history, plot_param_importances

# Model

In [None]:
class ItaiQNetworkV1_1(nn.Module):
    def __init__(self, num_actions, input_dim=(20, 10)):
        super().__init__()
        self._frames = 4
        self._num_actions = num_actions

        # CNN modeled off of Mnih et al.
        self.cnn = nn.Sequential(
            nn.Conv2d(self._frames, 32, kernel_size=4, stride=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=2, stride=1),
            nn.ReLU()
        )

        self.fc_layer_inputs = self.cnn_out_dim(input_dim)

        self.fully_connected = nn.Sequential(
            nn.Linear(self.fc_layer_inputs, 512, bias=True),
            nn.ReLU(),
            nn.Linear(512, 128, bias=True),
            nn.ReLU(),
            nn.Linear(128, self._num_actions))
        
        # Load existing state dict
        state_dict = torch.load("Itai_Models/1-active-4stacked-v1-84")
        self.load_state_dict(state_dict)

    def cnn_out_dim(self, input_dim):
        return self.cnn(torch.zeros(1, self._frames, *input_dim)).flatten().shape[0]

    def forward(self, x):
        cnn_out = self.cnn(x).reshape(-1, self.fc_layer_inputs)
        return self.fully_connected(cnn_out)

In [None]:
class ItaiQNetworkV1_2(nn.Module):
    def __init__(self, num_actions, input_dim=(20, 10)):
        super().__init__()
        self._frames = 4
        self._num_actions = num_actions

        # CNN modeled off of Mnih et al.
        self.cnn = nn.Sequential(
            nn.Conv2d(self._frames, 32, kernel_size=4, stride=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=2, stride=1),
            nn.ReLU()
        )

        self.fc_layer_inputs = self.cnn_out_dim(input_dim)

        self.fully_connected = nn.Sequential(
            nn.Linear(self.fc_layer_inputs, 512, bias=True),
            nn.ReLU(),
            nn.Linear(512, 128, bias=True),
            nn.ReLU(),
            nn.Linear(128, self._num_actions))
        
        # Load existing state dict
        state_dict = torch.load("Itai_Models/1-active-4stacked-v1-90")
        self.load_state_dict(state_dict)

    def cnn_out_dim(self, input_dim):
        return self.cnn(torch.zeros(1, self._frames, *input_dim)).flatten().shape[0]

    def forward(self, x):
        cnn_out = self.cnn(x).reshape(-1, self.fc_layer_inputs)
        return self.fully_connected(cnn_out)

In [None]:
class ItaiQNetworkV1_3(nn.Module):
    def __init__(self, num_actions, input_dim=(20, 10)):
        super().__init__()
        self._frames = 4
        self._num_actions = num_actions

        # CNN modeled off of Mnih et al.
        self.cnn = nn.Sequential(
            nn.Conv2d(self._frames, 32, kernel_size=4, stride=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=2, stride=1),
            nn.ReLU()
        )

        self.fc_layer_inputs = self.cnn_out_dim(input_dim)

        self.fully_connected = nn.Sequential(
            nn.Linear(self.fc_layer_inputs, 512, bias=True),
            nn.ReLU(),
            nn.Linear(512, 128, bias=True),
            nn.ReLU(),
            nn.Linear(128, self._num_actions))
        
        # Load existing state dict
        state_dict = torch.load("Itai_Models/5-active-4stacked-v1-102")
        self.load_state_dict(state_dict)

    def cnn_out_dim(self, input_dim):
        return self.cnn(torch.zeros(1, self._frames, *input_dim)).flatten().shape[0]

    def forward(self, x):
        cnn_out = self.cnn(x).reshape(-1, self.fc_layer_inputs)
        return self.fully_connected(cnn_out)

In [None]:
class ItaiQNetworkV2(nn.Module):
    def __init__(self, num_actions, input_dim=(20, 10)):
        super().__init__()
        self._frames = 4
        self.num_actions = num_actions

        # CNN modeled off of Mnih et al.
        self.cnn = nn.Sequential(
            nn.Conv2d(self._frames, 32, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=2, stride=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=2, stride=1),
            nn.ReLU()
        )

        self.fc_layer_inputs = self.cnn_out_dim(input_dim)

        self.fully_connected = nn.Sequential(
            nn.Linear(self.fc_layer_inputs, 512, bias=True),
            nn.ReLU(),
            nn.Linear(512, self.num_actions))
        
        # Load existing state dict
        state_dict = torch.load("Itai_Models/8-active-4stacked-v2-79")
        self.load_state_dict(state_dict)
        

    def cnn_out_dim(self, input_dim):
        return self.cnn(torch.zeros(1, self._frames, *input_dim)).flatten().shape[0]

    def forward(self, x):
        cnn_out = self.cnn(x).reshape(-1, self.fc_layer_inputs)
        return self.fully_connected(cnn_out)

In [None]:
class ItaiQNetworkV4(nn.Module):
    def __init__(self, num_actions, input_dim=(20,10)):
        super().__init__()
        self._frames = 4
        self._num_actions = num_actions

        # CNN modeled off of Mnih et al.
        self.cnn = nn.Sequential(
            nn.Conv2d(self._frames, 32, kernel_size=4, stride=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=2, stride=1),
            nn.ReLU(),
        )

        self.fc_layer_inputs = self.cnn_out_dim(input_dim)

        self.fully_connected = nn.Sequential(
            nn.Linear(self.fc_layer_inputs, 512, bias=True),
            nn.ReLU(),
            nn.Linear(512, self._num_actions))
        
        # Load existing state dict
        state_dict = torch.load("Itai_Models/8-active-4stacked-v4-78")
        self.load_state_dict(state_dict)

    def cnn_out_dim(self, input_dim):
        return self.cnn(torch.zeros(1, self._frames, *input_dim)).flatten().shape[0]

    def forward(self, x):
        cnn_out = self.cnn(x).reshape(-1, self.fc_layer_inputs)
        return self.fully_connected(cnn_out)

In [None]:
# FC network
class FC_QNetwork(nn.Module):
    def __init__(self, num_actions):
        super().__init__()
        self._network = nn.Sequential(
            # (4, 20, 10)
            nn.Flatten(),
            # 800
            nn.Linear(800, 1024),
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, num_actions),
        )
    
    def forward(self, x):
        return self._network(x)

In [None]:
def get_model_class(model_name : str):
    if model_name == "FC":
        return FC_QNetwork
    elif model_name == "Itai_v1_1":
        return ItaiQNetworkV1_1
    elif model_name == "Itai_v1_2":
        return ItaiQNetworkV1_2
    elif model_name == "Itai_v1_3":
        return ItaiQNetworkV1_3
    elif model_name == "Itai_v2":
        return ItaiQNetworkV2
    elif model_name == "Itai_v4":
        return ItaiQNetworkV4
    else:
        print("Not a valid architecture")

# Environment

In [None]:
# DEVICE_NAME = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE_NAME = "cpu"
DEVICE = torch.device(DEVICE_NAME)

In [None]:
# Frame Skip
class FrameSkipEnv(Wrapper):
    def __init__(self, env=None, skip=4):
        super(FrameSkipEnv, self).__init__(env)
        self._skip = skip

    def step(self, action):
        total_reward = 0.0
        done = None
        for _ in range(self._skip):
            obs, reward, done, info = self.env.step(action)
            total_reward += reward
            if done:
                break
        return obs, total_reward, done, info

    def reset(self):
        obs = self.env.reset()
        return obs

In [None]:
# Board Constants
GAME_BOX = 47, 95, 209, 176
BOARD_SHAPE = 20, 10
y_step = (GAME_BOX[2] - GAME_BOX[0]) // BOARD_SHAPE[0]
x_step = (GAME_BOX[3] - GAME_BOX[1]) // BOARD_SHAPE[1]

In [None]:
# Binary Board
class BinaryBoard(ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = Box(0, 1, BOARD_SHAPE)

    def observation(self, obs):
        # Given an image of the current board, obtain a binary (20x10) representation
        gray = np.mean(obs, axis=-1)
        cropped = gray[GAME_BOX[0]+(y_step//2) : GAME_BOX[2] : y_step,
                       GAME_BOX[1]+(x_step//2) : GAME_BOX[3] : x_step]
        assert cropped.shape == BOARD_SHAPE
        cropped[cropped > 1] = 1
        return cropped

In [None]:
# Tensor Wrapper
class TensorWrapper(ObservationWrapper):
    def __init__(self, env=None):
        super(TensorWrapper, self).__init__(env)
    
    def observation(self, obs):
        np_obs = np.array([obs])
        return torch.Tensor(np_obs).to(DEVICE)

In [None]:
# Get Environment
FRAME_SKIP = 6
# Making an environment
def get_env(args, run_name : str = "run", capture_video : bool = False):
    env = TetrisEnv(
        line_weight=args.line_weight,
        height_weight=args.height_weight,
        cost_weight=args.cost_weight,
        holes_weight=args.holes_weight,
        bumpiness_weight=args.bumpiness_weight,
        col_transitions_weight=args.col_transitions_weight,
        row_transitions_weight=args.row_transitions_weight,
    )

    env = RecordEpisodeStatistics(env)
    if capture_video:
        env = RecordVideo(env, f"videos/{run_name}", episode_trigger=lambda ep_num: ep_num % args.video_frequency == 0)
    
    env = JoypadSpace(env, args.movement)
    env = FrameSkipEnv(env, skip=FRAME_SKIP)
    env = BinaryBoard(env)
    env = FrameStack(env, args.frame_stack)
    env = TensorWrapper(env)

    env.seed(args.seed)
    env.action_space.seed(args.seed)
    env.observation_space.seed(args.seed)
    return env

# Training

In [None]:
# Epsilon scheduling
def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
    slope = (end_e - start_e) / duration
    return max(slope * t + start_e, end_e)

In [None]:
# Video Constants
FPS = 60 / FRAME_SKIP
SCALE_UP = 10

In [None]:
# Write the current observation into the video
def write_video_frame(out, obs, frame_stack : int = 1, scale_up : int = 1):
    img = np.array(obs.cpu(), dtype='uint8')[0] * 255
    if frame_stack > 1:
        img = img[-1]
    if scale_up > 1:
        img = np.repeat(np.repeat(img, scale_up, axis=0), scale_up, axis=1)
    out.write(img)

In [None]:
# Evaluation
def evaluate(args, model: torch.nn.Module, eval_name : str = "eval", no_video : bool = False):
    env = get_env(args, run_name=eval_name, capture_video=args.capture_eval_video) 
    
    model.eval()

    total_lines = 0.0
    for episode in range(args.eval_episodes):
        if args.capture_inputs_video and not no_video:
            out = cv2.VideoWriter(f'eval_episode{episode}.mp4', cv2.VideoWriter_fourcc(*'mp4v'), FPS, (BOARD_SHAPE[1]*SCALE_UP, BOARD_SHAPE[0]*SCALE_UP), False)
        
        obs = env.reset()
        done = False
        while not done:
            if args.capture_inputs_video and not no_video:
                write_video_frame(out, obs, args.frame_stack, SCALE_UP)

            q_values = model(obs)
            action = int(torch.argmax(q_values))
            obs, _, done, info = env.step(action)
        
        total_lines += info.get('lines')
    env.close()
    mean_lines = total_lines / args.eval_episodes
    return mean_lines

In [None]:
def write_episode_scalars(writer, global_step, info, epsilon):
    writer.add_scalar("charts/episodic_return", info.get("episode")["r"], global_step)
    writer.add_scalar("charts/episodic_length", info.get("episode")["l"], global_step)
    writer.add_scalar("charts/epsilon", epsilon, global_step)
    # writer.add_scalar("charts/score", info.get("score"), global_step)
    writer.add_scalar("charts/lines", info.get("lines"), global_step)

In [None]:
# Single env training without optuna - for simplicity
def train(args, start_model_path=None, trial=None):
  try:
    run_name = f"{args.exp_name}__{args.seed}__{args.run_id}"
    prefix = ""

    if trial:
      run_name += f"_trial_{trial.number}"
      prefix = f"trial_{trial.number}: "
    
    writer = SummaryWriter(f"runs/{run_name}")
    writer.add_text(
        "hyperparameters",
        "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])),
    )

    # TRY NOT TO MODIFY: seeding
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.backends.cudnn.deterministic = args.torch_deterministic

    # env setup
    env = get_env(args, run_name=f"{run_name}_0", capture_video=args.capture_video)
    assert isinstance(env.action_space, Discrete), "only discrete action space is supported"

    q_network = args.model(env.action_space.n).to(DEVICE)
    if start_model_path:
        state_dict = torch.load(start_model_path)
        q_network.load_state_dict(state_dict)
    
    optimizer = optim.Adam(q_network.parameters(), lr=args.learning_rate)
    target_network = args.model(env.action_space.n).to(DEVICE)
    target_network.load_state_dict(q_network.state_dict())

    # summary(q_network, input_size=(args.frame_stack, *BOARD_SHAPE), batch_size=args.batch_size, device=device_name)

    rb = TensorDictPrioritizedReplayBuffer(
        alpha=args.alpha,
        beta=args.beta,
        storage=LazyTensorStorage(args.buffer_size, device=DEVICE),
        batch_size=args.batch_size,
        prefetch=args.prefetch
    )

    obs = env.reset()

    # Tracks number of episodes simulated
    episode_cnt = 0
    # Tracks the number of pieces we have played
    piece_count = 0
    # Whether we explore (play random moves) or exploit (play according to the model)
    explore = True
    info = None

    # Track the best scoring models
    episode_lines = deque(maxlen=args.mean_lines_count)
    best_mean_lines = -1.0

    eval_idx = 1

    if args.capture_inputs_video:
        out = cv2.VideoWriter(f'episode0.mp4', cv2.VideoWriter_fourcc(*'mp4v'), FPS, (BOARD_SHAPE[1]*SCALE_UP, BOARD_SHAPE[0]*SCALE_UP), False)

    sps_time = time.time()

    for global_step in range(args.total_timesteps):

        if global_step > 0 and global_step % 1000 == 0:
           curr_time = time.time()
           writer.add_scalar("charts/SPS", 1000 / (curr_time - sps_time), global_step)
           sps_time = curr_time
        
        if args.capture_inputs_video and (episode_cnt % args.video_frequency == 0):
            write_video_frame(out, obs, args.frame_stack, SCALE_UP)
        
        # If a new piece has been generated, decide wether we will explore or exploit for this piece
        if global_step > 0 and piece_count != info.get("piece_count"):
            piece_count = info.get("piece_count")
            if global_step < args.learning_starts:
                epsilon = args.start_e
            else:
                duration = args.exploration_fraction * (args.total_timesteps - args.learning_starts)
                epsilon = linear_schedule(args.start_e, args.end_e, duration, global_step - args.learning_starts)
            explore = (random.random() < epsilon)

        # Find the next action to play
        if explore:
            action = env.action_space.sample()
        else:
            q_values = q_network(obs)
            action = int(torch.argmax(q_values))
        
        # Play a step with the given action
        next_obs, reward, done, info = env.step(action)

        # Evaluate and report the agent periodically
        if trial and global_step > 0 and global_step % args.eval_frequency == 0 and global_step < (args.total_timesteps - args.total_evaluations):
            no_video = not (args.eval_video_frequency % eval_idx == 0)
            eval_mean_lines = evaluate(args, model=q_network, eval_name=f"{run_name}-eval-{eval_idx}", no_video=no_video)
            print(f"{prefix}evaluation_{eval_idx} mean_lines={eval_mean_lines}")
            trial.set_user_attr("mean_lines", eval_mean_lines)
            eval_idx += 1

            # Check if the trial should be pruned
            if trial and trial.should_prune():
                print(f"Pruning Run: {run_name}")
                raise optuna.exceptions.TrialPruned()

        if not done:
            # Add observation to replay buffer
            data = TensorDict({"obs" : obs,
                               "next_obs" : next_obs,
                               "action" : [action],
                               "reward" : [reward],
                               "done" : [int(done)]},
                               batch_size=1, device=DEVICE)
            rb.add(data)
            obs = next_obs
        else:
            # print(f"Episode {episode_cnt} completed: {prefix}global_step={global_step},\tepisodic_return={info.get('episode')['r']:.1f},\tscore={info.get('score')}")
            # print(f"Episode {episode_cnt} completed: {prefix}global_step={global_step},\tepisodic_return={info.get('episode')['r']:.1f},\tLines_Cleared={info.get('lines')}")
            write_episode_scalars(writer, global_step, info, epsilon)

            episode_cnt += 1

            episode_lines.append(info.get("lines"))
            if episode_cnt > args.mean_lines_count:
                curr_mean_lines = sum(episode_lines) / args.mean_lines_count
                if curr_mean_lines > best_mean_lines:
                    best_mean_lines = curr_mean_lines
                    if global_step > args.learning_starts:
                        print(f"New best mean lines: {curr_mean_lines}")
                        # Keep a backup of the best scoring model
                        best_model_path = f"runs/{run_name}/{args.exp_name}.best"
                        torch.save(q_network.state_dict(), best_model_path)

            if args.capture_inputs_video:
                if episode_cnt % args.video_frequency == 0:
                    out = cv2.VideoWriter(f'episode{episode_cnt}.mp4', cv2.VideoWriter_fourcc(*'mp4v'), FPS, (BOARD_SHAPE[1]*SCALE_UP, BOARD_SHAPE[0]*SCALE_UP), False)
                else:
                    out = None
            
            if episode_cnt % args.reload_env_frequency == 0:
                num_reloads = episode_cnt // args.reload_env_frequency
                del env
                env = get_env(args, run_name=f"{run_name}_{num_reloads}", capture_video=args.capture_video)
            
            obs = env.reset()

        # Training Logic
        if global_step > args.learning_starts:
            
            if global_step % args.train_frequency == 0:
                data = rb.sample()
                with torch.no_grad():
                    target_max, _ = target_network(torch.squeeze(data.get("next_obs"))).max(dim=1)
                    td_target = data.get("reward").flatten() + args.gamma * target_max * (1 - data.get("done").flatten())
                model_output = q_network(torch.squeeze(data.get("obs"))).gather(1, data.get("action")).squeeze()
                
                loss = F.mse_loss(model_output, td_target)

                # Not sure about this part
                # weights = data.get("_weight")
                # loss = torch.sum(weights * torch.square(old_val - td_target)) / torch.sum(weights)

                # optimize the model
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                # Update data priority
                td_error = torch.abs(model_output - td_target).unsqueeze(1)
                data.set("td_error", td_error)
                rb.update_tensordict_priority(data)

                # Log training statistics
                if global_step % (100 * args.train_frequency) == 0:
                    writer.add_scalar("losses/td_loss", loss, global_step)
                    writer.add_scalar("losses/q_values", model_output.mean().item(), global_step)

            # update target network
            if global_step % args.target_network_frequency == 0:
                for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()):
                    target_network_param.data.copy_(
                        args.tau * q_network_param.data + (1.0 - args.tau) * target_network_param.data
                    )

            if global_step % args.backup_frequency == 0:
                backup_model_path = f"runs/{run_name}/{args.exp_name}.backup"
                torch.save(q_network.state_dict(), backup_model_path)

    if args.save_model:
        final_model_path = f"runs/{run_name}/{args.exp_name}.final"
        torch.save(q_network.state_dict(), final_model_path)
        print(f"{prefix}model saved to {final_model_path}")

        final_mean_lines = evaluate(
            args,
            model=q_network,
            eval_name=f"{run_name}-eval"
        )

        final_mean_lines = np.mean(final_mean_lines)
        print(f"{prefix}evaluation_{eval_idx}: mean_lines={final_mean_lines}")
        if trial:
            trial.set_user_attr("mean_lines", float(final_mean_lines))
        
  except:
    raise
  finally:
    env.close()
    writer.close()

# Optuna

In [None]:
# Optuna constants
N_TRIALS = 100              # Maximum number of trials
N_TIMESTEPS = 1_000_000       # Maximum number of time steps per trial
N_JOBS = 4                  # Number of jobs to run in parallel
N_STARTUP_TRIALS = 8        # Stop random sampling after N_STARTUP_TRIALS

In [None]:
class Args:
    def __init__(self):
        # Settings
        self.exp_name = "Tetris_DQN"
        self.movement = SIMPLE_MOVEMENT
        self.run_id = int(time.time())
        self.torch_deterministic = True
        self.capture_video = False
        self.capture_eval_video = True
        self.capture_inputs_video = False
        self.save_model = True
        self.backup_frequency = 50000
        self.mean_lines_count = 50
        self.video_frequency = 100
        self.eval_video_frequency = 5
        self.reload_env_frequency = 49
        self.prefetch = 3

        # Constant Hyper-Parameters
        self.seed = 2
        self.total_timesteps = N_TIMESTEPS
        self.buffer_size = 100_000
        self.learning_starts = 50_000
        self.train_frequency = 1
        self.start_e = 1.0
        self.batch_size = 32
        self.gamma = 0.98
        self.tau = 0.999
        self.target_network_frequency = 1000
        self.exploration_fraction = 0.4
        self.frame_stack = 4

        # Evaluation
        self.total_evaluations = 3
        self.eval_episodes = 10
        self.eval_frequency = self.total_timesteps // self.total_evaluations

        # Optimizable Hyper-Parameters
        self.model = None
        self.learning_rate = None
        self.alpha = None
        self.beta = None
        self.end_e = None
        # self.gamma = None
        # self.tau = None
        # self.frame_stack = None
        # self.target_network_frequency = None

        # Reward weights
        self.line_weight = None
        self.height_weight = None
        self.cost_weight = None
        self.holes_weight = None
        self.bumpiness_weight = None
        self.col_transitions_weight = None
        self.row_transitions_weight = None

args = Args()

In [None]:
def sample_params(trial: optuna.Trial) -> dict:
    params = {
        "model" : get_model_class(trial.suggest_categorical("model", ["FC", "Itai_v1_1", "Itai_v1_2", "Itai_v1_3", "Itai_v2", "Itai_v4"])),
        "learning_rate" : trial.suggest_float("learning_rate", 1e-5, 1e-3, log=True),
        "alpha" : trial.suggest_float("alpha", 0, 1),
        "beta" : trial.suggest_float("beta", 0, 1),
        "end_e" : trial.suggest_float("end_e", 0, 0.3),
        # "gamma" : 1.0 - trial.suggest_float("gamma", 0.0001, 0.1, log=True),
        # "tau" : 1.0 - trial.suggest_float("tau", 0.00001, 0.1, log=True),
        # "frame_stack" : trial.suggest_int("frame_stack", 2, 6),
        # "target_network_frequency" : trial.suggest_int("target_network_frequency", 100, 10000, log=True),
        # "exploration_fraction" : trial.suggest_float("exploration_fraction", 0.1, 0.5),
        "line_weight" : trial.suggest_float("line_weight", 0, 30),
        "height_weight" : trial.suggest_float("height_weight", 0, 3),
        "cost_weight" : trial.suggest_float("cost_weight", 0, 3),
        "holes_weight" : trial.suggest_float("holes_weight", 0, 3),
        "bumpiness_weight" : trial.suggest_float("bumpiness_weight", 0, 3),
        "col_transitions_weight" : trial.suggest_float("col_transitions_weight", 0, 3),
        "row_transitions_weight" : trial.suggest_float("row_transitions_weight", 0, 3),
    }
    return params

In [None]:
def objective(trial: optuna.Trial):
    args = Args()
    hyperparameters = sample_params(trial)
    for key, value in hyperparameters.items():
        setattr(args, key, value)

    nan_encountered = False
    try:
      train(args, trial=trial)
    except AssertionError as e:
      # Sometimes, random hyperparams can generate NaN
      print(e)
      nan_encountered = True
    except optuna.exceptions.TrialPruned:
      raise

    # Tell the optimizer that the trial failed
    if nan_encountered:
        return float("nan")

    return trial.user_attrs["mean_lines"]

# Main

In [None]:
!rm -r runs/* videos/* *.mp4 db.sqlite3

In [None]:
study_name = f"{args.exp_name}_study"
study_num = 1

In [None]:
# Set pytorch num threads to 1 for faster training
torch.set_num_threads(1)
# Select the sampler, can be random, TPESampler, CMAES, ...
sampler = TPESampler(n_startup_trials=N_STARTUP_TRIALS, seed=args.seed)
# Pruner to stop bad runs
pruner = MedianPruner(n_startup_trials=N_STARTUP_TRIALS)
# Create the study and start the hyperparameter optimization
study = optuna.create_study(study_name=f"{study_name}-{study_num}", storage="sqlite:///db.sqlite3", sampler=sampler, pruner=pruner, direction="maximize")
study_num += 1

In [None]:
# !optuna-dashboard sqlite:///db.sqlite3

In [32]:
try:
    study.optimize(objective, n_trials=N_TRIALS, n_jobs=N_JOBS)
except KeyboardInterrupt:
    pass

In [None]:
%load_ext tensorboard

In [None]:
print(f"Number of finished trials: ", len(study.trials))
print(f"Best trial:")
best_trial = study.best_trial
print(f"\tValue: {best_trial.value}")
print(f"\tParams: ")
for key, value in best_trial.params.items():
    print(f"\t\t{key}: {value}")

print("\tUser attrs:")
for key, value in best_trial.user_attrs.items():
    print(f"\t\t{key}: {value}")

In [None]:
# Write report
study.trials_dataframe().to_csv(f"study_results_dqn_{study_name}_{study_num}.csv")

fig1 = plot_optimization_history(study)
fig2 = plot_param_importances(study)

fig1.show()
fig2.show()

In [None]:
# class Args:
#     def __init__(self):
#         # Settings
#         self.exp_name = "Tetris_DQN"
#         self.run_id = int(time.time())
#         self.torch_deterministic = True
#         self.cuda = True
#         self.mps = False
#         self.capture_video = True
#         self.capture_inputs_video = True
#         self.save_model = True
#         self.eval_episodes = 5
#         self.backup_frequency = 50000
#         self.mean_score_count = 50
#         self.video_frequency = 200
#         self.reload_env_frequency = 49
#         self.prefetch = None

#         # Constant Hyper-Parameters
#         self.seed = 2
#         self.total_timesteps = 1_000_000
#         self.buffer_size = 50_000
#         self.learning_starts = 50_000
#         self.train_frequency = 1
#         self.start_e = 1
#         self.batch_size = 32

#         # Optimizable Hyper-Parameters
#         self.model = get_model_class('small')
#         self.learning_rate = 2e-4
#         self.gamma = 0.99
#         self.tau = 0.999
#         self.alpha = 0.6
#         self.beta = 0.5
#         self.frame_stack = 4
#         self.target_network_frequency = 2000
#         self.end_e = 0.1
#         self.exploration_fraction = 0.5

#         # Reward weights
#         self.line_weight = 10.0
#         self.height_weight = 1.0
#         self.cost_weight = 1.0
#         self.holes_weight = 1.0
#         self.bumpiness_weight = 1.0
#         self.col_transitions_weight = 1.0
#         self.row_transitions_weight = 1.0


# args = Args()