# Chrome Dino Game Reinforcement Learning Infrastructure
- Check out the following [github](https://luungoc2005.github.io/blog/2020-06-15-chrome-dino-game-reinforcement-learning/) where I obtained much of the code from (selenium).
- Testing out multiple split screen for getting data inputs (dinosaur frames)

In [1]:
import os
import time
import numpy as np

import gym
from gym import spaces

from selenium import webdriver # get webdriver.
from selenium.webdriver.common.keys import Keys # Keyboard actions
from selenium.common.exceptions import WebDriverException
from selenium.webdriver.common.by import By
from selenium.webdriver.chrome.service import Service

In [9]:
driver_path = "chrome_driver/chromedriver"

In [2]:
'''Test on what the actions will look like to obtain the images from environment.'''
_chrome_options = webdriver.ChromeOptions()
_chrome_options.add_argument("--mute-audio") # Cause we don't want to hear the 100 milestone and jump actions
_chrome_options.add_argument("--disable-gpu")
_chrome_options.add_experimental_option('excludeSwitches', ['enable-logging'])

serv_obj = Service("chrome_driver/chromedriver") # File name to chrome driver.

_driver = webdriver.Chrome(
    service= serv_obj,
    options=_chrome_options
)

# Wait for everything to load.
time.sleep(2)

try:
    _driver.get('chrome://dino')
except WebDriverException:
    pass

# Start Game
_driver.find_element(By.TAG_NAME,"body").send_keys(Keys.SPACE) # one space to start game.

# Creating the Environment

Here, we create a gym environment for our agent/models to train in (Capture basic actions)

In [12]:
class TRexEnv(gym.Env):
    """A class to control the custom environment for our agents. """
    
    def __init__(self, screen_width, screen_height, chromedriver_path, chrome_options, action_space_size=3):
        # These can be anything from the spaces module -- dig in later.
        self.screen_width = screen_width
        self.screen_height = screen_height
        self.chromedriver_path = chromedriver_path
        
        self.chrome_options = chrome_options
        
        self.num_observations = 0
        self.action_space = spaces.Discrete(action_space_size)
        self.observation_space = spaces.Box(
            low=0,
            high=255,
            shape=(self.screen_width, self.screen_height, 4),
            dtype=npuint8
        )
        
        # Set up actions for agent
        self.actions_map = [
        Keys.ARROW_RIGHT, # do nothing
        Keys.ARROW_UP, # jump
        Keys.ARROW_DOWN # down
        ]
        action_chains = ActionChains(self._driver)
        self.keydown_actions = [action_chains.key_down(item) for item in self.actions_map]
        self.keyup_actions = [action_chains.key_up(item) for item in self.actions_map]
    
        # Setup chrome driver
        self._driver = webdriver.Chrome(
            executable_path=self.chromedriver_path,
            options=self.chrome_options
        )
        self.current_key = None
        # current state represented by 4 images
        self.state_queue = deque(maxlen=4)
        
    def step(self, action):
        """ Take a step in the game environment.
        
        Args:
            - action (int) : An integer representing the action for the agent to take.
        Returns:
            - observation
            - reward
            - done
            - info 
        """
        self._driver.find_element_by_tag_name("body").send_keys(self.actions_map[action])
        observation = self._next_observation()
        
        done = self._get_done()
        
        # A very simple reward function.
        reward = 0.1 if not done else -1
        
        # Wait to send information
        time.sleep(0.01)
        
        # Info about the game
        info = {"score": self._get_score()}
        
        return observation, reward, done, info
    
    
    def reset(self):
        """ Reset the environment."""
        #TODO
        pass
    
    def _get_image(self):
        LEADING_TEXT = "data:image/png;base64,"
        _img = self._driver.execute_script(
            "return document.querySelector('canvas.runner-canvas').toDataURL()"
        )
        _img = _img[len(LEADING_TEXT):]
        return np.array(
            Image.open(BytesIO(base64.b64decode(_img)))
        )

    def _next_observation(self):
        """Helper function to get the next observation from the enviroment, & crop it.
        
        Args:
            - N/A
        
        Returns:
            - image
        
        """
        image = cv2.cvtColor(self._get_image(), cv2.COLOR_BGR2GRAY)
        image = image[:500, :480] # cropping
        image = cv2.resize(image, (self.screen_width, self.screen_height))
        self.state_queue.append(image)

        if len(self.state_queue) < 4:
        # during the start, we copy the images to make the sequence of 4
            return np.stack([image] * 4, axis=-1)
        else:
            return np.stack(self.state_queue, axis=-1)

        return image
    
    def _get_done(self):
        return self._driver.execute_script("return Runner.instance_.crashed")

In [4]:
class DinoEnvironment(gym.Env):
    '''Set up the environment for the Chrome Dinosaur Environment; Interface with Selenium.
        Github: https://github.com/openai/gym'''
    def __init__(self, screen_width: int, screen_height:int, chromedriver_path:str):
        '''Initialize variables needed to capture window.
            @params:
                screen_width: int - screen width dimension
                screen_height: int - screen height dimension
                chromedriver_path: str - location of chromedriver.exe
        '''
        self.screen_width = screen_width
        self.screen_height = screen_height
        self.chrome_driver_path = chromedriver_path
        # Set up number of actions and location of screen capture.
        self.action_space = space.Discrete(2) # Do nothing and up -- Set X number of elements.
        self.observation_space = spaces.Box(low = 0,
                                            high = 255,
                                            shape= (self.screen_width,self.screen_height,4),
                                            dtype = np.unit8)
        # Running Selenium
        _chrome_options = webdriver.ChromeOptions()
        _chrome_options.add_argument("--mute-audio") # Cause we don't want to hear the 100 milestone and jump actions
        _chrome_options.add_argument("--disable-gpu") # required for windows OS- probably should comment out for any other OS
        _chrome_options.add_experimental_option('excludeSwitches', ['enable-logging'])
        # Initialize driver.
        self._driver = webdriver.Chrome(
            executable_path=self.chromedriver_path,
            chrome_options=_chrome_options
        )
        self.current_key = None
        self.state_queue = deque(maxlen=4)

        self.actions_map = [
            Keys.ARROW_RIGHT, # do nothing
            Keys.ARROW_UP, # jump
        ]
        action_chains = ActionChains(self._driver)
        self.keyup_actions = [action_chains.key_up(item) for item in self.actions_map]
        
    def reset(self):
        '''Reset environment when fail.'''
        self._driver.get('chrome://dino')
        WebDriverWait(self._driver, 10).until(
            EC.presence_of_element_located((
                By.TAG_NAME,
                "Body"))
        )
        # trigger game start
        self._driver.find_element(By.TAG_NAME,"body").send_keys(Keys.SPACE)
        return self._next_observation()

    def _get_image(self):
        LEADING_TEXT = "data:image/png;base64,"
        _img = self._driver.execute_script(
            "return document.querySelector('canvas.runner-canvas').toDataURL()")
        _img = _img[len(LEADING_TEXT):]
        return np.array(Image.open(BytesIO(base64.b64decode(_img))))

    def _next_observation(self):
        image = cv2.cvtColor(self._get_image(), cv2.COLOR_BGR2GRAY)
        image = image[:500, :480] # TO EDIT IMAGE
        image = cv2.resize(image, (self.screen_width, self.screen_height))

        self.state_queue.append(image)

        if len(self.state_queue) < 4:
            return np.stack([image] * 4, axis=-1)
        else:
            return np.stack(self.state_queue, axis=-1)

        return image

    def _get_score(self):
        return int(''.join(
            self._driver.execute_script("return Runner.instance_.distanceMeter.digits")
        ))

    def _get_done(self):
        return not self._driver.execute_script("return Runner.instance_.playing")

    def step(self, action: int):
        self._driver.find_element_by_tag_name("body") \
            .send_keys(self.actions_map[action])

        obs = self._next_observation()

        done = self._get_done()
        reward = .1 if not done else -1

        time.sleep(.015)

        return obs, reward, done, {"score": self._get_score()}

    def render(self, mode: str='human'):
        img = cv2.cvtColor(self._get_image(), cv2.COLOR_BGR2RGB)
        if mode == 'rgb_array':
            return img
        elif mode == 'human':
            from gym.envs.classic_control import rendering
            if self.viewer is None:
                self.viewer = rendering.SimpleImageViewer()
            self.viewer.imshow(img)
            return self.viewer.isopen

    def close(self):
        if self.viewer is not None:
            self.viewer.close()
            self.viewer = None

In [None]:
import os

from stable_baselines import DQN
from stable_baselines.deepq.policies import CnnPolicy
from stable_baselines.common.vec_env import DummyVecEnv

env_lambda = lambda: ChromeDinoEnv(
    screen_width=96,
    screen_height=96,
    chromedriver_path=os.path.join(
        os.path.dirname(os.path.abspath(__file__)),
        "chromedriver"
    )
)
save_path = "chrome_dino_dqn_cnn"
env = DummyVecEnv([env_lambda])

model = DQN(
    CnnPolicy,
    env,
    verbose=1,
    tensorboard_log="./.tb_chromedino_env/",
)
model.learn(total_timesteps=100000)
model.save(save_path)

model = DQN.load(save_path, env=env)

obs = env.reset()

while True:
    action, _states = model.predict(obs)
    obs, reward, done, info = env.step(action)
    env.render(mode="human")