In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
import tensorflow as tf
import numpy as np

from tf_agents.environments import py_environment
from tf_agents.environments import tf_environment
from tf_agents.environments import tf_py_environment
from tf_agents.environments import utils
from tf_agents.specs import array_spec
from tf_agents.environments import wrappers
from tf_agents.environments import suite_gym
from tf_agents.trajectories import time_step as ts

tf.compat.v1.enable_v2_behavior()

In [None]:
# Based off of: https://luungoc2005.github.io/blog/2020-06-15-chrome-dino-game-reinforcement-learning/

from selenium import webdriver
#Allows us to emulate the stroke of keyboard keys
from selenium.webdriver.common.keys import Keys
# Simulates holding down the key
from selenium.webdriver.common.action_chains import ActionChains
import time


# Getting the canvas
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC

# Capture screenshot
from io import BytesIO
from PIL import Image
import base64
import cv2

# Manage the state
from collections import deque

# Debug
from matplotlib.pyplot import imshow


SCREEN_WIDTH = 320
SCREEN_HEIGHT = 150

DINO_URL = 'https://chromedino.com/'

NUM_SNAPSHOTS = 4

class DinoGameEnv(py_environment.PyEnvironment):
    def __init__(self, screen_width=SCREEN_WIDTH, screen_height=SCREEN_HEIGHT):
        self.screen_width = screen_width
        self.screen_height = screen_height
        
        # Set up Chrome emulator
        self._driver = webdriver.Chrome()
                
        _chrome_options = webdriver.ChromeOptions()
        _chrome_options.add_argument("--mute-audio")
        _chrome_options.add_argument("--disable-gpu") # if running on Windows
        
        self.actions = [
            Keys.ARROW_UP,     # JUMP
            Keys.ARROW_DOWN,   # DUCK
            Keys.ARROW_RIGHT,  # NOTHING
        ]
        
        self.action_chains = ActionChains(self._driver)
        self.keydown_actions = [self.action_chains.key_down(item) for item in self.actions]
        self.keyup_actions = [self.action_chains.key_up(item) for item in self.actions]
                
        # Observation state. Queue of size NUM_SNAPSHOTS. This is because
        # we want to check the velocity of objects as level increases
        # deque object is a collection that generalizes stacks and queues
        self.state = deque(maxlen=NUM_SNAPSHOTS)

        # 3 possible actions (We will make this discrete rather than continuous from [0,2] using a wrapper later.
        self._action_spec = array_spec.BoundedArraySpec(
            shape=(), dtype=np.int32, minimum=0, maximum=2, name='action')
        
        
        # BoundedArraySpec(shape=(320, 150, 4), dtype=dtype('uint8'), name='observation', minimum=0, maximum=255)
        
        # Buffer to store the observation from the state
        # uint8: https://scikit-image.org/docs/dev/user_guide/data_types.html
        self._observation_spec = array_spec.BoundedArraySpec(
            shape=(150, 600, 4),
            dtype=np.uint8,
            minimum=0,
            maximum=255,
            name='observation'
        )
    
    def action_spec(self):
        return self._action_spec
    
    def observation_spec(self):
        return self._observation_spec
    
    def _reset(self):
        self._driver.get(DINO_URL)
        
        # Wait until the game is loaded
        WebDriverWait(self._driver, 10).until(
            EC.presence_of_element_located((By.CLASS_NAME, "runner-canvas"))
        )
        
        body = self._driver.find_element_by_tag_name("body")
        body.send_keys(Keys.SPACE)
        time.sleep(0.1)

        return ts.restart(self._next_observation())
    
    def _screenshot(self):
        canvas = self._driver.find_element_by_class_name("runner-canvas")
        canvas_base64 = self._driver.execute_script("return arguments[0].toDataURL('image/png').substring(21);", canvas)
        return np.array(Image.open(BytesIO(base64.b64decode(canvas_base64))))
    
    def _next_observation(self):
        # Preprocess the image
        gray = cv2.cvtColor(self._screenshot(), cv2.COLOR_BGR2GRAY)
        
        # Crop the image
        # cropped = gray[:500, :480]
        # Scale the image
        # rescaled = cv2.resize(cropped, (self.screen_width, self.screen_height))

        # Push the last 4 screenshots as observations
        self.state.append(gray)
        
        # Collate last number of screenshots to account for speed increase
        # We could probably also use concat, but example uses stack
        if len(self.state) < NUM_SNAPSHOTS:
            initial_snapshots = np.stack([gray]*NUM_SNAPSHOTS, axis=-1)
            
            # Used to output the shape because the observation shape has to be consistent
            # print('shape:')
            # print(np.shape(initial_snapshots))
            
            return initial_snapshots
        else:
            return np.stack(self.state, axis=-1)
    
    def _is_game_stopped(self):
        return not self._driver.execute_script("return Runner.instance_.playing")
    
    def _is_game_over(self):
        didDinoCrash = self._driver.execute_script("return Runner.instance_.crashed")
        print("did dino crash?")
        print(didDinoCrash)
        # Did the dino crash?
        return didDinoCrash
    
    def _get_score(self):
        return int(self._driver.execute_script("return Runner.instance_.distanceRan || 0"))

    def _step(self, action):
