In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")


In [2]:
# DQN model for Crossy Road game
# input is the 2d matrix of the game state (90 x 90)
# output is the action to take (0, 1, 2, 3) for (up, down, left, right)

import torch
import torch.nn as nn
import torch.nn.functional as F

class RecurrentIQN(nn.Module):
    def __init__(self, input_size, output_size, hidden_size, n_quantiles=32):
        super(RecurrentIQN, self).__init__()
        self.n_quantiles = n_quantiles
        self.input_size = input_size
        self.hidden_size = hidden_size

        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.quantile_embed = nn.Linear(hidden_size, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x, quantiles, hidden):
        # x shape: (batch_size, seq_len, input_size)
        lstm_out, hidden = self.lstm(x, hidden)  # lstm_out: (batch_size, seq_len, hidden_size)
        
        # Expand quantiles to match lstm_out
        quantiles = quantiles.unsqueeze(-1)  # (batch_size, n_quantiles, 1)
        pi = torch.acos(torch.zeros(1)).item() * 2  # pi value
        quantile_feats = torch.cos(pi * quantiles * torch.arange(1, self.hidden_size + 1).to(x.device))
        quantile_feats = F.relu(self.quantile_embed(quantile_feats))  # (batch_size, n_quantiles, hidden_size)

        # Combine LSTM output with quantile features
        lstm_out = lstm_out[:, -1, :].unsqueeze(1)  # (batch_size, 1, hidden_size)
        x = lstm_out * quantile_feats  # (batch_size, n_quantiles, hidden_size)

        x = self.fc(x)  # (batch_size, n_quantiles, output_size)
        return x, hidden

    def act(self, state, hidden, epsilon):
        if random.random() > epsilon:
            print("Model acting")
            with torch.no_grad():
                # Prepare the state tensor
                state = torch.FloatTensor(state).unsqueeze(0).unsqueeze(0).to(next(self.parameters()).device)  # (1, 1, input_size)
                # Sample quantiles
                quantiles = torch.rand(1, self.n_quantiles).to(state.device)  # (1, n_quantiles)
                # Forward pass
                q_values, hidden = self.forward(state, quantiles, hidden)  # q_values: (1, n_quantiles, output_size)
                # Average over quantiles
                q_values = q_values.mean(dim=1)  # (1, output_size)
                # Select action with highest Q-value
                action = q_values.argmax(dim=1).item()
        else:
            # Random action
            action = random.randrange(self.fc.out_features)
        return action, hidden


In [3]:
class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.position = 0
        
    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity
        
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))
        return state, action, reward, next_state, done
        
    def __len__(self):
        return len(self.buffer)
    
    
# DQN agent for Crossy Road game
# uses DQN model and replay buffer for training


