# AI Learns Chrome T-Rex Game!

In [11]:
"""
Lets import these babies
"""

# Bread and butter of any DL Problem
import numpy as np
from matplotlib import pyplot as plt

# For Automating browser actions
from selenium import webdriver
from selenium.webdriver.chrome.options import Options
from selenium.webdriver.common.keys import Keys

# Importing Pytorch
import torch
import torch.nn as nn
import torch.optim as optim

# Now this is random
import random

# Collections
from collections import namedtuple

# For image processing
import cv2

# Just because I can't do math
import math

# Yeah this is the time stone
import time

import pyautogui

In [12]:


"""# Setting up display
%matplotlib inline
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython: from IPython import display"""


"# Setting up display\n%matplotlib inline\nis_ipython = 'inline' in matplotlib.get_backend()\nif is_ipython: from IPython import display"

In [13]:
Experience = namedtuple(
    'Experience',
    ('state', 'action', 'next_state', 'reward', 'done')
)

e = Experience(1, 2, 3, 5, True)
e

Experience(state=1, action=2, next_state=3, reward=5, done=True)

In [14]:
class ReplayMemory():
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.push_count = 0

    def push(self, experience):
        if len(self.memory) < self.capacity:
            self.memory.append(experience)
        else:
            self.memory[self.push_count % self.capacity] = experience
        self.push_count += 1

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def can_provide_sample(self, batch_size):
        return len(self.memory) >= batch_size

    def len(self):
        return len(self.memory)


class EpsilonGreedyStrategy():
    def __init__(self, start, end, decay):
        self.start = start
        self.end = end
        self.decay = decay

    def get_exploration_rate(self, current_episode):
        exploration_rate = self.start + (self.end-self.start) *\
                      math.exp(-1. * current_episode * self.decay)

        return exploration_rate 




In [65]:
class Agent():
    
    def __init__(self, strategy, num_actions, device):
        self.current_episode = 0
        self.strategy = strategy
        self.num_actions = num_actions
        self.device = device
    
    def select_action(self, state, policy_net):
        rate = self.strategy.get_exploration_rate(self.current_episode)
        if rate <= random.random():
            # Exploration
            return random.randrange(self.num_actions)
        else:
            # Exploitation
            with torch.no_grad():
                # print(f"The agent has decided to take action #{policy_net(state).argmax(dim=1)}")
                q_values = policy_net(state)
                action = q_values.argmax(dim=1).item()
                return action




# This is our Enviroment manager 
chrome_browser_path = ".//Driver/chromedriver.exe"
init_script = "document.getElementsByClassName('runner-canvas')[0].id = 'runner-canvas'"
game_url = 'chrome://dino'

class ChromeManager():
    def __init__(self, device):
        print("Hello")
        self.device = device
        self.done = False
        self.initialize_chrome()
        self.get_state = ImagePreProcessing(self.driver, device)
    
    def initialize_chrome(self):
        chrome_options = Options()
        chrome_options.add_argument("start-maximized")
        chrome_options.add_argument("disable-infobars")
        chrome_options.add_argument("--mute-audio")
        self.driver = webdriver.Chrome(executable_path=chrome_browser_path, chrome_options=chrome_options)

    # Restarting our game
    def reset(self):
        try:
            self.driver.get(game_url)
        except:
            pass
            # print('Environment has been reset')
        self.driver.execute_script("Runner.config.ACCELERATION=0")
        self.driver.execute_script(init_script)
        return self.driver.execute_script("return Runner.instance_.restart()")

    # To close the game
    def close(self):
        return self.driver.execute_script("return Runner.instance_.crashed")

    # crash
    def get_crashed(self):
        return self.driver.execute_script("return Runner.instance_.crashed")

    # Start?
    def start(self):
        try:
            self.driver.get(game_url)
        except:
            pass
            # print('Exception has been handled')
        self.driver.execute_script('Runner.instance_.playing=true')
        self.press_up()
        # time.sleep(3)

    # To get the current score of our game
    def get_score(self):
        score_array = self.driver.execute_script('return Runner.instance_.distanceMeter.digits')
        # Scores are stored in the form '0, 1, 4', for a score of 14.
        score = ''.join(score_array)
        return int(score)

    # To get the highscore of our game
    def get_highscore(self):
        highscore_array = self.driver.execute_script('return Runner.instance_.distanceMeter.highScore')
        for i in range(len(highscore_array)):
            if highscore_array[i] == 0:
                break
            highscore_array = highscore_array[i:]
            highscore = ''.join(highscore_array)
            return int(highscore)
    
    # This function will be called when we want to jump
    def press_up(self):
        self.driver.find_element_by_tag_name('body').send_keys(Keys.ARROW_UP)

    # This function will be called when we want to duck
    def press_down(self):
        pyautogui.keyDown("down")
        time.sleep(0.2)
        pyautogui.keyUp("down")
        # self.driver.find_element_by_tag_name('body').send_keys(Keys.ARROW_DOWN)

    # This function will be called when we don't want to jump or duck
    def press_right(self):
        self.driver.find_element_by_tag_name('body').send_keys(Keys.ARROW_RIGHT)

    # Taking actions
    def take_action(self, action):
        score = self.get_score()
        highscore = self.get_highscore()
        reward = 0.1
        if action == 0:
            # T-Rex Jumps
            self.press_up()
        elif action == 1:
            # T-Rex ducks
            self.press_down()
        elif action == 2:
            # T-Rex procrastinates
            self.press_right()

        state = self.get_state.screenshot()
        self.done = False
        if self.get_crashed():
            time.sleep(0.1)
            reward = -1
            self.done = True
        return state, reward, self.done, score, highscore
        


    # Hard stop
    def close_all(self):
        self.driver.close()
        self.driver.quit()
        try:
            os.system('cmd /c taskkill /F /IM chromedriver.exe')
        except:
            print('No tasks found!')


