# Imports

In [None]:
import gym
from utils import *
from agents import *

# Environment

In [None]:
game_id = "SeaquestNoFrameskip-v4"
env = wrap_deepmind(gym.make(game_id))
num_actions = env.unwrapped.action_space.n

# Agent training

In [None]:
atari_agent = GaussDQNAgent(env, num_actions, state_shape=[84, 84, 4],
                            convs=[[32, 8, 4], [64, 4, 2], [64, 3, 1]], fully_connected=[512],
                            base_sigma=0.05, sigma_coef=10, loss_function="kl_divergence",
                            save_path="atari_models", model_name="gaussian_boi")

In [None]:
atari_agent.set_parameters(max_episode_length=100000, discount_factor=0.99, final_eps=0.01,
                           replay_memory_size=1000000, replay_start_size=50000, annealing_steps=1000000)

In [None]:
aa.train(gpu_id=1, exploration="e-greedy", save_freq=1000000, max_num_epochs=1000, performance_print_freq=50)

# Other agents

In [None]:
# Classic deep Q-network
atari_agent = DQNAgent(env, num_actions, state_shape=[84, 84, 4],
                       convs=[[32, 8, 4], [64, 4, 2], [64, 3, 1]], fully_connected=[512],
                       save_path="atari_models", model_name="dqn_boi")

# Dueling deep Q-network
atari_agent = DuelDQNAgent(env, num_actions, state_shape=[84, 84, 4],
                           convs=[[32, 8, 4], [64, 4, 2], [64, 3, 1]], fully_connected=[256],
                           save_path="atari_models", model_name="dueldqn_boi")

# Categorical deep Q-network (C51)
atari_agent = CatDQNAgent(env, num_actions, state_shape=[84, 84, 4],
                          convs=[[32, 8, 4], [64, 4, 2], [64, 3, 1]], fully_connected=[512],
                          v=(-10, 10), num_atoms=51,
                          save_path="atari_models", model_name="catdqn_boi")

# Gaussian deep Q-network
atari_agent = GaussDQNAgent(env, num_actions, state_shape=[84, 84, 4],
                            convs=[[32, 8, 4], [64, 4, 2], [64, 3, 1]], fully_connected=[512],
                            base_sigma=0.05, sigma_coef=10, loss_function="kl_divergence",
                            save_path="atari_models", model_name="gaussian_boi")