class IQNAgent:
    def __init__(self, input_size, output_size, hidden_size, replay_buffer_capacity, batch_size, gamma, epsilon_start, epsilon_end, epsilon_decay, n_quantiles=32):
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.replay_buffer_capacity = replay_buffer_capacity
        self.batch_size = batch_size
        self.gamma = gamma
        self.epsilon_start = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_decay = epsilon_decay
        self.n_quantiles = n_quantiles

        self.model = RecurrentIQN(input_size, output_size, hidden_size, n_quantiles).to(device)
        self.model_target = RecurrentIQN(input_size, output_size, hidden_size, n_quantiles).to(device)
        self.model_target.load_state_dict(self.model.state_dict())

        self.hidden = (torch.zeros(1, 1, hidden_size).to(device),
                       torch.zeros(1, 1, hidden_size).to(device))

        self.replay_buffer = ReplayBuffer(replay_buffer_capacity)
        self.optimizer = optim.Adam(self.model.parameters())

        self.steps_done = 0

        # Add counters for target network updates
        self.update_counter = 0
        self.target_update_freq = 1000  # Update target network every 1000 steps

    def select_action(self, state):
        epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * np.exp(-1. * self.steps_done / self.epsilon_decay)
        self.steps_done += 1
        action, self.hidden = self.model.act(state, self.hidden, epsilon)
        return action

    def optimize_model(self):
        if len(self.replay_buffer) < self.batch_size:
            return
        
        self.model.train()
        print("Optimizing model...")
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)

        # Convert to tensors
        states = torch.FloatTensor(states).unsqueeze(1).to(device)
        actions = torch.LongTensor(actions).to(device)
        rewards = torch.FloatTensor(rewards).to(device)
        next_states = torch.FloatTensor(next_states).unsqueeze(1).to(device)
        dones = torch.FloatTensor(dones).to(device)

        # Sample quantiles
        quantiles = torch.rand(self.batch_size, self.n_quantiles).to(device)

        # Initialize hidden states
        hidden = (torch.zeros(1, self.batch_size, self.hidden_size).to(device),
                  torch.zeros(1, self.batch_size, self.hidden_size).to(device))

        # Compute current Q-values
        current_q, _ = self.model(states, quantiles, hidden)
        current_q = current_q.gather(2, actions.unsqueeze(-1).unsqueeze(-1).expand(-1, self.n_quantiles, -1)).squeeze(-1)

        # Compute target Q-values
        with torch.no_grad():
            next_hidden = (torch.zeros(1, self.batch_size, self.hidden_size).to(device),
                           torch.zeros(1, self.batch_size, self.hidden_size).to(device))
            next_quantiles = torch.rand(self.batch_size, self.n_quantiles).to(device)
            next_q, _ = self.model_target(next_states, next_quantiles, next_hidden)
            next_q = next_q.max(2)[0]
            target_q = rewards.unsqueeze(1) + self.gamma * next_q * (1 - dones.unsqueeze(1))

        # Compute quantile Huber loss
        td_errors = target_q.unsqueeze(1) - current_q
        huber_loss = F.smooth_l1_loss(current_q, target_q.unsqueeze(1), reduction='none')
        quantile_loss = (torch.abs(quantiles.unsqueeze(-1) - (td_errors.detach() < 0).float()) * huber_loss).mean()

        self.optimizer.zero_grad()
        quantile_loss.backward()
        self.optimizer.step()

        # Update target network periodically
        self.update_counter += 1
        if self.update_counter % self.target_update_freq == 0:
            self.update_target_network()

        return quantile_loss.item()

    def update_target_network(self):
        self.model_target.load_state_dict(self.model.state_dict())

    def push(self, state, action, reward, next_state, done):
        self.replay_buffer.push(state, action, reward, next_state, done)

    def save(self, path):
        torch.save(self.model.state_dict(), path)

    def load(self, path):
        self.model.load_state_dict(torch.load(path))
        self.model.eval()

    def reset(self):
        self.steps_done = 0


In [4]:
def simplify_state(state):
    # Flatten the grid to create a fixed-size input vector
    agent_x, agent_y = get_agent_position(state)
    obstacles = get_nearby_obstacles(state, agent_x, agent_y)
    timbers = get_nearby_timbers(state, agent_x, agent_y)
    
    
    
    simplified_state = ((agent_x, agent_y), tuple(set(obstacles)), tuple(set(timbers)))
    
    return simplified_state

def get_agent_position(state):
    for i, row in enumerate(state):
        for j, cell in enumerate(row):
            if cell == 1:
                return j, i
    return state.shape[1] // 2, state.shape[0] - 1
    
            
def get_nearby_obstacles(state, agent_x, agent_y):
    obstacles = []
    for i, row in enumerate(state):
        for j, cell in enumerate(row):
            if cell == 2:
                obstacles.append((j, i))
    return obstacles

def get_nearby_timbers(state, agent_x, agent_y):
    timbers = []
    for i, row in enumerate(state):
        for j, cell in enumerate(row):
            if cell == 3:
                timbers.append((j, i))
    return timbers


In [5]:
# run DQN agent on Crossy Road game to train and play the game
import pyautogui
import cv2
import time
import keyboard
import torchvision
from ultralytics import YOLO


RES_X = 1920
RES_Y = 1080

GAME_REGION = (405, 210, 850, 480)
restart_button = cv2.imread('restart_button.png', cv2.IMREAD_GRAYSCALE)


def get_screen(region):
    screen = pyautogui.screenshot(region=(region[0], region[1], region[2], region[3]))
    
    non_crop = screen.copy()

    transforms = torchvision.transforms.Compose([
        torchvision.transforms.RandomRotation((14, 14)),
        torchvision.transforms.CenterCrop((320, 566)),
        torchvision.transforms.Resize((240, 425)),
    ])
    
    screen = transforms(screen)   
    
    screen = cv2.cvtColor(np.array(screen), cv2.COLOR_RGB2BGR)
    
    non_crop = cv2.cvtColor(np.array(non_crop), cv2.COLOR_RGB2BGR)
    non_crop = cv2.resize(non_crop, (425, 240))
    
    return screen, non_crop

import numpy as np