class EnviromentManager():
    def __init__(self, agent, environment):
        self.agent = agent
        self.environment = environment


class ImagePreProcessing():
    def __init__(self, driver, device):
        self.driver = driver
        self.device = device

    # Take Screenshot. We'll be needing 4 screenshots, so yeah
    def screenshot(self):
        file_path = './/Screenshots/'
        current_state = []
        for i in range(1, 5):
            file_name = file_path + 'Screenshot' + str(i) + '.jpg'
            self.driver.save_screenshot(file_name)
            image_tensor = self.process_image(file_name)
            current_state.append(image_tensor)
        current_state = torch.cat(current_state).unsqueeze(0).to(self.device)
        print('THis marks the end of th capturing process')
        return current_state

    
    def process_image(self, image_file):
        image = cv2.imread(image_file)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
        image[image>255] = 255
        image = cv2.resize(image, (84, 84))
        image = np.reshape(image, (1, 84, 84))
        #image = torch.from_numpy(image)
        return self.image_to_tensor(image)


    def image_to_tensor(self, image):
        imag = image.astype(np.float32)
        image_tensor = torch.from_numpy(image)
        image_tensor = image_tensor.to(self.device, dtype=torch.float32)
        return image_tensor


    

In [66]:

# Starting to Build our DQN
class DinoNetwork(nn.Module):
    def __init__(self):
        super(DinoNetwork, self).__init__()
        self.number_of_actions = 3
        self.gamma = 0.99
        self.initial_epsilon = 0.1
        self.final_epsilon = 0.0001
        self.number_of_iterations = 10000
        self.replay_memory_size = 1000
        self.minibatch_size = 1 

        self.conv1 = nn.Sequential(
            nn.Conv2d(4, 32, 5, 3),
            nn.ReLU(),
            nn.BatchNorm2d(32)
        )

        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 5, 3),
            nn.ReLU(),
            nn.BatchNorm2d(64)
        )

        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 64, 3, 3),
            nn.ReLU(),
            nn.BatchNorm2d(64)
        )

        self.fc1 = nn.Sequential(
            nn.Linear(256, 64),
            nn.ReLU(),
            # nn.BatchNorm1d(256)
        )

        self.fc2 = nn.Sequential(
            nn.Linear(64, 3),
            nn.Sigmoid()
        )


    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.conv2(x)
        x = self.conv3(x)
        # print(f"After last conv layer, shape is {x.shape}")
        x = x.flatten(start_dim=1, end_dim=-1)
        # print(f"After Flattening shape is {x.shape}")
        x = self.fc1(x)
        x = self.fc2(x) 
        return x


In [67]:
# Defining Hyperparameters
batch_size = 4

# Discount factor for Belmann equation
gamma = 0.99

# Start value of the exploration rate
eps_start = 1

# End value of the exploration rate
eps_end = 0.01

# Decay rate to decay decay epsilon over time
eps_decay = 0.001

# How frequently the weights of target Network will be updated, in terms of episodes
target_update = 10

# Capacity of replay memory
memory_size = 100_000

# Learning rate of policy network
lr = 0.001

# Total number of episodes, dumbo
num_episodes = 12

# Setting up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Setting loss
criterion = nn.MSELoss()


In [68]:
# Environment goes brr

environ = ChromeManager(device=device)

# Strategy goes brr
strategy = EpsilonGreedyStrategy(eps_start, eps_end, eps_decay)

# Agent goes brr
agent = Agent(strategy, 3, device)

# Replay memory goes brr
memory = ReplayMemory(memory_size)

# Defining our Policy network and target network
policy_net = DinoNetwork().to(device)
target_net = DinoNetwork().to(device)

# For taking screenshots
get_ss = ImagePreProcessing(environ.driver, device)

# Now we set the weights of our target net to be the same as our policy net
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