#         print(f'is game over: {self._is_game_over()} and current score: {self._get_score()}')
        if self._is_game_over():
            # The last action ended the episode. Ignore the current action and start a new episode.
            return self.reset()
        
        # Trigger an action based on action id.
        int_action = int(action)

        if int_action == 0:
            self.action_chains.key_down(Keys.ARROW_UP).pause(0.5).key_up(Keys.ARROW_UP).perform()
#             self.action_chains.key_down(Keys.ARROW_UP).pause(0.25).Build().perform()

        elif int_action == 1:
            self.action_chains.key_down(Keys.ARROW_DOWN).pause(0.5).key_up(Keys.ARROW_DOWN).perform()
#             self.action_chains.key_down(Keys.ARROW_DOWN).pause(0.25).Build().perform()
        else:
            self.action_chains.key_down(Keys.ARROW_RIGHT).pause(0.5).key_up(Keys.ARROW_RIGHT).perform()
#             self.action_chains.key_down(Keys.ARROW_RIGHT).pause(0.25).Build().perform()
        
        if self._is_game_over():
            print("GAME OVER!!!")
            return ts.termination(self._next_observation(), reward=-1)
        else:
            return ts.transition(self._next_observation(), reward=0.1, discount=1.0)

    def render(self, mode='rgb_array'):
        gray = cv2.cvtColor(self._screenshot(), cv2.COLOR_BGR2GRAY)
        return gray

        
        

In [43]:


jump_action = np.array(0, dtype=np.int32)
duck = np.array(1, dtype=np.int32)
nothing = np.array(1, dtype=np.int32)


import PIL.Image
from tf_agents.specs import tensor_spec
from tf_agents.networks import sequential
from tf_agents.agents.dqn import dqn_agent
from tf_agents.utils import common


# Test the environment
dino_env = DinoGameEnv()
dino_env = wrappers.ActionDiscretizeWrapper(dino_env, num_actions=3)
print('Discretized Action Spec:', dino_env.action_spec())


utils.validate_py_environment(dino_env, episodes=5)




Discretized Action Spec: BoundedArraySpec(shape=(), dtype=dtype('int32'), name='action', minimum=0, maximum=2)


NoSuchWindowException: Message: no such window: target window already closed
from unknown error: web view not found
  (Session info: chrome=92.0.4515.131)


In [7]:
jump_action = np.array(0, dtype=np.int32)
duck = np.array(1, dtype=np.int32)
nothing = np.array(1, dtype=np.int32)


import PIL.Image
from tf_agents.specs import tensor_spec
from tf_agents.networks import sequential, q_network
from tf_agents.agents.dqn import dqn_agent
from tf_agents.utils import common

from tf_agents.train.utils import spec_utils, train_utils

# Test the environment
dino_env = DinoGameEnv()
dino_env = wrappers.ActionDiscretizeWrapper(dino_env, num_actions=3)
# print('Discretized Action Spec:', discrete_action_env.action_spec())

# dino_env.reset()
# PIL.Image.fromarray(dino_env.render())


# Environment step takes action in environment and returns a TimeStep tuple containing the next observation of the env and reward for the action
# print('Observation')
# print(dino_env.time_step_spec().observation)



In [8]:
train_env = tf_py_environment.TFPyEnvironment(dino_env)
eval_env = tf_py_environment.TFPyEnvironment(dino_env)

unused_observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
      spec_utils.get_tensor_specs(train_env))


#Hyperparameters
num_iterations = 20000 # @param {type:"integer"}

initial_collect_steps = 100  # @param {type:"integer"} 
collect_steps_per_iteration = 1  # @param {type:"integer"}
replay_buffer_max_length = 100000  # @param {type:"integer"}

batch_size = 64  # @param {type:"integer"}
learning_rate = 1e-3  # @param {type:"number"}
log_interval = 200  # @param {type:"integer"}

num_eval_episodes = 10  # @param {type:"integer"}
eval_interval = 1000  # @param {type:"integer"}

# Agent
# Algorithm that will used to solve the RL problem
# In this case, we're going to create a QNetwork (predict expected returns)
env = train_env

fc_layer_params = (100, 50)
# action_tensor_spec = tensor_spec.from_spec(env.action_spec())
num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1

# Define a helper function to create Dense layers configured with the right
# activation and kernel initializer.
def dense_layer(num_units):
  return tf.keras.layers.Dense(
      num_units,
      activation=tf.keras.activations.relu,
      kernel_initializer=tf.keras.initializers.VarianceScaling(
          scale=2.0, mode='fan_in', distribution='truncated_normal'))

