In [1]:
from mss import mss
import pydirectinput
import cv2
import numpy as np
import pytesseract
from matplotlib import pyplot as plt
import time
from gymnasium import Env
from gymnasium.spaces import Box, Discrete
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common import env_checker

In [6]:
class WebGame(Env):
    def __init__(self):
        super().__init__()
        self.observation_space = Box(low=0, high=255, shape=(1, 83, 100), dtype=np.uint8)
        self.action_space = Discrete(3)
        self.cap = mss()
        self.game_location = {'top':300, 'left':0, 'width':600, 'height':500}
        self.done_location = {'top':420, 'left':630, 'width':660, 'height':70}
        pytesseract.pytesseract.tesseract_cmd = r'C:\Program Files\Tesseract-OCR\tesseract.exe'
        
    def step(self, action):
        action_map = {
            0:'space',
            1:'down',
            2:'no_op'
        }
        if action != 2:
            pydirectinput.press(action_map[action])
        done, done_cap = self.get_done()
        new_observation = self.get_observation()
        truncated = False
        reward = 1
        info = {}
        return new_observation, reward, done, truncated, info
        
    def render(self):
        cv2.imshow('Game', np.array(self.cap.grab(self.game_location))[:,:,:3])
        if cv2.waitKey(1) & 0xFF == ord('q'):
            self.close()

    def reset(self, seed=None):
        if seed is not None:
            np.random.seed(seed)
        time.sleep(1)
        pydirectinput.click(x=150,y=150)
        pydirectinput.press('space')
        info = {}
        return self.get_observation(), info
        

    def get_observation(self):
        raw = np.array(self.cap.grab(self.game_location))[:,:,:3]
        gray = cv2.cvtColor(raw, cv2.COLOR_BGR2GRAY)
        resized  = cv2.resize(gray,(100,83))
        channel = np.reshape(resized, (1,83,100))
        return channel

    def get_done(self):
        done_cap = np.array(self.cap.grab(self.done_location))[:,:,:3]
        done_strings = ['GAME' , 'GAHE']

        done = False
        res = pytesseract.image_to_string(done_cap)[:4]
        if res in done_strings :
            done = True

        return done, done_cap
        
    def close(self):
        cv2.destroyAllWindows()

In [7]:
env = WebGame()

In [23]:
for episode in range(1):
    obs = env.reset()
    done = False
    total_reward = 0
    while not done:
        obs, reward, done, info = env.step(env.action_space.sample())
        total_reward  += reward
    print(f'Total reward for episode {episode} is {total_reward}')

Total reward for episode 0 is 8


In [8]:
import os 
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common import env_checker

In [9]:
env_checker.check_env(env)

In [10]:
class TrainAndLoggingCallback(BaseCallback):

    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)

        return True

In [11]:
CHECKPOINT_DIR='./train/'
LOG_DIR = './logs/'

In [17]:
callback = TrainAndLoggingCallback(check_freq = 300, save_path = CHECKPOINT_DIR)

In [18]:
from stable_baselines3 import DQN

In [19]:
model = DQN('CnnPolicy', env, tensorboard_log=LOG_DIR, verbose = 1, buffer_size=100000, learning_starts=0)

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [20]:
model.learn(total_timesteps=1000, callback=callback)

Logging to ./logs/DQN_2
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 6.5      |
|    ep_rew_mean      | 6.5      |
|    exploration_rate | 0.753    |
| time/               |          |
|    episodes         | 4        |
|    fps              | 0        |
|    time_elapsed     | 26       |
|    total_timesteps  | 26       |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.46     |
|    n_updates        | 6        |
----------------------------------
----------------------------------
| rollout/            |          |
|    ep_len_mean      | 7.75     |
|    ep_rew_mean      | 7.75     |
|    exploration_rate | 0.411    |
| time/               |          |
|    episodes         | 8        |
|    fps              | 1        |
|    time_elapsed     | 56       |
|    total_timesteps  | 62       |
| train/              |          |
|    learning_rate    | 0.0001   |
|    loss             | 0.213  

KeyboardInterrupt: 