In [None]:
import gym
from baselines import deepq
import tensorflow as tf

def callback(lcl, _glb):
    # stop training if reward exceeds 199
    is_solved = lcl['t'] > 100 and sum(lcl['episode_rewards'][-101:-1]) / 100 >= 199
    return is_solved

In [None]:
def train_cartpole():
    env = gym.make("CartPole-v0")
    act = deepq.learn(
        env,
        network='mlp',
        lr=1e-3,
        total_timesteps=1000,
        buffer_size=50000,
        exploration_fraction=0.1,
        exploration_final_eps=0.02,
        print_freq=10,
        callback=callback
    )
    print("Saving model to cartpole_model.pkl")   
    return act

In [None]:
def evaluate_cartpole(act):
    env = gym.make("CartPole-v0")    
    obs, done = env.reset(), False
    episode_rew = 0
    
    while not done:
        # env.render()
        obs, rew, done, _ = env.step(act(obs[None])[0])
        episode_rew += rew
    print("Episode reward", episode_rew)

In [None]:
def train_mountaincar():
    env = gym.make("MountainCar-v0")
    
    act = deepq.learn(
        env,
        network='mlp',        
        lr=1e-3,
        total_timesteps=1000,
        buffer_size=50000,
        exploration_fraction=0.1,
        exploration_final_eps=0.1,
        print_freq=10,
        param_noise=True,
        callback=None
    )
    print("Saving model to mountaincar_model.pkl")
    return act      

In [None]:
def evaluate_mountaincar(act):
    env = gym.make("MountainCar-v0")    
    
    obs, done = env.reset(), False
    episode_rew = 0
    while not done:
        # env.render()
        obs, rew, done, _ = env.step(act(obs[None])[0])
        episode_rew += rew
    print("Episode reward", episode_rew)    
    env.close()

In [None]:
# Train CartPole
with tf.variable_scope('cartpole'):
    cartpole_action = train_cartpole()

In [None]:
# Evaluate CartPole
evaluate_cartpole(cartpole_action)

In [None]:
# Train MountainCar
with tf.variable_scope('mtncar'):
    mtncar_action = train_mountaincar()

In [None]:
# Evaluate MountainCar
evaluate_mountaincar(mtncar_action)