# Use QR-DQN to Play Pong

TensorFlow version

In [1]:
%matplotlib inline

import copy
import logging
import itertools
import sys

import numpy as np
np.random.seed(0)
import pandas as pd
import gym
from gym.wrappers.atari_preprocessing import AtariPreprocessing
from gym.wrappers.frame_stack import FrameStack
import matplotlib.pyplot as plt
import tensorflow.compat.v2 as tf
tf.random.set_seed(0)
from tensorflow import keras
from tensorflow import nn
from tensorflow import optimizers
from tensorflow import losses
from tensorflow.keras import layers
from tensorflow.keras import models

logging.basicConfig(level=logging.DEBUG,
        format='%(asctime)s [%(levelname)s] %(message)s',
        stream=sys.stdout, datefmt='%H:%M:%S')

Environment

In [2]:
env = FrameStack(AtariPreprocessing(gym.make('PongNoFrameskip-v4')),
        num_stack=4)
for key in vars(env):
    logging.info('%s: %s', key, vars(env)[key])
for key in vars(env.spec):
    logging.info('%s: %s', key, vars(env.spec)[key])

05:01:52 [INFO] env: <AtariPreprocessing<TimeLimit<AtariEnv<PongNoFrameskip-v4>>>>
05:01:52 [INFO] action_space: Discrete(6)
05:01:52 [INFO] observation_space: Box(0, 255, (4, 84, 84), uint8)
05:01:52 [INFO] reward_range: (-inf, inf)
05:01:52 [INFO] metadata: {'render.modes': ['human', 'rgb_array']}
05:01:52 [INFO] num_stack: 4
05:01:52 [INFO] lz4_compress: False
05:01:52 [INFO] frames: deque([], maxlen=4)
05:01:52 [INFO] id: PongNoFrameskip-v4
05:01:52 [INFO] entry_point: gym.envs.atari:AtariEnv
05:01:52 [INFO] reward_threshold: None
05:01:52 [INFO] nondeterministic: False
05:01:52 [INFO] max_episode_steps: 400000
05:01:52 [INFO] _kwargs: {'game': 'pong', 'obs_type': 'image', 'frameskip': 1}
05:01:52 [INFO] _env_name: PongNoFrameskip


Agent

In [3]:
class DQNReplayer:
    def __init__(self, capacity):
        self.memory = pd.DataFrame(index=range(capacity),
                columns=['state', 'action', 'reward', 'next_state', 'done'])
        self.i = 0
        self.count = 0
        self.capacity = capacity

    def store(self, *args):
        self.memory.loc[self.i] = args
        self.i = (self.i + 1) % self.capacity
        self.count = min(self.count + 1, self.capacity)

    def sample(self, size):
        indices = np.random.choice(self.count, size=size)
        return (np.stack(self.memory.loc[indices, field]) for field in
                self.memory.columns)

