In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")


In [None]:
def simplify_state(state):
    agent_x, agent_y = get_agent_position(state)
    obstacles = get_nearby_obstacles(state, agent_x, agent_y)
    timbers = get_nearby_timbers(state, agent_x, agent_y)
    
    simplified_state = ((agent_x, agent_y), tuple(set(obstacles)), tuple(set(timbers)))
    
    return simplified_state

def get_agent_position(state):
    for i, row in enumerate(state):
        for j, cell in enumerate(row):
            if cell == 1:
                return j, i
    return state.shape[1] // 2, state.shape[0] - 1
    
            
def get_nearby_obstacles(state, agent_x, agent_y):
    obstacles = []
    for i, row in enumerate(state):
        for j, cell in enumerate(row):
            if cell == 2:
                obstacles.append((j, i))
    return obstacles

def get_nearby_timbers(state, agent_x, agent_y):
    timbers = []
    for i, row in enumerate(state):
        for j, cell in enumerate(row):
            if cell == 3:
                timbers.append((j, i))
    return timbers


In [None]:
import pyautogui
import cv2
import time
import keyboard
import torchvision
from ultralytics import YOLO


RES_X = 1920
RES_Y = 1080

GAME_REGION = (405, 210, 850, 480)
restart_button = cv2.imread('restart_button.png', cv2.IMREAD_GRAYSCALE)


def get_screen(region):
    screen = pyautogui.screenshot(region=(region[0], region[1], region[2], region[3]))
    
    non_crop = screen.copy()

    transforms = torchvision.transforms.Compose([
        torchvision.transforms.RandomRotation((14, 14)),
        torchvision.transforms.CenterCrop((320, 566)),
        torchvision.transforms.Resize((240, 425)),
    ])
    
    screen = transforms(screen)   
    
    screen = cv2.cvtColor(np.array(screen), cv2.COLOR_RGB2BGR)
    
    non_crop = cv2.cvtColor(np.array(non_crop), cv2.COLOR_RGB2BGR)
    non_crop = cv2.resize(non_crop, (425, 240))
    
    return screen, non_crop

import numpy as np

def map_to_grid(image_size, grid_size, boxes, class_labels):
    width, height = image_size
    grid_width, grid_height = grid_size
    grid = np.zeros((grid_height, grid_width), dtype=int)

    cell_width = width / grid_width
    cell_height = height / grid_height

    for (x_min, y_min, x_max, y_max), label in zip(boxes, class_labels):
        x_start = int(x_min // cell_width)
        y_start = int(y_min // cell_height)
        x_end = int(np.ceil(x_max / cell_width))
        y_end = int(np.ceil(y_max / cell_height))

        for y in range(y_start, y_end):
            for x in range(x_start, x_end):
                grid[y, x] = label + 1

    return grid


def get_state(screen):
    results = cv_model(screen, verbose=False)

    image_size = (425, 240)
    grid_size = (36, 32)

    boxes = []
    labels = []
    
    boxes_ = results[0].boxes
    for box in boxes_:
        x_min, y_min, x_max, y_max = box.xyxy[0].tolist()
        
        class_id = int(box.cls[0].item())
        
        boxes.append((x_min, y_min, x_max, y_max))
        labels.append(class_id)

    grid = map_to_grid(image_size, grid_size, boxes, labels)
    return grid, boxes, labels

def draw_bounding_boxes(screen, boxes, class_labels):
    for (x_min, y_min, x_max, y_max), label in zip(boxes, class_labels):
        cv2.rectangle(screen, (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 255, 0), 2)
        cv2.putText(screen, str(label), (int(x_min), int(y_min) - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
        
    return screen

def is_game_over(image, score_threshold=0.5, scale=0.5):
    grey_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    resized_template = cv2.resize(restart_button, (0, 0), fx=scale, fy=scale)
    h, w = grey_image.shape

    cropped_search_box = grey_image[int(h * 0.87):, int(w * 0.43):int(w * 0.57)]
    
    result = cv2.matchTemplate(cropped_search_box, resized_template, cv2.TM_CCOEFF_NORMED)
    result = np.sort(result.flatten())[::-1]
    
    return result.max() > score_threshold



cv_model = YOLO("best_cv.pt")

keyboard.wait('q')

while True:
    screenshot, non_crop_state = get_screen(GAME_REGION)
    state_raw, boxes, labels = get_state(screenshot)
    state = simplify_state(state_raw)

    with open('state.txt', 'w') as f:
        for row in state_raw:
            f.write(' '.join(map(str, row)) + '\n')
            
    screen_with_boxes = draw_bounding_boxes(screenshot, boxes, labels)
    
    cv2.imshow('Game Screen with Bounding Boxes', screen_with_boxes)
    
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break
        
    if is_game_over(non_crop_state):
        keyboard.press('space')
        time.sleep(3.25)
        keyboard.release('space')
                
