In [1]:
from __future__ import absolute_import, division, print_function

import base64
import imageio
import IPython
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

import tensorflow as tf

import argparse
from tf_agents.agents.dqn import dqn_agent
from tf_agents.drivers import dynamic_step_driver
from tf_agents.environments import suite_gym
from tf_agents.environments import tf_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import q_network
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory
from tf_agents.utils import common

# Enables TensorFlow 2 behaviors.
tf.compat.v1.enable_v2_behavior()

In [2]:
def dqn_args_train():
    """Parse DQN training arguments.
    
    Returns:
        args: The parsed arguments.
    """
    parser = argparse.ArgumentParser()

    parser.add_argument(
        '--seed',
        dest='seed',
        type=int,
        help='Seed for numpy and tensorflow.',
        default=123)

    parser.add_argument(
        '--num_iterations',
        dest='num_iterations',
        type=int,
        help=' Training will end after n number of interations.',
        default=20000)

    parser.add_argument(
        '--initial_collect_steps',
        dest='initial_collect_steps',
        type=int,
        help='Exploratory steps.',
        default=1000)

    parser.add_argument(
        '--collect_steps_per_iteration',
        dest='collect_steps_per_iteration',
        type=int,
        help='Collected steps per iteration.',
        default=1)

    parser.add_argument(
        '--replay_buffer_max_length',
        dest='replay_buffer_max_length',
        type=int,
        help='Size of the replay buffer.',
        default=100000)

    parser.add_argument(
        '--batch_size',
        dest='batch_size',
        type=int,
        help='The assets directory.',
        default=64)

    parser.add_argument(
        '--lr',
        dest='learning_rate',
        type=float,
        help='The learning rate',
        default=1e-3)

    parser.add_argument(
        '--log_interval',
        dest='log_interval',
        type=int,
        help='Output logs after n steps.',
        default=200)

    parser.add_argument(
        '--num_eval_episodes',
        dest='num_eval_episodes',
        type=int,
        help='.',
        default=10)

    parser.add_argument(
        '--eval_interval',
        dest='eval_interval',
        type=int,
        help='.',
        default=1000)

    args = parser.parse_args(args=[])
    #args = parser.parse_args()

    return args

In [3]:
def compute_avg_return(environment, policy, num_episodes):
    """Computes the average return.
    
    Args:
        environment: The environment.
        policy: The agent's policy.
        num_episodes: Number of episodes.
        
    Returns:
        avg_return: The average return.
    """

    total_return = 0.0
    for _ in range(num_episodes):
        time_step = environment.reset()
        episode_return = 0.0
        
        while not time_step.is_last():
            action_step = policy.action(time_step)
            time_step = environment.step(action_step.action)
            episode_return += time_step.reward
        total_return += episode_return

    avg_return = total_return / num_episodes
    return avg_return.numpy()[0]

In [4]:
def collect_step(environment, policy, buffer):
    """ Collects data from one step and stores it in the replay buffer.
    
    Args:
        environment: The environment.
        policy: The agent's policy.
        buffer: The replay buffer.
        
    Yields:
        A trajectory added to the replay buffer.
    """
        
    time_step = environment.current_time_step()
    action_step = policy.action(time_step)
    next_time_step = environment.step(action_step.action)
    traj = trajectory.from_transition(time_step, action_step, next_time_step)

    buffer.add_batch(traj)

In [5]:
def collect_data(environment, policy, buffer, n_steps):
    """ Collects data from n steps and stores it in the replay buffer.
    
    Args:
        environment: The environment.
        policy: The agent's policy.
        buffer: The replay buffer.
        n_steps: The number of steps to collect data.
        
    Yields:
        n_steps added to the replay buffer.
    """
    for _ in range(n_steps):
        collect_step(environment, policy, buffer)    

In [6]:
def embed_mp4(filename):
    """Embeds an mp4 file in the notebook."""
    video = open(filename,'rb').read()
    b64 = base64.b64encode(video)
    tag = '''
    <video width="640" height="480" controls>
    <source src="data:video/mp4;base64,{0}" type="video/mp4">
    Your browser does not support the video tag.
    </video>'''.format(b64.decode())

    return IPython.display.HTML(tag)

In [7]:
def create_policy_eval_video(policy, filename, eval_env, eval_py_env, num_episodes=5, fps=30):
    filename = filename + ".mp4"
    with imageio.get_writer(filename, fps=fps) as video:
        for _ in range(num_episodes):
            time_step = eval_env.reset()
            video.append_data(eval_py_env.render())
            while not time_step.is_last():
                action_step = policy.action(time_step)
                time_step = eval_env.step(action_step.action)
                video.append_data(eval_py_env.render())
    return embed_mp4(filename)

