In [None]:
# Link do jogo para o navegador brave
# brave://dino/

1 - Instalando e importando módulo

In [None]:
%pip install -r requirements.txt

In [None]:
from mss import mss # mss é usado para capturar a tela
import pydirectinput # usado para enviar comandos
import cv2 # open cv para processar a tela
import numpy as np # transformação
import pytesseract # OCR para ver o game over do jogo
from matplotlib import pyplot as plt # Visualizar frames
import time # Ter pausas
from gym import Env # Componentes de ambiente
from gym.spaces import Box, Discrete

2 - Construindo o ambiente

2.1 - Criando o ambiente

In [None]:
class WebGame(Env):
    # Iniciando a área do jogo
    def __init__(self):
        super().__init__()

        # Configurando a área
        self.observation_space = Box(low=0, high=255, shape=(1,83,100), dtype=np.uint8)
        self.action_space = Discrete(3)

        # Capturando a tela
        self.cap = mss()

        self.game_location = {'top': 300, 'left': 0, 'width': 600, 'height': 500}
        self.done_location = {'top': 405, 'left': 630, 'width': 660, 'height': 70}
        
        
    # Definindo a ação
    def step(self, action):
        action_map = {0:'space', 1: 'down', 2: 'sem_operacao'}

        if (action != 2):
            pydirectinput.press(action_map[action])

        done, done_cap = self.get_done() 

        observation = self.get_observation()

        reward = 1 
        info = {}

        return observation, reward, done, info
        
    
    # Resetando o jogo
    def reset(self):
        time.sleep(1)

        pydirectinput.click(x=150, y=150)
        pydirectinput.press('space')

        return self.get_observation()
        
    
    # Renderizando o jogo
    def render(self):
        cv2.imshow('Game', self.current_frame)

        if cv2.waitKey(1) & 0xFF == ord('q'):
            self.close()


    # Fechando as janelas
    def close(self):
        cv2.destroyAllWindows()
    

    # Vendo a imagem do jogo
    def get_observation(self):
        raw = np.array(self.cap.grab(self.game_location))[:,:,:3].astype(np.uint8)

        gray = cv2.cvtColor(raw, cv2.COLOR_BGR2GRAY)

        resized = cv2.resize(gray, (100,83))
        channel = np.reshape(resized, (1,83,100))

        return channel
    

    # Vendo que o jogo parou
    def get_done(self):
        done_cap = np.array(self.cap.grab(self.done_location))

        done_strings = ['GAME', 'GAHE']
        done = False

        # if np.sum(done_cap) < 44300000:
        #     done = True
        res = pytesseract.image_to_string(done_cap)[:4]
        
        if (res in done_strings):
            done = True
        
        return done, done_cap

2.2 - Testando o ambiente

In [None]:
env = WebGame()

In [None]:
obs = env.get_observation()

In [None]:
plt.imshow(cv2.cvtColor(obs[0], cv2.COLOR_GRAY2BGR))

In [None]:
done, done_cap = env.get_done()

In [None]:
plt.imshow(done_cap)

In [None]:
pytesseract.image_to_string(done_cap)[:4]

In [None]:
done

In [None]:
for episode in range(10): 
    obs = env.reset()

    done = False  
    total_reward   = 0
    
    while not done: 
        obs, reward,  done, info =  env.step(env.action_space.sample())
        total_reward  += reward
    
    print(f'Recompensa total da geração {episode} é {total_reward}')    

3 - Treinando o Modelo

3.1 - Criando callback

In [None]:
# Importando os para gerenciamento de path
import os 
# Importando Base Callback para salvar os modelos
from stable_baselines3.common.callbacks import BaseCallback
# Verificando o ambiente
from stable_baselines3.common import env_checker

In [None]:
env_checker.check_env(env)

In [None]:
class TrainAndLoggingCallback(BaseCallback):

    def __init__(self, check_freq, save_path, verbose=1):
        super(TrainAndLoggingCallback, self).__init__(verbose)
        self.check_freq = check_freq
        self.save_path = save_path

    def _init_callback(self):
        if self.save_path is not None:
            os.makedirs(self.save_path, exist_ok=True)

    def _on_step(self):
        if self.n_calls % self.check_freq == 0:
            model_path = os.path.join(self.save_path, 'best_model_{}'.format(self.n_calls))
            self.model.save(model_path)

        return True

In [None]:
CHECKPOINT_DIR = './train/'
LOG_DIR = './logs/'

In [None]:
callback = TrainAndLoggingCallback(check_freq=1000, save_path=CHECKPOINT_DIR)

3.2 - Construindo DQN e Train

In [None]:
from stable_baselines3 import DQN
from stable_baselines3.common.monitor import Monitor
from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack

In [None]:
env = WebGame()

In [None]:
model = DQN('CnnPolicy', env, tensorboard_log=LOG_DIR, verbose=1, buffer_size=1200000, learning_starts=1000)

In [None]:
model.learn(total_timesteps = 100000, callback = callback)

In [None]:
model.load('train_first/best_mode l_50000') 

4 - Testando o Modelo

In [None]:
for episode in range(5): 
    obs = env.reset()
    
    done = False
    total_reward = 0
    
    while not done: 
        action, _ = model.predict(obs)
        obs, reward, done, info = env.step(int(action))
        time.sleep(0.01)
        total_reward += reward
    
    print(f'Recompensa total da geração {episode} é {total_reward}')
    time.sleep(2)