# Setup

In [None]:
output_dir = 'results'
game_name = 'Breakout'

### Global Modules

In [None]:
import os
import datetime
import pandas as pd
from matplotlib import pyplot as plt
%matplotlib inline
from gym.wrappers import Monitor

### Local Modules

In [None]:
import base
from src.environment.atari import build_atari_environment
from src.environment.nes import build_nes_environment
from src.agents import DeepQAgent
from src.util import BaseCallback

#### Output Directory

In [None]:
now = datetime.datetime.today().strftime('%Y-%m-%d_%H-%M')
output_dir = '{}/{}/DeepQAgent/{}'.format(output_dir, game_name, now)
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
output_dir

In [None]:
weights_file = '{}/weights.h5'.format(output_dir)
weights_file

# Environment

In [None]:
# check if we need to load the NES environment
if 'SuperMarioBros' in game_name:
    env = build_nes_environment(game_name)
# default to the Atari environment
else:
    env = build_atari_environment(game_name)
# wrap the environment with a monitor
env = Monitor(env, '{}/monitor_train'.format(output_dir), force=True)

In [None]:
env.observation_space

In [None]:
env.action_space

# Agent

In [None]:
# build the agent
agent = DeepQAgent(env)
agent

In [None]:
# write some info about the agent's hyperparameters to disk
with open('{}/agent.py'.format(output_dir), 'w') as agent_file:
    agent_file.write(repr(agent))

## Training

In [None]:
agent.observe(replay_start_size=1000)

In [None]:
callback = BaseCallback(weights_file)
callback

In [None]:
agent.train(callback=callback, frames_to_play=1000)

In [None]:
# save the training results
scores = pd.Series(callback.scores)
scores.to_csv('{}/scores.csv'.format(output_dir))
losses = pd.Series(callback.losses)
losses.to_csv('{}/losses.csv'.format(output_dir))

In [None]:
train = pd.concat([scores, losses], axis=1)
train.columns = ['Reward', 'Loss']
train.index.name = 'Episode'
_ = train.plot(figsize=(12, 5), subplots=True)

## Saving Weights

In [None]:
agent.model.save_weights('{}/weights.h5'.format(output_dir), overwrite=True)