In [8]:
def create_environment(env_name="CartPole-v0"):
    
    
    env = suite_gym.load(env_name)
    env.reset()

    train_py_env = suite_gym.load(env_name)
    eval_py_env = suite_gym.load(env_name)
    train_env = tf_py_environment.TFPyEnvironment(train_py_env)
    eval_env = tf_py_environment.TFPyEnvironment(eval_py_env)
    
    return train_env, eval_env, train_py_env, eval_py_env

In [9]:
def create_network(train_env):
    
    fc_layer_params = (100,)
    q_net = q_network.QNetwork(
        train_env.observation_spec(),
        train_env.action_spec(),
        fc_layer_params=fc_layer_params)
    
    return q_net

In [10]:
def create_agent(train_env, q_net, learning_rate):
    
    
    optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)

    train_step_counter = tf.Variable(0)

    # Create the agent
    agent = dqn_agent.DqnAgent(
        train_env.time_step_spec(),
        train_env.action_spec(),
        q_network=q_net,
        optimizer=optimizer,
        td_errors_loss_fn=common.element_wise_squared_loss,
        train_step_counter=train_step_counter)

    agent.initialize()
    
    return agent

In [11]:
def create_policies(train_env, agent):
    
    eval_policy = agent.policy
    collect_policy = agent.collect_policy
    random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                train_env.action_spec())
    
    return eval_policy, collect_policy, random_policy

In [12]:
def create_replay_buffer(train_env, agent, random_policy, replay_buffer_max_length,
                        initial_collect_steps, batch_size):
    
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=train_env.batch_size,
    max_length=replay_buffer_max_length)

    collect_data(train_env, random_policy, replay_buffer, n_steps=initial_collect_steps)

    dataset = replay_buffer.as_dataset(
        num_parallel_calls=3, 
        sample_batch_size=batch_size, 
        num_steps=2).prefetch(3)

    iterator = iter(dataset)
    
    return replay_buffer, iterator

In [13]:
def dqn_train_eval():
    
    args = dqn_args_train()
    
    train_env, eval_env, train_py_env, eval_py_env = create_environment()
    q_net = create_network(train_env)
    agent = create_agent(train_env, q_net, args.learning_rate)
    eval_policy, collect_policy, random_policy = create_policies(train_env, agent)
    replay_buffer, iterator = create_replay_buffer(
        train_env, 
        agent, 
        random_policy, 
        args.replay_buffer_max_length,
        args.initial_collect_steps,
        args.batch_size)
    
    # Set the random seed.
    if args.seed is not None:
        np.random.seed(args.seed)
        tf.random.set_seed(args.seed)

    # (Optional) Optimize by wrapping some of the code in a graph using TF function.
    agent.train = common.function(agent.train)

    # Reset the train step
    agent.train_step_counter.assign(0)

    # Evaluate the agent's policy once before training.
    avg_return = compute_avg_return(eval_env, agent.policy, args.num_eval_episodes)
    returns = [avg_return]

    for _ in range(args.num_iterations):

        # Collect a few steps using collect_policy and save to the replay buffer.
        for _ in range(args.collect_steps_per_iteration):
            collect_step(train_env, agent.collect_policy, replay_buffer)

        # Sample a batch of data from the buffer and update the agent's network.
        experience, unused_info = next(iterator)
        train_loss = agent.train(experience).loss

        step = agent.train_step_counter.numpy()

        if step % args.log_interval == 0:
            print('step = {0}: loss = {1}'.format(step, train_loss))

        if step % args.eval_interval == 0:
            avg_return = compute_avg_return(eval_env, agent.policy, args.num_eval_episodes)
            print('step = {0}: Average Return = {1}'.format(step, avg_return))
            returns.append(avg_return)
            
        if step % (args.eval_interval * 5) == 0:
            create_policy_eval_video(agent.policy, "videos/cartpole_" + str(step), eval_env, eval_py_env)
    

In [14]:
dqn_train_eval()

