# Imports

In [None]:
import gym
from utils import *
from agents import *
from environments import Snake

# Snake Environment

### Environment initializtion

In [None]:
env = Snake(grid_size=(8, 8))
num_actions = 3

### NN training

In [None]:
snake_agent = DQNAgent(env, num_actions, state_shape=[8, 8, 5],
                       convs=[[16, 2, 1], [32, 1, 1]], fully_connected=[128],
                       save_path="snake_models", model_name="dqn_8x8")

In [None]:
snake_agent.set_parameters(max_episode_length=1000, replay_memory_size=100000, replay_start_size=10000,
                           discount_factor=0.999, final_eps=0.01, annealing_steps=100000)

In [None]:
# set gpu_id = -1 to use cpu instead if gpu
snake_agent.train(gpu_id=-1, exploration="boltzmann", save_freq=500000, max_num_epochs=1000)

### Other agents

In [None]:
# Classic deep Q-network
snake_agent = DQNAgent(env, num_actions, state_shape=[8, 8, 5],
                       convs=[[16, 2, 1], [32, 1, 1]], fully_connected=[128],
                       save_path="snake_models", model_name="dqn_8x8")

# Dueling deep Q-network
snake_agent = DuelDQNAgent(env, num_actions, state_shape=[8, 8, 5],
                           convs=[[16, 2, 1], [32, 1, 1]], fully_connected=[64],
                           save_path="snake_models", model_name="dueldqn_8x8")

# Categorical deep Q-network (C51)
snake_agent = CatDQNAgent(env, num_actions, state_shape=[8, 8, 5],
                          convs=[[16, 2, 1], [32, 1, 1]], fully_connected=[128],
                          v=(-5, 25), num_atoms=51,
                          save_path="snake_models", model_name="catdqn_8x8")

# Quantile regression deep Q-network (QR-DQN)
snake_agent = QuantRegDQNAgent(env, num_actions, state_shape=[8, 8, 5],
                               convs=[[16, 2, 1], [32, 1, 1]], fully_connected=[128],
                               num_atoms=100, kappa=1.0,
                               save_path="snake_models", model_name="quantdqn_8x8")

# Soft Actor-Critic
snake_agent = SACAgent(env, num_actions, state_shape=[8, 8, 5],
                       convs=[[16, 2, 1], [32, 1, 1]], fully_connected=[128],
                       temperature=0.1,
                       save_path="snake_models", model_name="sac_8x8")

# Atari Environment

### Environment initializtion

In [None]:
game_id = "PongNoFrameskip-v4"
env = wrap_deepmind(gym.make(game_id))
num_actions = env.unwrapped.action_space.n

### NN training

In [None]:
atari_agent = DQNAgent(env, num_actions, state_shape=[84, 84, 4],
                       convs=[[32, 8, 4], [64, 4, 2], [64, 3, 1]], fully_connected=[512],
                       save_path="atari_models", model_name="dqn_boi")

In [None]:
atari_agent.set_parameters(max_episode_length=100000, discount_factor=0.99, final_eps=0.01,
                           replay_memory_size=1000000, replay_start_size=50, annealing_steps=1000000,
                           frame_history_len=4)

In [None]:
atari_agent.train(gpu_id=-1, exploration="e-greedy", save_freq=50000, 
                  max_num_epochs=1000, performance_print_freq=50)

### Other agents

In [None]:
# Classic deep Q-network
atari_agent = DQNAgent(env, num_actions, state_shape=[84, 84, 4],
                       convs=[[32, 8, 4], [64, 4, 2], [64, 3, 1]], fully_connected=[512],
                       save_path="atari_models", model_name="dqn_boi")

# Dueling deep Q-network
atari_agent = DuelDQNAgent(env, num_actions, state_shape=[84, 84, 4],
                           convs=[[32, 8, 4], [64, 4, 2], [64, 3, 1]], fully_connected=[256],
                           save_path="atari_models", model_name="dueldqn_boi")

# Categorical deep Q-network (C51)
atari_agent = CatDQNAgent(env, num_actions, state_shape=[84, 84, 4],
                          convs=[[32, 8, 4], [64, 4, 2], [64, 3, 1]], fully_connected=[512],
                          v=(-10, 10), num_atoms=51,
                          save_path="atari_models", model_name="catdqn_boi")

# Quantile regression deep Q-network (QR-DQN)
atari_agent = QuantRegDQNAgent(env, num_actions, state_shape=[84, 84, 4],
                               convs=[[32, 8, 4], [64, 4, 2], [64, 3, 1]], fully_connected=[512],
                               num_atoms=200, kappa=1,
                               save_path="atari_models", model_name="quantdqn_boi")