# QNetwork consists of a sequence of Dense layers followed by a dense layer
# with `num_actions` units to generate one q_value per available action as
# it's output.
dense_layers = [dense_layer(num_units) for num_units in fc_layer_params]
q_values_layer = tf.keras.layers.Dense(
    num_actions,
    activation=None,
    kernel_initializer=tf.keras.initializers.RandomUniform(
        minval=-0.03, maxval=0.03),
    bias_initializer=tf.keras.initializers.Constant(-0.2))

# Unlike CNN this predicts the QValue or expected return for all actions. Input shape is therefore 3.
# q_net = sequential.Sequential(dense_layers + [q_values_layer], input_spec=tf.TensorSpec(shape=(3,)))
# q_net = sequential.Sequential(dense_layers + [q_values_layer])

optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
      spec_utils.get_tensor_specs(train_env))

train_step_counter = tf.Variable(0)
train_step = train_utils.create_train_step()

q_net = q_network.QNetwork(
    observation_tensor_spec,
    action_tensor_spec,
    fc_layer_params=fc_layer_params)


agent = dqn_agent.DqnAgent(
    time_step_tensor_spec,
    action_tensor_spec,
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=common.element_wise_squared_loss,
    train_step_counter=train_step)

agent.initialize()



In [9]:
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.trajectories import trajectory

### Metrics and Eval

def compute_avg_return(environment, policy, num_episodes=10):
  total_return = 0.0
  for _ in range(num_episodes):

    time_step = environment.reset()
    episode_return = 0.0

    while not time_step.is_last():
      action_step = policy.action(time_step)
      time_step = environment.step(action_step.action)
      episode_return += time_step.reward
    print("Episode return")
    print(episode_return)
    total_return += episode_return

  avg_return = total_return / num_episodes
  return avg_return.numpy()[0]

def collect_step(environment, policy, buffer):
  time_step = environment.current_time_step()
  action_step = policy.action(time_step)
  next_time_step = environment.step(action_step.action)
  traj = trajectory.from_transition(time_step, action_step, next_time_step)

  # Add trajectory to the replay buffer
  buffer.add_batch(traj)

def collect_data(env, policy, buffer, steps):
  for _ in range(steps):
    collect_step(env, policy, buffer)


# Policy
# Way the agent acts in the environment

# eval_policy = agent.policy
# collect_policy = agent.collect_policy
# random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
#                                                 train_env.action_spec())

dino_env = DinoGameEnv()
dino_env = wrappers.ActionDiscretizeWrapper(dino_env, num_actions=3)
train_env = tf_py_environment.TFPyEnvironment(dino_env)
eval_env = tf_py_environment.TFPyEnvironment(dino_env)


replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
    data_spec=agent.collect_data_spec,
    batch_size=train_env.batch_size,
    max_length=replay_buffer_max_length)

# Dataset generates trajectories with shape [Bx2x...]
dataset = replay_buffer.as_dataset(
    num_parallel_calls=3, 
    sample_batch_size=batch_size, 
    num_steps=2).prefetch(3)

iterator = iter(dataset)

try:
  %%time
except:
  pass

# (Optional) Optimize by wrapping some of the code in a graph using TF function.
agent.train = common.function(agent.train)

# Reset the train step
agent.train_step_counter.assign(0)

# Evaluate the agent's policy once before training.
avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
returns = [avg_return]

for _ in range(num_iterations):

  # Collect a few steps using collect_policy and save to the replay buffer.
  collect_data(train_env, agent.collect_policy, replay_buffer, collect_steps_per_iteration)

  # Sample a batch of data from the buffer and update the agent's network.
  experience, unused_info = next(iterator)
  train_loss = agent.train(experience).loss

  step = agent.train_step_counter.numpy()

  if step % log_interval == 0:
    print('step = {0}: loss = {1}'.format(step, train_loss))

  if step % eval_interval == 0:
    avg_return = compute_avg_return(eval_env, agent.policy, num_eval_episodes)
    print('step = {0}: Average Return = {1}'.format(step, avg_return))
    returns.append(avg_return)


Instructions for updating:
Use `as_dataset(..., single_deterministic_pass=False) instead.
GAME OVER!
Episode return
tf.Tensor([-0.6], shape=(1,), dtype=float32)
GAME OVER!
Episode return
tf.Tensor([-1.], shape=(1,), dtype=float32)
GAME OVER!
Episode return
tf.Tensor([-1.], shape=(1,), dtype=float32)
GAME OVER!
Episode return
tf.Tensor([-1.], shape=(1,), dtype=float32)
GAME OVER!
Episode return
tf.Tensor([-1.], shape=(1,), dtype=float32)


KeyboardInterrupt: 