1. Install and import dependencies


In [None]:
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

In [None]:
!pip install stable-baselines3[extra] protobuf==3.20.*

In [None]:
# install tesseract
!apt-get update
!apt-get install -y tesseract-ocr

In [None]:
# install python tesseract
!pip install mss pydirectinput pytesseract

In [None]:
# imports
# capturing screen
from mss import mss
# sending commands
import pydirectinput
# frame processing
import cv2
# transformations
import numpy as np
# ocr for text extracting
import pytesseract
# visualization
from matplotlib import pyplot as plt
# pauses
import time
# envs
from gym import Env
from gym.spaces import Box, Discrete

2. Build the environment


In [None]:
# create the environment
class WebGame(Env):
  def __init__(self):
    # subclass
    super().__init__()
    # space
    self.output_space = Box(low=0, high=255, shape(1,83,100), dtype=np.uint8)
    self.action_space = Discrete(3)
    # define extraction parameters
    self.cap = mss()
    self.game_location={'top':300, 'left': 0, 'width': 600, 'height': 500}
    self.done_location = {'top':405, 'left': 630, 'width': 660, 'height': 70}

  def step(self, action):
    # action index - 0 => space, 1 => duck, 2 => no action
    action_map = {
        0: 'space',
        1: 'down',
        2: 'no_operation'
    }
    if action != 2:
      pydirectinput.press(action_map[action])
    # check for game over
    gover, done_capture = self.get_done()
    # get the new observation
    new_observation = self.get_output()

    # define rewards for our model (for every frame a point is given to the model)
    reward = 1
    info = {}
    return new_observation, reward, gover, info

  # function to render the screen frames
  def render(self):
    cv2.imshow('Game', np.array(self.cap.grab(self.game_location))[:,:,3])
    if cv2.waitKey(1) & 0xFF == ord('q'):
      self.close()
  # restart session
  def reset(self):
    # wait for the button to appear
    time.sleep(1)
    # press the buttons
    pydirectinput.click(x=150, y=150)
    pydirectinput.press('space')
    return self.get_output()
  # get part of the game to check for conditions
  def get_output(self):
    # get the screen capture
    raw_screen = np.array(self.cap.grab(self.game_location))[:,:,3].astype(np.uint8)
    # gray scale it
    gray = cv2.cvtColor(raw, cv2.COLOR_BGr2GRAY)
    # resize the gray-scaled image
    resized_image = cv2.resize(gray, (100,83))
    # channels
    channel = np.reshape(resized, (1, 83,100))
    return channel
  # get the final text (game over) from the running screen in browser
  def get_done(self):
    done_capture = np.array(self.cap.grab(self.done_location))[:, :, 3]
    # validate the text. Sometimes the model reads the text as Gahe instead of game
    gover_strings = ['GAME', 'GAHE']
    # run the ocr
    gover= False
    result = pytesseract.image_to_string(done_capture)[:4]
    if result in gover_strings:
      gover = True
    return  gover, done_capture
  # close the observation
  def close(self):
    cv2.destroyAllWindows()


In [None]:
# setup the environment
env = WebGame()

In [None]:
env.reset()

In [None]:
env.render()

In [None]:
# grey scaled
plt.imshow(cv2.cvtColor(env.get_observation()[0], cv2.COLOR_BGR2RGB ))

In [None]:
# game over region
 gover, done_capture = env.get_done()

In [None]:
# show the done capture
plt.imshow(done_capture)

Test the environment


In [None]:
# define an instance of the game
env = WebGame()

In [None]:
# get the output observation
output = env.get_output()

In [None]:
# display output
plt.imshow(cv2.cvtColor(obs[0], cv2.COLOR_BGR2RGB))

In [None]:
# get the game over status
gover, done_capture = env.get_done()
gover

In [None]:
# pre-emptively test the model using for loop and random action (10 instances)
for instance in range(10):
  observation = env.reset()
  gover = False
  reward_tally = 0
  # while not dead
  while not gover:
    observation, reward, gover, info = env.step(env.action_space.sample())
    reward_tally += reward
  # get output for the trials
  print(f'Total reward for instance {instance} is {reward_tally}')

3. Training the model

In [None]:
# create callback to train model
import os
# base callback for saving
from stable_baselines3.common.callbacks import BaseCallback
# check environment
from stable_baselines3.common import env_checker

In [None]:
# check environment status
env_checker.check_env(env)

In [None]:
# write the callback class
class TrainAndLoggingCallback(BaseCallback):
  # initialize the class
  def __init__(self, check_frequency, save_path, verbose=1):
    super(TrainAndLoggingCallback, self).__init__(verbose)
    self.check_frequency = check_frequency
    self.save_path = save_path
  # initialize callback
  def _init_callback(self):
    if self.save_path is not None:
      os.makedirs(self.save_path, exist_ok=True)
  # on all steps
  def _on_step(self):
    if self.n_calls % self.check_frequency == 0:
      model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
      self.model.save(model_path)
    return True



In [None]:
# create two directories for checkpoints and logs
TRAIN_DIRECTORY = './train'
LOG_DIRECTORY = './log'

In [None]:
# instantiate the callback
cb = TrainAndLoggingCallback(check_frequency = 1000, save_path=TRAIN_DIRECTORY)

Build and train the DQN model


In [None]:
# import dqn
from stable_baselines3 import DQN

In [None]:
# create the model
model = DQN('CnnPolicy', env, tensorboard_log=LOG_DIRECTORY, verbose=1, buffer_size=150000, learning_starts=100)

In [None]:
# start training
model.learn(total_timesteps=6000, callback=callback)

4. Test the model

In [None]:
# preload a tested model for better learning
model = DQN.load(model.load(os.path.join('train_first', 'best_model_88000')))

In [None]:
# run loop, but this time with proper learning actions
for instance in range(10):
  observation = env.reset()
  gover = False
  reward_tally = 0
  # while not dead
  while not gover:
    action, _ = model.predict(observation)
    observation, reward, gover, info = env.step(int(action))
    reward_tally += reward
  # get output for the trials
  print(f'Total reward for instance {instance} is {reward_tally}')