# Choosing our optimizer
optimizer = optim.Adam(params=policy_net.parameters(), lr=lr)

"""environ.reset()
environ.start()
state = get_ss.screenshot()
environ.close_all()"""


Hello


'environ.reset()\nenviron.start()\nstate = get_ss.screenshot()\nenviron.close_all()'

In [69]:
def extract_tensors(experiences):
    # Converting batches of exp to exp of batches
    batch = Experience(*zip(*experiences))
    t1 = torch.cat(batch.state)
    t2 = torch.cat(batch.action)
    t3 = torch.cat(batch.next_state)
    t4 = torch.cat(batch.reward)
    t5 = torch.cat(batch.done)

    return (t1, t2, t3, t4, t5)

In [72]:
def train():    
    possible_actions = ['Jump', 'Duck', 'Do Nothing']
    for episodes in range(num_episodes):
        environ.reset()
        environ.start()
        state = get_ss.screenshot()
        done = False
        # While the episode is running
        # print(type(done))
        while not done:
            # print(type(done))
            action = agent.select_action(state, policy_net)
            print(f"The agent decided to {possible_actions[action]}")
            next_state, reward, done, score, highscore  = environ.take_action(action)
            
            # Converting action, reward, done to Pytorch tensors
            action_tensor = torch.tensor(action, dtype=torch.int64, device=device)
            reward_tensor = torch.tensor(reward, dtype=torch.int64, device=device)
            done_tensor = torch.tensor(done, dtype=torch.bool, device=device)
            #print(f"After changing, action :{type(action_tensor)}\ndone: {type(done_tensor)}\treward: {type(reward_tensor)}")
            memory.push(Experience(state, action_tensor.unsqueeze(0), next_state, reward_tensor.unsqueeze(0), done_tensor.unsqueeze(0)))
            

            # print(f"action: {action_tensor}\treward: {reward_tensor}\ndone: {done_tensor}")
            
            """if memory.can_provide_sample(batch_size):
                experiences = memory.sample(batch_size)
                states, actions, rewards, next_states = extract_tensors(experiences)"""
            # print(f"batch size is {batch_size} an memory size is {memory_size}, and currently we have {memory.len()} samples. So we can provide samples? {memory.can_provide_sample(batch_size)}")
            if not memory.can_provide_sample(batch_size):
                continue
            else:
                
                experiences_minibatch = memory.sample(batch_size)
                # print(f"Shape of exp minibatch{len(experiences_minibatch)}")
                experience_tensors = extract_tensors(experiences_minibatch)
                states, actions, next_states, rewards, terminal = experience_tensors

                # Get the current state from the policy network, pass it to the target n/w to get the q values of the next state.
                # Different logic for terminating state and a normal state
                current_q_values =  policy_net.forward(states)
                next_q_values = target_net.forward(next_states)
                
                #print(experiences_minibatch)
                """q_values = torch.cat(tuple(rewards[i] if experiences_minibatch[i][terminal]
                                    else reward[i] + gamma*torch.max(next_q_values[i])
                                    for i in range(len(experiences_minibatch))))"""
                # print(f"rewards: {rewards}\tactions: {actions}\nterminal: {terminal}")
                # print(f"Rewards = {actions} \t rewards[1] = {rewards[1]}")
                # print(f"Next_q_values: {next_q_values}\tNext Q values[1]{next_q_values[0]}")
                # print(f"Experiences minibatch {experiences_minibatch[1][4]}")
                q_values = torch.stack(tuple(rewards[i] if experiences_minibatch[i][4] 
                                    else rewards[i] + gamma*torch.max(next_q_values[i]) 
                                    for i in range(len(experiences_minibatch))))
               
                
                y_pred = current_q_values.gather(1, actions.unsqueeze(1))
                # print(f"Y_pred: {y_pred}")
                optimizer.zero_grad()
                q_values.detach()
                loss = criterion(y_pred, q_values)
                state = next_state

        if num_episodes % target_update == 0:
            target_net.load_state_dict(policy_net.state_dict())


            

In [73]:
train()

THis marks the end of th capturing process
The agent decided to Jump
THis marks the end of th capturing process
The agent decided to Do Nothing
THis marks the end of th capturing process
THis marks the end of th capturing process
The agent decided to Duck
THis marks the end of th capturing process
The agent decided to Duck
THis marks the end of th capturing process
THis marks the end of th capturing process
The agent decided to Do Nothing
THis marks the end of th capturing process
The agent decided to Jump
THis marks the end of th capturing process
THis marks the end of th capturing process
The agent decided to Duck
THis marks the end of th capturing process
The agent decided to Jump
THis marks the end of th capturing process
THis marks the end of th capturing process
The agent decided to Duck
THis marks the end of th capturing process
THis marks the end of th capturing process
The agent decided to Do Nothing
THis marks the end of th capturing process
THis marks the end of th capturing

In [None]:
rewards[1]

In [77]:
10%3

1