## 1. Install Dependencies

In [None]:
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

In [None]:
!pip list

In [None]:
!pip install stable-baselines3[extra] protobuf==3.20.*

In [None]:
!pip install mss 
!pip install pydirectinput 
!pip install pytesseract

In [None]:
!pip install gym

In [2]:
# Mss used for screen capture 
from mss import mss 
# Sending commands 
import pydirectinput
#Open CV allow us to frame processing 
import cv2
# Transformational frameworkd 
import numpy as np 
from numpy import shape
# OCR for game over extraction 
import pytesseract
# visualize the captured frames 
from matplotlib import pyplot as plt 
# Bring in time for pauses 
import time 
# Environment components 
from gym import Env

from gymnasium.spaces import Box,Discrete


import gymnasium

## 2 Build the environment

### 2.1 Create Environment

In [3]:



class WebGame(gymnasium .Env):
    
    # Setup the environment action and observation shapes 
    def __init__(self):
        # SubClass model 
        super().__init__()
        
        # Setup Spacess 
        self.observation_space = Box(low=0,high=255,shape=(1,150,600),dtype=np.uint8)
        self.action_space = Discrete(3)
        
        # Define extraction parameters for the game 
        self.cap = mss()
        self.game_location = {'top':200, 'left':500, 'width':600,'height':150}
        self.done_location = {'top':160, 'left':825, 'width':260,'height':60}
        
    
    # What is called to do something in the game 
    def step(self,action):
        # Action key 0=Space , 
        action_map = {0:'space',1:'down',2:'no_op'}
        if action != 2:
            pydirectinput.press(action_map[action])

        # Checking whether the game is done 
        res, done, done_cap = self.get_done()
        
        # Get the next observation
        new_observation = self.get_observation()
        
        #* REWARD - we get a point for every point we live 
        reward = 1 
        
        # Info dictionary 
        info = {}
        
        return new_observation, reward, False, done , info
    
    # Visualize the game 
    def render(self):
        cv2.imshow('Game',np.array(self.cap.grab(self.game_location))[:,:,:3])
        if cv2.waitKey(1) & 0xff == ord("q"):
            self.close()
    
       
    # This closes down the observation 
    def close(self):
        cv2.destroyAllWindows()
    
    # Restart the game 
    def reset(self,seed=None):
        time.sleep(1)
        pydirectinput.click(x=950,y=240)
        pydirectinput.press('space')
        info = {}
        return self.get_observation(), info
 
    
    
    # Get the part of the observation of the game that we want to see 
    def get_observation(self):
        # Get screen Capture of the game 
        raw = np.array(self.cap.grab(self.game_location))[:,:,:3].astype(np.uint8)
        
        # Gray Scale 
        gray = cv2.cvtColor(raw,cv2.COLOR_BGR2GRAY)
        
        # Resize 
        resized = cv2.resize(gray,(600,150))
        
        # add channels first 
        channel = np.reshape(resized,(1,150,600))
        return channel
    
    # Get the done text using OCR 
    def get_done(self):
        # Get done screen 
        done_cap = np.array(self.cap.grab(self.done_location))[:,:,:3].astype(np.uint8)
        # valid done text 
        
        
        done_strings = ['GAME','GAHE']                                                                                       
        
        done = False
        res  = pytesseract.image_to_string(done_cap)[:4]
        if res in done_strings:
            done = True 
                   
        return res,done, done_cap
     
 

In [157]:
env = WebGame()

In [None]:
env.reset()

In [159]:
env.render()

In [160]:
env.close()

In [None]:
plt.imshow(cv2.cvtColor(env.get_observation()[0],cv2.COLOR_BGR2RGB))

In [162]:
res, done , done_cap = env.get_done()

In [None]:
res

In [164]:
done

False

In [None]:
plt.imshow(done_cap)

In [166]:
env.get_done()[2].shape

(60, 260, 3)

### 2.2 Test Environment

In [7]:
env = WebGame()

In [None]:
obs = env.get_observation()[0]
plt.imshow(cv2.cvtColor(obs,cv2.COLOR_BGR2RGBA))

In [9]:
res, done, done_cap = env.get_done()

In [10]:
res

'GAME'

In [None]:
done

In [None]:
plt.imshow(done_cap)

In [15]:
# play game for 10 episode
for episode in range(1):
    obs = env.reset()
    done = False
    total_rewared =0 
    
    while not done:
        obs,reward,truncated, done, info = env.step(env.action_space.sample())
        total_rewared += reward
    
    print(f'Total Reward for episode {episode} is {total_rewared}')

Total Reward for episode 0 is 12


## 3. Train The Model

### 3.1 Create a Callback

In [16]:
# import OS for file path management    
import os 

# import Base Callback for saving models 
from stable_baselines3.common.callbacks import BaseCallback
     
# check environment 
from stable_baselines3.common import env_checker

In [17]:
## Check the environment is okey
env_checker.check_env(env=env)

In [18]:
class TrainAndLoggingCallback(BaseCallback):
    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path
        
    def __init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)
            
    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)
        return True

In [19]:
CHECKPOINT_DIR = './train/'
LOG_DIR = './logs'

In [20]:
callback = TrainAndLoggingCallback(check_freq=50,save_path=CHECKPOINT_DIR)

### 3.2 Build DQN and Train

In [None]:
# import the DQN algorithm
from stable_baselines3 import DQN

In [None]:
# Create DQN Model
model = DQN(policy='CnnPolicy',env=env,tensorboard_log=LOG_DIR,verbose=1,buffer_size=6000,learning_starts=0)

In [None]:
# Kick Off Training 
model.learn(total_timesteps=50,callback=callback)

In [None]:
obs

In [28]:
action , _  = model.predict(obs)

In [36]:
int(action)

0

## 4. Test out Model

In [None]:
model = DQN.load(os.path.join('best_model','best_model_88000'))

### TEST THE MODEL

In [None]:
for episode in range(10):
    obs = env.reset()
    done = False
    total_rewared =0 
    
    while not done:
        action , _ = model.predict(obs)
        obs,reward,truncated, done, info = env.step(int(action))
        total_rewared += reward
    
    print(f'Total Reward for episode {episode} is {total_rewared}')