step = 200: loss = 38.37812423706055
step = 400: loss = 55.52467727661133
step = 600: loss = 20.308481216430664
step = 800: loss = 22.60995864868164
step = 1000: loss = 37.635780334472656
step = 1000: Average Return = 41.0
step = 1200: loss = 24.170095443725586
step = 1400: loss = 62.31285095214844
step = 1600: loss = 34.68653106689453
step = 1800: loss = 61.01420593261719
step = 2000: loss = 4.443548202514648
step = 2000: Average Return = 55.29999923706055
step = 2200: loss = 16.577068328857422
step = 2400: loss = 36.24903869628906
step = 2600: loss = 84.85283660888672
step = 2800: loss = 30.161462783813477
step = 3000: loss = 70.71626281738281
step = 3000: Average Return = 45.70000076293945
step = 3200: loss = 67.6693344116211
step = 3400: loss = 56.20784378051758
step = 3600: loss = 133.0223846435547
step = 3800: loss = 38.070770263671875
step = 4000: loss = 84.8382568359375
step = 4000: Average Return = 97.0999984741211
step = 4200: loss = 75.252685546875
step = 4400: loss = 130.38



step = 5200: loss = 32.419471740722656
step = 5400: loss = 157.88783264160156
step = 5600: loss = 141.82223510742188
step = 5800: loss = 69.51016998291016
step = 6000: loss = 52.09953689575195
step = 6000: Average Return = 115.4000015258789
step = 6200: loss = 225.6766357421875
step = 6400: loss = 153.56661987304688
step = 6600: loss = 203.82626342773438
step = 6800: loss = 192.7113494873047
step = 7000: loss = 90.13687133789062
step = 7000: Average Return = 196.8000030517578
step = 7200: loss = 94.67214965820312
step = 7400: loss = 186.1062469482422
step = 7600: loss = 16.9442138671875
step = 7800: loss = 512.1292114257812
step = 8000: loss = 237.49038696289062
step = 8000: Average Return = 183.0
step = 8200: loss = 579.2872314453125
step = 8400: loss = 465.34539794921875
step = 8600: loss = 21.985334396362305
step = 8800: loss = 293.43133544921875
step = 9000: loss = 167.97677612304688
step = 9000: Average Return = 199.8000030517578
step = 9200: loss = 210.78369140625
step = 9400: lo



step = 10000: Average Return = 200.0
step = 10200: loss = 84.01322937011719
step = 10400: loss = 884.515625
step = 10600: loss = 217.8933868408203
step = 10800: loss = 361.18560791015625
step = 11000: loss = 553.664794921875
step = 11000: Average Return = 200.0
step = 11200: loss = 567.9567260742188
step = 11400: loss = 1790.2293701171875
step = 11600: loss = 445.34130859375
step = 11800: loss = 31.82424545288086
step = 12000: loss = 82.710693359375
step = 12000: Average Return = 200.0
step = 12200: loss = 37.00714111328125
step = 12400: loss = 18.624347686767578
step = 12600: loss = 175.9971923828125
step = 12800: loss = 166.88958740234375
step = 13000: loss = 482.79541015625
step = 13000: Average Return = 200.0
step = 13200: loss = 519.7514038085938
step = 13400: loss = 230.82632446289062
step = 13600: loss = 35.094215393066406
step = 13800: loss = 1824.8187255859375
step = 14000: loss = 70.08675384521484
step = 14000: Average Return = 200.0
step = 14200: loss = 1568.5677490234375
st



step = 15000: Average Return = 200.0
step = 15200: loss = 161.66961669921875
step = 15400: loss = 3647.3544921875
step = 15600: loss = 2617.25048828125
step = 15800: loss = 2542.709716796875
step = 16000: loss = 2738.21044921875
step = 16000: Average Return = 200.0
step = 16200: loss = 86.77226257324219
step = 16400: loss = 2887.104248046875
step = 16600: loss = 2208.19189453125
step = 16800: loss = 52.97126007080078
step = 17000: loss = 53.186248779296875
step = 17000: Average Return = 200.0
step = 17200: loss = 1218.6190185546875
step = 17400: loss = 2900.903564453125
step = 17600: loss = 3360.69384765625
step = 17800: loss = 70.0076904296875
step = 18000: loss = 77.5316162109375
step = 18000: Average Return = 200.0
step = 18200: loss = 178.0081329345703
step = 18400: loss = 235.01988220214844
step = 18600: loss = 96.42203521728516
step = 18800: loss = 5020.37646484375
step = 19000: loss = 360.89434814453125
step = 19000: Average Return = 200.0
step = 19200: loss = 1498.7930908203125



step = 20000: Average Return = 200.0


In [None]:
def plot(returns, num_iterations, eval_interval):
    iterations = range(0, num_iterations + 1, eval_interval)
    plt.plot(iterations, returns)
    plt.ylabel('Average Return')
    plt.xlabel('Iterations')
    plt.ylim(top=250)

In [15]:
from numba import cuda
cuda.select_device(0)
cuda.close()