In [1]:
import numpy as np
import time
from selenium import webdriver

import cv2 as cv

from gymnasium import Env
from gymnasium.spaces import Box, Discrete

import matplotlib.pyplot as plt

In [2]:
class Dino(Env):
    def __init__(self, render_mode="human"):
        super().__init__()
        self.render_mode = render_mode
        self.observation_space = Box(low= 0, high=255, shape=(180, 260, 3), dtype=np.uint8)
        self.action_space = Discrete(3)
        self.game_over = cv.imread("game_over.png")
        self.init_dino()


    def init_dino(self):
        chrome_options = webdriver.ChromeOptions()
        chrome_options.add_argument("--mute-audio")
        chrome_options.add_argument("--window-size=640,480")
        chrome = webdriver.Chrome(options=chrome_options)
        # To avoid an exception while opening a URL without a connection
        try:
            chrome.get("chrome://dino")
        except:
            pass

        self.chrome = chrome
        # Load the game
        time.sleep(1)
        self.send_key(webdriver.Keys.SPACE)
        time.sleep(2)
        self.chrome.execute_script("Runner.instance_.gameOver()")
    
    def step(self, action):
        action_map = {0: webdriver.Keys.SPACE, 1: webdriver.Keys.ARROW_DOWN, 2: "no_op"}
        reward = 2
        if action != 2:
            self.send_key(action_map[action])
            reward = 1
        
        self.set_last_record()
        done, _ = self.get_done()
        observation = self.get_observation()
        info = {}
        return observation, reward, done, False, info
    
    def set_last_record(self):
        binary_image = self.chrome.get_screenshot_as_png()
        buffered_image = np.frombuffer(binary_image, np.uint8)
        last_record = cv.imdecode(buffered_image, cv.IMREAD_COLOR)
        gray_mask = np.all(last_record == np.array([172, 172, 172]), axis=-1)
        last_record[~gray_mask] = [0,0,0]
        self.last_record = last_record

    def get_last_record(self):
        return self.last_record

    def get_observation(self):
        observation = self.last_record[125:305, 45:305]

        return observation
    
    def get_done(self):
        done_record = np.copy(self.last_record[157:180,305:638])
        # Keeping only the gray pixels (172,172,172)
        gray_mask = np.all(done_record == np.array([172, 172, 172]), axis=-1)
        done_record[~gray_mask] = [0,0,0]
        similarity = self.get_game_over_simmilarity(done_record)
        # I take an error margin of 5%
        done = similarity >= 95

        return done, done_record
    
    def close(self):
        self.chrome.quit()
        cv.destroyAllWindows()

    def render(self):
        rendered_image = self.reshape_images()
        cv.imshow("Dino", rendered_image)
        if cv.waitKey(1) & 0xFF == ord('q'):
            cv.destroyAllWindows()
        del rendered_image

    def reset(self,seed=None):
        time.sleep(2)
        self.send_key(webdriver.Keys.SPACE)
        self.set_last_record()
        info = {}
        return (self.get_observation(), info)

    def get_game_over_simmilarity(self, image):
        difference = np.sum(np.abs(self.game_over - image))
        pixels = image.size
        similarity = (100 - (difference / pixels))

        return similarity
    
    def send_key(self, key):
        webdriver.ActionChains(self.chrome).send_keys(key).perform()
    
    # TO DO: Don't do unnecesary calculations
    def reshape_images(self):
        is_done, done_image = self.get_done()
        observation_image = self.get_observation()
        # Calculate dimensions
        done_box_height = observation_image.shape[0] - done_image.shape[0]
        observation_box_width = done_image.shape[1] - observation_image.shape[1]
        # Create colored boxes
        color = [0, 255, 0] if is_done else [255, 0, 0]
        upper_done_box = np.full((done_box_height // 2, done_image.shape[1], 3), color, dtype=np.uint8)
        bottom_done_box = np.full((done_box_height - upper_done_box.shape[0], done_image.shape[1], 3), color, dtype=np.uint8)
        left_observation_box = np.full((observation_image.shape[0], observation_box_width // 2, 3), 0, dtype=np.uint8)
        right_observation_box = np.full((observation_image.shape[0], observation_box_width - left_observation_box.shape[1], 3), 0, dtype=np.uint8)
        # Construct the final images
        done = np.vstack((upper_done_box, done_image, bottom_done_box))
        observation = np.hstack((left_observation_box, observation_image, right_observation_box))

        return np.vstack((observation, done))



In [8]:
env = Dino()
for episode in range(30): 
    try:
        obs = env.reset()
        done = False  
        total_reward   = 0
        while not done:

                obs, reward,  done, _, info =  env.step(env.action_space.sample())
                total_reward  += reward
                env.render()
    except:
        env.close()
        cv.destroyAllWindows()
        break
    print('Total Reward for episode {} is {}'.format(episode, total_reward)) 
   

Total Reward for episode 0 is 21
Total Reward for episode 1 is 39
Total Reward for episode 2 is 27
Total Reward for episode 3 is 30
Total Reward for episode 4 is 27
Total Reward for episode 5 is 31
Total Reward for episode 6 is 25
Total Reward for episode 7 is 28
Total Reward for episode 8 is 20
Total Reward for episode 9 is 24
Total Reward for episode 10 is 30
Total Reward for episode 11 is 28
Total Reward for episode 12 is 28
Total Reward for episode 13 is 22
Total Reward for episode 14 is 39
Total Reward for episode 15 is 24
Total Reward for episode 16 is 29
Total Reward for episode 17 is 23
Total Reward for episode 18 is 31
Total Reward for episode 19 is 25
Total Reward for episode 20 is 24
Total Reward for episode 21 is 23


NoSuchWindowException: Message: no such window: target window already closed
from unknown error: web view not found
  (Session info: chrome=120.0.6099.201)
Stacktrace:
	GetHandleVerifier [0x00007FF67A0A2142+3514994]
	(No symbol) [0x00007FF679CC0CE2]
	(No symbol) [0x00007FF679B676AA]
	(No symbol) [0x00007FF679B40AFD]
	(No symbol) [0x00007FF679BDCB1B]
	(No symbol) [0x00007FF679BF218F]
	(No symbol) [0x00007FF679BD5D93]
	(No symbol) [0x00007FF679BA4BDC]
	(No symbol) [0x00007FF679BA5C64]
	GetHandleVerifier [0x00007FF67A0CE16B+3695259]
	GetHandleVerifier [0x00007FF67A126737+4057191]
	GetHandleVerifier [0x00007FF67A11E4E3+4023827]
	GetHandleVerifier [0x00007FF679DF04F9+689705]
	(No symbol) [0x00007FF679CCC048]
	(No symbol) [0x00007FF679CC8044]
	(No symbol) [0x00007FF679CC81C9]
	(No symbol) [0x00007FF679CB88C4]
	BaseThreadInitThunk [0x00007FF8699B257D+29]
	RtlUserThreadStart [0x00007FF86ADAAA58+40]


In [3]:
# Import os for file path management
import os
# Import Base Callback for saving models
from stable_baselines3.common.callbacks import BaseCallback
# Check Environment    
from stable_baselines3.common import env_checker

In [53]:
env = Dino()
env_checker.check_env(env)
env.close()

In [4]:
class TrainAndLoggingCallback(BaseCallback):

    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)

        return True

In [5]:
CHECKPOINT_DIR = './train/'
LOG_DIR = './logs/'

In [6]:
callback = TrainAndLoggingCallback(check_freq=100, save_path=CHECKPOINT_DIR)

In [7]:
from stable_baselines3 import DQN
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack

In [8]:
env = Dino()
model = DQN('CnnPolicy', env, tensorboard_log=LOG_DIR, verbose=1, buffer_size= 10000 , learning_starts=100)
model.learn(total_timesteps=10000, callback=callback)


Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Wrapping the env in a VecTransposeImage.
Logging to ./logs/DQN_1
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 26.2     |
|    ep_rew_mean      | 35       |
|    exploration_rate | 0.9      |
| time/               |          |
|    episodes         | 4        |
|    fps              | 2        |
|    time_elapsed     | 36       |
|    total_timesteps  | 105      |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.983    |
|    n_updates        | 1        |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 19.6     |
|    ep_rew_mean      | 27       |
|    exploration_rate | 0.851    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 2        |
|    time_elapsed     | 66       |
|

NoSuchWindowException: Message: no such window: target window already closed
from unknown error: web view not found
  (Session info: chrome=120.0.6099.201)
Stacktrace:
	GetHandleVerifier [0x00007FF67A0A2142+3514994]
	(No symbol) [0x00007FF679CC0CE2]
	(No symbol) [0x00007FF679B676AA]
	(No symbol) [0x00007FF679B40AFD]
	(No symbol) [0x00007FF679BDCB1B]
	(No symbol) [0x00007FF679BF218F]
	(No symbol) [0x00007FF679BD5D93]
	(No symbol) [0x00007FF679BA4BDC]
	(No symbol) [0x00007FF679BA5C64]
	GetHandleVerifier [0x00007FF67A0CE16B+3695259]
	GetHandleVerifier [0x00007FF67A126737+4057191]
	GetHandleVerifier [0x00007FF67A11E4E3+4023827]
	GetHandleVerifier [0x00007FF679DF04F9+689705]
	(No symbol) [0x00007FF679CCC048]
	(No symbol) [0x00007FF679CC8044]
	(No symbol) [0x00007FF679CC81C9]
	(No symbol) [0x00007FF679CB88C4]
	BaseThreadInitThunk [0x00007FF8699B257D+29]
	RtlUserThreadStart [0x00007FF86ADAAA58+40]


In [62]:
env.close()

In [16]:
model.load('train/best_model_16000') 


<stable_baselines3.dqn.dqn.DQN at 0x224e4ccaf90>

In [20]:
for episode in range(5): 
    obs,_ = env.reset()
    done = False
    total_reward = 0
    while not done: 
        action, _ = model.predict(obs)
        obs, reward, done,_ ,info = env.step(int(action))
        time.sleep(0.01)
        total_reward += reward
    print('Total Reward for episode {} is {}'.format(episode, total_reward))
    time.sleep(2)

Total Reward for episode 0 is 565.0
Total Reward for episode 1 is 3.5
Total Reward for episode 2 is 0.5
Total Reward for episode 3 is 1


KeyboardInterrupt: 