In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import numpy as np
from PIL import Image
import cv2 #opencv
import sys 
import io
import time
import pandas as pd
import numpy as np
from IPython.display import clear_output
from random import randint
import os

from selenium import webdriver
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.common.keys import Keys


import random
import pickle
from io import BytesIO
import base64
import json
import time

generation_score = []

game_url = "chrome://dino"
chromebrowser_path = "/home/snowballfight/Documents/chromedriver"

init_script = "document.getElementsByClassName('runner-canvas')[0].id = 'runner-canvas'"

getbase64Script = "canvasRunner = document.getElementById('runner-canvas'); \
return canvasRunner.toDataURL().substring(22)"

In [2]:
class Game:
    def __init__(self, custom_config=True):
        chrome_options = Options()
        chrome_options.add_argument('disable-infobars')
        chrome_options.add_argument('--mute-audio')
        self.browser = webdriver.Chrome(executable_path = chromebrowser_path, options = chrome_options)
        self.browser.set_window_position(x=-10,y=0)
        self.browser.get('chrome://dino')
        self.browser.execute_script('Runner.config.Acceleration=0')
        self.browser.execute_script(init_script)
        self.browser.implicitly_wait(30)
        self.browser.maximize_window()
        
    def get_crashed(self):
        return self.browser.execute_script('return Runner.instance_.crashed')
    def get_playing(self):
        return self.browser.execute_script('return Runner.instance_.playing')
    def restart(self):
        self.browser.execute_script('Runner.instance_.restart')
        
    def press_up(self):
        self.browser.find_element_by_tag_name('body').send_keys(Keys.ARROW_UP)
    def press_down(self):
        self.browser.find_element_by_tag_name('body').send_keys(Keys.ARROW_DOWN)
    def press_right(self):
        self.browser.find_element_by_tag_name('body').send_keys(Keys.ARROW_RIGHT)
    
    def get_score(self):
        score_array = self.browser.execute_script('return Runner.instance_.distanceMeter.digits')
        score = ''.join(score_array)
        return int(score)
    
    def get_highscore(self):
        score_array = self.browser.execute_script('return Runner.instance_.distanceMeter.highScore')
        for i in range(len(score_array)):
            if score_array[i] == '':
                break
        score_array = score_array[i:]
        score = ''.join(score_array)
        return int(score)
    
    def pause(self):
        return self.browser.execute_script('return Runner.instance_.stop()')
    def resume(self):
        return self.browser.execute_script('return Runner.instance_.play()')
    def end(self):
        self.browser.close()
        

In [3]:
def screenshot(browser):
    image_b64 = browser.execute_script(getbase64Script)
    screen = np.array(Image.open(BytesIO(base64.b64decode(image_b64))))
    image = process_img(screen) #process image as required
    return image

def process_img(image):
    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) # rgb to grey scale
    image = image[:500, :600] # crop region of interest(ROI)
    image = cv2.resize(image, (84,84))
    image[image>0] = 255
    image = np.reshape(image, (84,84,1))
    return image

def image_to_tensor(image):
    image = np.transpose(image, (2,0,1)) #84x84x1 to 1x84x84
    image_tensor = image.astype(np.float32)
    image_tensor = torch.from_numpy(image_tensor)
    if torch.cuda.is_available(): #put on GPU if CUDA is avaiable
        image_tensor = image_tensor.cuda()
    return image_tensor

def show_img(graphs=False):
    while True:
        screen = (yield)
        window_title = 'Dino Agent'
        cv2.namedWindow(window_title, cv2.WINDOW_NORMAL)
        imS = cv2.resize(screen, (800,400))
        cv2.imshow(window_title, screen)
        if (cv2.waitKey(1) & 0xFF == ord('q')):
            cv2.destroyAllWindows()
            break

In [4]:
class DinoAgent:
    def __init__(self, game):
        self.dinoGame = game
        self.jump()
        
    def is_running(self):
        return self.dinoGame.get_playing()
    def is_crashed(self):
        return self.dinoGame.get_crashed()
    def jump(self):
        self.dinoGame.press_up()
    def duck(self):
        self.dinoGame.press_down()
    def DoNothing(self):
        self.dinoGame.press_right()

In [5]:
class Game_state:
    def __init__(self, agent, game):
        self._agent = agent
        self.dinoGame = game
        self._display = show_img()
        self._display.__next__()
    
    def get_next_state(self, actions):
        
        score = self.dinoGame.get_score()
        high_score = self.dinoGame.get_highscore()
        
        reward = 0.1
        is_over = False #game over
        
        if actions[0] == 1:
            self._agent.jump()
        elif actions[1] == 1:
            self._agent.duck()
        elif actions[2] == 1:
            self._agent.DoNothing()
            
        image = screenshot(self.dinoGame.browser)
        self._display.send(image)
        
        if self._agent.is_crashed():
            generation_score.append(score)
            time.sleep(0.1)
            self.dinoGame.restart()
            reward = -1
            is_over = True
        
        image = image_to_tensor(image)
        return image, reward, is_over, score, high_score

In [6]:
def RandomAgent():
    game = Game()
    dino = DinoAgent(game)
    game_state = Game_state(dino, game)
    number_of_actions = 3
    
    action = torch.zeros([number_of_actions], dtype=torch.float32)
    action[0] = 1
    
    image_data, reward, terminal, s_, h_ = game_state.get_next_state(action)
    state = torch.cat((image_data, image_data, image_data, image_data)).unsqueeze(0)
    
    
    while True:
        action = torch.zeros([number_of_actions], dtype=torch.float32)
        action_index = [torch.randint(number_of_actions, torch.Size([]), dtype=torch.int)]
        action[action_index] = 1
        image_data_1, reward, terminal, s_, h_ = game_state.get_next_state(action)
        state_1 = torch.cat((state.squeeze(0)[1:, :, :], image_data_1)).unsqueeze(0)
        state = state_1
    

In [7]:
dino_1 = RandomAgent()

KeyboardInterrupt: 