In [4]:
class Agent:
    def __init__(self, env):
        self.action_n = env.action_space.n
        self.gamma = 0.99
        self.epsilon = 1.

        self.replayer = DQNReplayer(capacity=100000)

        quantile_count = 200
        self.cumprob_tensor = tf.range(1 / (2 * quantile_count),
                1, 1 / quantile_count)[np.newaxis, :, np.newaxis]

        self.evaluate_net = self.build_net(self.action_n, quantile_count)
        self.target_net = models.clone_model(self.evaluate_net)

    def build_net(self, action_n, quantile_count):
        net = keras.Sequential([
                keras.layers.Permute((2, 3, 1), input_shape=(4, 84, 84)),
                layers.Conv2D(32, kernel_size=8, strides=4, activation=nn.relu),
                layers.Conv2D(64, kernel_size=4, strides=2, activation=nn.relu),
                layers.Conv2D(64, kernel_size=3, strides=1, activation=nn.relu),
                layers.Flatten(),
                layers.Dense(512, activation=nn.relu),
                layers.Dense(action_n * quantile_count),
                layers.Reshape((action_n, quantile_count))])
        optimizer = optimizers.Adam(0.0001)
        net.compile(optimizer=optimizer)
        return net
        
    def reset(self, mode=None):
        self.mode = mode
        if mode == 'train':
            self.trajectory = []

    def step(self, observation, reward, done):
        state_tensor = tf.convert_to_tensor(np.array(observation)[np.newaxis],
                dtype=tf.float32)
        q_component_tensor = self.evaluate_net(state_tensor)
        q_tensor = tf.reduce_mean(q_component_tensor, axis=2)
        action_tensor = tf.math.argmax(q_tensor, axis=1)
        actions = action_tensor.numpy()
        action = actions[0]
        if self.mode == 'train':
            if np.random.rand() < self.epsilon:
                action = np.random.randint(0, self.action_n)
            
            self.trajectory += [observation, reward, done, action]
            if len(self.trajectory) >= 8:
                state, _, _, act, next_state, reward, done, _ = \
                        self.trajectory[-8:]
                self.replayer.store(state, act, reward, next_state, done)
            if self.replayer.count >= 1024 and self.replayer.count % 10 == 0:
                self.learn()
        return action

    def close(self):
        pass

    def update_net(self, target_net, evaluate_net, learning_rate=0.005):
        average_weights = [(1. - learning_rate) * t + learning_rate * e for t, e
                in zip(target_net.get_weights(), evaluate_net.get_weights())]
        target_net.set_weights(average_weights)

    def learn(self):
        # replay
        batch_size = 32
        states, actions, rewards, next_states, dones = \
                self.replayer.sample(batch_size)
        state_tensor = tf.convert_to_tensor(states, dtype=tf.float32)
        reward_tensor = tf.convert_to_tensor(rewards[:, np.newaxis],
                dtype=tf.float32)
        done_tensor = tf.convert_to_tensor(dones[:, np.newaxis],
                dtype=tf.float32)
        next_state_tensor = tf.convert_to_tensor(next_states, dtype=tf.float32)

        # compute target
        next_q_component_tensor = self.evaluate_net(next_state_tensor)
        next_q_tensor = tf.reduce_mean(next_q_component_tensor, axis=2)
        next_action_tensor = tf.math.argmax(next_q_tensor, axis=1)
        next_actions = next_action_tensor.numpy()
        all_next_q_quantile_tensor = self.target_net(next_state_tensor)
        indices = [[idx, next_action] for idx, next_action in
                enumerate(next_actions)]
        next_q_quantile_tensor = tf.gather_nd(all_next_q_quantile_tensor,
                indices)
        target_quantile_tensor = reward_tensor + self.gamma \
                * next_q_quantile_tensor * (1. - done_tensor)

        with tf.GradientTape() as tape:
            all_q_quantile_tensor = self.evaluate_net(state_tensor)
            indices = [[idx, action] for idx, action in enumerate(actions)]
            q_quantile_tensor = tf.gather_nd(all_q_quantile_tensor, indices)

            target_quantile_tensor = target_quantile_tensor[:, np.newaxis, :]
            q_quantile_tensor = q_quantile_tensor[:, :, np.newaxis]
            td_error_tensor = target_quantile_tensor - q_quantile_tensor
            abs_td_error_tensor = tf.math.abs(td_error_tensor)
            hubor_delta = 1.
            hubor_loss_tensor = tf.where(abs_td_error_tensor < hubor_delta,
                    0.5 * tf.square(td_error_tensor),
                    hubor_delta * (abs_td_error_tensor - 0.5 * hubor_delta))
            comparison_tensor = tf.cast(td_error_tensor < 0, dtype=tf.float32)
            quantile_regression_tensor = tf.math.abs(self.cumprob_tensor -
                    comparison_tensor)
            quantile_huber_loss_tensor = tf.reduce_mean(tf.reduce_sum(
                    hubor_loss_tensor * quantile_regression_tensor, axis=-1),
                    axis=1)
            loss_tensor = tf.reduce_mean(quantile_huber_loss_tensor)
        grads = tape.gradient(loss_tensor, self.evaluate_net.variables)
        self.evaluate_net.optimizer.apply_gradients(
                zip(grads, self.evaluate_net.variables))

        self.update_net(self.target_net, self.evaluate_net)

        self.epsilon = max(self.epsilon - 1e-5, 0.05)


agent = Agent(env)

Train & Test

In [None]:
def play_episode(env, agent, max_episode_steps=None, mode=None, render=False):
    observation, reward, done = env.reset(), 0., False
    agent.reset(mode=mode)
    episode_reward, elapsed_steps = 0., 0
    while True:
        action = agent.step(observation, reward, done)
        if render:
            env.render()
        if done:
            break
        observation, reward, done, _ = env.step(action)
        episode_reward += reward
        elapsed_steps += 1
        if max_episode_steps and elapsed_steps >= max_episode_steps:
            break
    agent.close()
    return episode_reward, elapsed_steps


logging.info('==== train ====')
episode_rewards = []
for episode in itertools.count():
    episode_reward, elapsed_steps = play_episode(env, agent, mode='train')
    episode_rewards.append(episode_reward)
    logging.debug('train episode %d: reward = %.2f, steps = %d',
            episode, episode_reward, elapsed_steps)
    if np.mean(episode_rewards[-5:]) > 16.:
        break
plt.plot(episode_rewards)


logging.info('==== test ====')
episode_rewards = []
for episode in range(100):
    episode_reward, elapsed_steps = play_episode(env, agent)
    episode_rewards.append(episode_reward)
    logging.debug('test episode %d: reward = %.2f, steps = %d',
            episode, episode_reward, elapsed_steps)
logging.info('average episode reward = %.2f ± %.2f',
        np.mean(episode_rewards), np.std(episode_rewards))

05:01:54 [INFO] ==== train ====
05:02:09 [DEBUG] train episode 0: reward = -20.00, steps = 1049
05:02:47 [DEBUG] train episode 1: reward = -19.00, steps = 1070
05:03:18 [DEBUG] train episode 2: reward = -20.00, steps = 874
05:03:55 [DEBUG] train episode 3: reward = -18.00, steps = 1031
05:04:31 [DEBUG] train episode 4: reward = -20.00, steps = 1073
05:04:57 [DEBUG] train episode 5: reward = -21.00, steps = 764
05:05:34 [DEBUG] train episode 6: reward = -19.00, steps = 1114
05:06:00 [DEBUG] train episode 7: reward = -21.00, steps = 778
05:06:32 [DEBUG] train episode 8: reward = -19.00, steps = 958
05:06:58 [DEBUG] train episode 9: reward = -21.00, steps = 762
05:07:25 [DEBUG] train episode 10: reward = -21.00, steps = 821
05:07:53 [DEBUG] train episode 11: reward = -21.00, steps = 823
05:08:27 [DEBUG] train episode 12: reward = -19.00, steps = 1007
05:09:01 [DEBUG] train episode 13: reward = -21.00, steps = 998
05:09:33 [DEBUG] train episode 14: reward = -20.00, steps = 969
05:10:00 [DE