# Setup

In [None]:
output_dir = 'results'
env_id = 'Tetris-v0'

### Global Modules

In [None]:
import os
import datetime
import pandas as pd
from matplotlib import pyplot as plt
%matplotlib inline

### Local Modules

In [None]:
import base
from src.environment.atari import build_atari_environment
from src.agents import DeepQAgent
from src.util import BaseCallback
from src.agents import DeepQAgent
from src.util import BaseCallback

## Output Directory

In [None]:
now = datetime.datetime.today().strftime('%Y-%m-%d_%H-%M')
output_dir = '{}/{}/DeepQAgent/{}'.format(output_dir, game_name, now)
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
output_dir

In [None]:
weights_file = '{}/weights.h5'.format(output_dir)
weights_file

# Environment

In [None]:
if 'Tetris' in env_id:
    env = gym_tetris.make(env_id)
    env = gym_tetris.wrap(env, clip_rewards=False, skip_frames=None)
elif 'SuperMarioBros' in env_id:
    env = gym_super_mario_bros.make(env_id)
    env = gym_super_mario_bros.wrap(env, clip_rewards=False)
else:
    env = build_atari_environment(env_id)

In [None]:
env.observation_space

In [None]:
env.action_space

# Agent

In [None]:
# build the agent
agent = DeepQAgent(env, replay_memory_size=int(7.5e5))
agent

In [None]:
# write some info about the agent's hyperparameters to disk
with open('{}/agent.py'.format(output_dir), 'w') as agent_file:
    agent_file.write(repr(agent))

# Training

In [None]:
agent.observe()

In [None]:
callback = BaseCallback(weights_file)
callback

In [None]:
agent.train(frames_to_play=int(2.5e6), callback=callback)

In [None]:
agent.model.save_weights(weights_file, overwrite=True)

# Results

In [None]:
rewards = pd.Series(callback.scores)
losses = pd.Series(callback.losses)
rewards_losses = pd.concat([rewards, losses], axis=1)
rewards_losses.columns = ['Reward', 'Loss']
rewards_losses.index.name = 'Episode'
rewards_losses.to_csv('{}/rewards_losses.csv'.format(output_dir))

In [None]:
rewards_losses.plot(figsize=(12, 5), subplots=True)
plt.savefig('{}/rewards_losses.pdf'.format(output_dir))

In [None]:
env.close()