def map_to_grid(image_size, grid_size, boxes, class_labels):
    """
    Map detected bounding boxes to a grid representation.

    Args:
        image_size: Tuple (width, height) of the image.
        grid_size: Tuple (N, M) of the grid dimensions.
        boxes: List of bounding boxes [(x_min, y_min, x_max, y_max)].
        class_labels: List of class labels corresponding to the boxes.

    Returns:
        grid: 2D numpy array of shape (N, M) with object class labels.
    """
    width, height = image_size
    grid_width, grid_height = grid_size
    grid = np.zeros((grid_height, grid_width), dtype=int)

    cell_width = width / grid_width
    cell_height = height / grid_height

    for (x_min, y_min, x_max, y_max), label in zip(boxes, class_labels):
        x_start = int(x_min // cell_width)
        y_start = int(y_min // cell_height)
        x_end = int(np.ceil(x_max / cell_width))
        y_end = int(np.ceil(y_max / cell_height))

        for y in range(y_start, y_end):
            for x in range(x_start, x_end):
                grid[y, x] = label + 1

    return grid


def get_state(screen):
    results = cv_model(screen, verbose=False)
    
    # save the image with bounding boxes
    # results[0].save("dataset/screen_detections.png")

    image_size = (425, 240)  # Example image dimensions (width, height)
    grid_size = (36, 32)    # Example grid dimensions (N, M)

    boxes = []
    labels = []
    
    boxes_ = results[0].boxes
    for box in boxes_:
        x_min, y_min, x_max, y_max = box.xyxy[0].tolist()
        
        class_id = int(box.cls[0].item())
        
        boxes.append((x_min, y_min, x_max, y_max))
        labels.append(class_id)  # Assuming class_id is the label

    grid = map_to_grid(image_size, grid_size, boxes, labels)
    # print(grid)
    return grid

def is_game_over(image, score_threshold=0.5, scale=0.5):
    grey_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    resized_template = cv2.resize(restart_button, (0, 0), fx=scale, fy=scale)
    h, w = grey_image.shape

    cropped_search_box = grey_image[int(h * 0.87):, int(w * 0.43):int(w * 0.57)]
    
    # cv2.imwrite('cropped_search_box.png', cropped_search_box)
    
    result = cv2.matchTemplate(cropped_search_box, resized_template, cv2.TM_CCOEFF_NORMED)
    result = np.sort(result.flatten())[::-1]
    
    return result.max() > score_threshold

def process_state(state, max_obstacles=1152, max_timbers=1152):
    agent_pos, obstacles, timbers = state[0], state[1], state[2]
    
    # Agent position
    # print(agent_pos)
    state_vector = list(agent_pos)
    
    # Obstacles
    obstacles = list(obstacles)
    for i in range(max_obstacles):
        if i < len(obstacles):
            state_vector.extend(obstacles[i])  # Assuming obstacle positions are tuples
        else:
            state_vector.extend([0, 0])  # Padding
   
    # Timbers
    timbers = list(timbers)
    for i in range(max_timbers):
        if i < len(timbers):
            state_vector.extend(timbers[i])  # Assuming timber positions are tuples
        else:
            state_vector.extend([0, 0])  # Padding
    
    return state_vector

def compute_reward(reward_state):
    # Compute reward based on the change in the game screen   
    time = reward_state['time']
    state = reward_state['state']
    action = reward_state['action']
    prev_action = reward_state['prev_action']
    next_state = reward_state['next_state']
    total_reward = reward_state['total_reward']
    non_crop_state = reward_state['non_crop_state']
    reward = 0
    
    if is_game_over(non_crop_state):
        reward = -100
        return reward
    
    if action == 0:
        reward += 10
    elif action == 1:
        reward -= 5
    elif action == 2:
        reward += 0.25
    elif action == 3:
        reward += 0.25
    elif action == 4:
        reward += 0.2

    # if action == prev_action:
    #     reward += 0
    
    hen_count = 0
    for row in next_state:
        for cell in row:
            if cell == 1:
                hen_count += 1
                
    reward += (hen_count * 0.25)
    

    if action == prev_action:
        if state[0][0] // 5 == next_state[0][0] // 5 \
        and state[0][1] // 5 == next_state[0][1] // 5:
            reward -= 10
            
    reward = round(reward, 2)
    
    return reward
    

# train DQN agent
actions = [0, 1, 2, 3, 4]

input_size = 4610
agent = IQNAgent(input_size, 5, 128, 10000, 32, 0.99, 1.0, 0.1, 1000)
cv_model = YOLO('best_cv.pt')
episodes = 1000
episode_length = 1000
losses = []
rewards = []

# agent.load('dqn.pth')

print("Model is ready to train")

keyboard.wait('q')

# Loop for self-play training
for episode in range(episodes):
    screenshot, non_crop_state = get_screen(GAME_REGION)
    state_raw = get_state(screenshot)
    
    total_reward = 0
    total_loss = 0
    action = 0

    state = simplify_state(state_raw)
    total_reward = 0
    done = False
    
    
    start_time = time.time()
    for step in range(episode_length):
        prev_action = action
        time.sleep(0.1)
        
        state_vector = process_state(state)
        
        action = agent.select_action(state_vector)
        
        if action < 4:
            pyautogui.press(['up', 'down', 'left', 'right'][action])
        
        
        next_screenshot, non_crop_state = get_screen(GAME_REGION)
        next_state_raw = get_state(next_screenshot)
        
        with open('state.txt', 'w') as f:
            for row in next_state_raw:
                f.write(' '.join(map(str, row)) + '\n')
                
        
        # cv2.imwrite(f'state_{step}.png', state)
        
        next_state = simplify_state(next_state_raw)
        next_state_vector = process_state(next_state)
        
        reward_state = {
            'state': state,
            'action': action,
            'prev_action': prev_action,
            'next_state': next_state,
            'time': time.time() - start_time,
            'total_reward': total_reward,
            'non_crop_state': non_crop_state
        }
        
        done = 0
        reward = compute_reward(reward_state)

        agent.push(state_vector, action, reward, next_state_vector, done)

        state = next_state
        total_reward += reward
        
        if is_game_over(non_crop_state):
            done = 1
            break
        
        loss = agent.optimize_model()
        # else:
        #     print("False")
        
        print(f"Step: {step}, Action: {action}, Reward: {reward}, Total Reward: {total_reward}, Loss: {loss}")
 
    # tap space key to restart the game
    keyboard.press_and_release('space')
    losses.append(total_loss)
    rewards.append(total_reward)
    print('\nepisode: {}, reward: {}'.format(episode, total_reward))
    
    # Save model
    if episode % 10 == 0:
        agent.save('dqn.pth')
        print("Model saved")
    
    time.sleep(3.25)
    keyboard.press_and_release('space')



Model is ready to train
Step: 0, Action: 2, Reward: 0.25, Total Reward: 0.25, Loss: None
Step: 1, Action: 0, Reward: 10.0, Total Reward: 10.25, Loss: None
Step: 2, Action: 3, Reward: 0.25, Total Reward: 10.5, Loss: None
Step: 3, Action: 1, Reward: -5.0, Total Reward: 5.5, Loss: None
Step: 4, Action: 3, Reward: 0.25, Total Reward: 5.75, Loss: None
Step: 5, Action: 4, Reward: 0.2, Total Reward: 5.95, Loss: None
Step: 6, Action: 1, Reward: -5.0, Total Reward: 0.9500000000000002, Loss: None
Step: 7, Action: 4, Reward: 0.2, Total Reward: 1.1500000000000001, Loss: None
Step: 8, Action: 1, Reward: -5.0, Total Reward: -3.8499999999999996, Loss: None
Step: 9, Action: 1, Reward: -5.0, Total Reward: -8.85, Loss: None
Step: 10, Action: 2, Reward: 0.25, Total Reward: -8.6, Loss: None
Step: 11, Action: 4, Reward: 0.2, Total Reward: -8.4, Loss: None
Step: 12, Action: 1, Reward: -5.0, Total Reward: -13.4, Loss: None
Step: 13, Action: 3, Reward: 0.25, Total Reward: -13.15, Loss: None
Step: 14, Action: 

  huber_loss = F.smooth_l1_loss(current_q, target_q.unsqueeze(1), reduction='none')


Step: 15, Action: 4, Reward: 0.2, Total Reward: 7.7, Loss: 2.8623082637786865
Optimizing model...
Step: 16, Action: 3, Reward: 0.25, Total Reward: 7.95, Loss: 2.9446310997009277
Model acting
Optimizing model...
Step: 17, Action: 0, Reward: 10.0, Total Reward: 17.95, Loss: 2.831799030303955
Optimizing model...
Step: 18, Action: 2, Reward: 0.25, Total Reward: 18.2, Loss: 2.999375104904175
Optimizing model...
Step: 19, Action: 3, Reward: 0.25, Total Reward: 18.45, Loss: 3.105541467666626
Optimizing model...
Step: 20, Action: 1, Reward: -5.0, Total Reward: 13.45, Loss: 3.0252881050109863
Optimizing model...
Step: 21, Action: 2, Reward: 0.25, Total Reward: 13.7, Loss: 2.8815200328826904
Optimizing model...
Step: 22, Action: 4, Reward: 0.2, Total Reward: 13.899999999999999, Loss: 2.856795310974121
Optimizing model...
Step: 23, Action: 1, Reward: -5.0, Total Reward: 8.899999999999999, Loss: 1.6173111200332642
Optimizing model...
Step: 24, Action: 1, Reward: -15.0, Total Reward: -6.10000000000

KeyboardInterrupt: 