In [None]:
#Helpful information
# https://stable-baselines.readthedocs.io/en/master/guide/examples.html#using-callback-monitoring-training
# actual rom files must be named rom.nes and placed in their specific folders... my path is:
# {path on colossus}
# Scenario files also go here, these are what determines how we reward the agent

# To Do:
# Train a 5m timestep model for each of the 9 games.
# Collect data, put it into a report. Prepare a presentation.
# try to auto-optimize their parameters and compare again, then
# Try to make our own, but time is pretty limited, so....

# Hard: SMB, Life Force, Megaman
# Medium: Breakout-Atari2600, Space Invaders, Asteroid
# Easy: CartPole-v0, Pendulum-v0 MountainCar-v0

In [None]:
import retro
import random
import gym
import numpy as np
import os 
import time
import datetime

#mlp means Multilayer perceptron, and is probably the fastest but worst.
#CnnPolicies are for images only. MlpPolicies are made for other type of features (e.g. robot joints) 
#Dunno what this means exactly, but I copied it from the documentation.
from stable_baselines.common.policies import MlpPolicy, MlpLstmPolicy, MlpLnLstmPolicy, CnnPolicy, CnnLstmPolicy, CnnLnLstmPolicy
from stable_baselines.deepq.policies import DQNPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import A2C, PPO2, TRPO
# Documentation is here: https://stable-baselines.readthedocs.io/en/master/guide/examples.html#using-callback-monitoring-training

In [None]:
output = {}
output['highs'] = {}
output['avgs'] = {}
output['lows'] = {}

#set timesteps up here now so we can get estimated time remaining, and finished elapsed time. Sorry.
max_timesteps = 10000

#Gotta use the right stats array or it'll blow.
stats_ = ['fps', 'policy_entropy', 'value_loss'] #A2C
#stats = ['fps', 'loss_val', 'explained_var', 'lr_now'] #PPO2
#stats = ['t_start'] #TRPO Note: TRPO Doesn't have useful info in the callback, but it does if ran with verbose.
# Still have to have 'stats' 

In [None]:
def format_time(secs):
    hours = 0
    minutes = 0
    
    if secs > 3600:
        hours = secs//3600
        secs = secs - (hours*3600)
    
    if secs > 60:
        minutes = secs//60
        secs = secs - (minutes*60)
 
    return int(hours), int(minutes), round(secs,2)

def round_output():
    for stat in stats:
        output['highs'][stat] = round(output['highs'][stat], 4)
        output['avgs'][stat] = round(output['avgs'][stat], 4)
        output['lows'][stat] = round(output['lows'][stat], 4)
        
    return
        
def callback(_locals, _globals):
    #print(_locals)
    #return False
    
    # Apparently callback is called every 5th timestep for a2c, so we have to increment by 5. Nice.
    

    for stat in stats:
        if stat not in output['highs'].keys():
            output['highs'][stat] = _locals[stat]
            
        if stat not in output['lows'].keys():
            output['lows'][stat] = _locals[stat]
         
        if stat not in output['avgs'].keys():
            output['avgs'][stat] = _locals[stat]
        else:
            output['avgs'][stat] = (output['avgs'][stat] + _locals[stat])/2
            
    for stat in stats:
        if output['highs'][stat] < _locals[stat]:
             output['highs'][stat] = _locals[stat]
                
        if output['lows'][stat] > _locals[stat]:
             output['lows'][stat] = _locals[stat]
                
    if 'nupdates' not in _locals.keys():
        if 'timesteps_so_far' in _locals.keys():
            total_timesteps = _locals['timesteps_so_far']
            #Timesteps so far / elapsed time = ts/s, total - timesteps = remaining, remaining / t/s = how long we have.
            curr_steps = _locals['timesteps_so_far']
            elapsed_time = time.time() - start
            ts_per_s = curr_steps / elapsed_time
            remaining = (total_timesteps - curr_steps) 
        else:
            curr_up = _locals['update'] * 5
            total_timesteps = _locals['total_timesteps']
            intervals = 600
            if curr_up % intervals == 0:
                elapsed_time = time.time() - start
                ups_per_sec = curr_up / elapsed_time
                remaining = (total_timesteps - curr_up)/ups_per_sec
                hour_remain, minute_remain, second_remain = format_time(remaining)
                print (f"approx: training time remaining: {hour_remain}:{minute_remain}:{second_remain} {curr_up}/{total_timesteps}")


        return True
    
    curr_up = _locals['update']
    nupdates = _locals['nupdates']
    
    elapsed_time = time.time() - start
    ups_per_sec = curr_up / elapsed_time
    remaining = (nupdates - curr_up)/ups_per_sec
    hour_remain, minute_remain, second_remain = format_time(remaining)
    print (f"approx: training time remaining: {hour_remain}:{minute_remain}:{second_remain} {curr_up}/{nupdates}")

    return True

def print_output():
    round_output()
    for stat in stats:
        print(f"{stat} low: {output['lows'][stat]} mean: {output['avgs'][stat]} high: {output['highs'][stat]}")

In [None]:
#env = retro.make(game="SuperMarioBros-Nes")
#env = retro.make(game="LifeForce-Nes")
#env = retro.make(game="MegaMan-Nes")
#env = retro.make(game="Asteroids-Atari2600")
env = retro.make(game="Breakout-Atari2600") 
#env = retro.make(game="SpaceInvaders-Atari2600") 
#env = gym.make('Pendulum-v0')
#env = gym.make('CartPole-v0')
#env = gym.make('MountainCar-v0')
env = DummyVecEnv([lambda: env])  # The algorithms require a vectorized environment to run


In [None]:
start = time.time()
model = A2C(MlpPolicy, env, verbose=1)
#model = PPO2(MlpPolicy, env, verbose=1)
#model = TRPO(MlpPolicy, env, verbose=1)
#model.learn(total_timesteps=max_timesteps, callback=callback)
#model.save("savedModel")
end = time.time()

hours, minutes, seconds = format_time(end-start)
print(f"Time Elapsed: {hours}:{minutes}:{seconds}")
print_output()

In [None]:
# If you want to just load a model instead of retraining
#model = A2C.load("LFA2CMlp5m.pkl")

In [None]:
obs = env.reset()
episode_reward = 0
for i in range(5000):
    action, _states = model.predict(obs)
    obs, rewards, dones, info = env.step(action)
    env.render()
    episode_reward += rewards
    if dones:
        print('Reward: %s' % episode_reward)
        break

In [None]:
# To find out where to put rom files
print(retro.__file__)