# Setup

In [None]:
exp_directory = 'results'
agent_name = 'DoubleDeepQAgent'
game_name = 'SpaceInvaders'
env_name = '-v4'
render_mode='rgb_array'

In [None]:
game = game_name + env_name
game

### Global Modules

In [None]:
import os
import gym
import pandas as pd
%matplotlib inline

### Local Modules

In [None]:
import base
from src.agents import (
    DeepQAgent,
    DoubleDeepQAgent
)
from src.downsamplers import (
    downsample_pong,
    downsample_breakout,
    downsample_space_invaders
)

from src.util import JupyterCallback

## Constants

In [None]:
# a mapping of string names to agents
agents = {
    'DeepQAgent': DeepQAgent,
    'DoubleDeepQAgent': DoubleDeepQAgent,
}

# down-samplers for each game
downsamplers = {
    'Pong': downsample_pong,
    'Breakout': downsample_breakout,
    'SpaceInvaders': downsample_space_invaders,
}

In [None]:
exp_directory = '{}/{}/{}'.format(exp_directory, game_name, DoubleDeepQAgent.__name__)
if not os.path.exists(exp_directory):
    os.makedirs(exp_directory)
# set up the weights file
weights_file = '{}/weights.h5'.format(exp_directory)

# Environment

In [None]:
env = gym.make(game)

In [None]:
env.observation_space

In [None]:
env.action_space

# Agent

In [None]:
agent = agents[agent_name](env, downsamplers[game_name], render_mode=render_mode)
agent

## Initial

In [None]:
initial = agent.play()
initial = pd.Series(initial)
initial

In [None]:
initial.to_csv('{}/initial.csv'.format(exp_directory))

In [None]:
initial.describe()

In [None]:
initial.hist()

## Training

In [None]:
agent.observe()

In [None]:
callback = JupyterCallback()
agent.train(callback=callback)

In [None]:
# save the training results
scores = pd.Series(callback.scores)
scores.to_csv('{}/scores.csv'.format(exp_directory))
losses = pd.Series(callback.losses)
losses.to_csv('{}/losses.csv'.format(exp_directory))

## Final

In [None]:
final = agent.play()
final = pd.Series(final)
final

In [None]:
final.to_csv('{}/final.csv'.format(exp_directory))

In [None]:
final.describe()

In [None]:
final.hist()

## Saving Weights

In [None]:
agent.model.save_weights('{}/weights.h5'.format(exp_directory), overwrite=True)