In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import random


In [2]:
# DQN model for Crossy Road game
# input is image of the game screen
# output is the action to take (up, down, left, right)

class DQN(nn.Module):
    def __init__(self, input_size, output_size, hidden_size):
        super(DQN, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        
        self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
        self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)
        
        def conv2d_size_out(size, kernel_size=5, stride=2):
            return (size - (kernel_size - 1) - 1) // stride + 1
        
        convw = conv2d_size_out(conv2d_size_out(conv2d_size_out(input_size[1])))
        convh = conv2d_size_out(conv2d_size_out(conv2d_size_out(input_size[2])))
        linear_input_size = convw * convh * 32
        
        self.fc1 = nn.Linear(linear_input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = F.relu(self.fc1(x.reshape(x.size(0), -1)))
        return self.fc2(x)
        
    def act(self, state, epsilon):
        if random.random() > epsilon:
            state = torch.FloatTensor(state).unsqueeze(0)
            q_value = self.forward(state)
            action = q_value.max(1)[1].data[0]
        else:
            action = random.randrange(self.output_size)
        return action


In [3]:
# Replay buffer for DQN
# stores the transitions (state, action, reward, next_state, done)
# and samples a batch of transitions for training

class ReplayBuffer:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.position = 0
        
    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity
        
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))
        return state, action, reward, next_state, done
        
    def __len__(self):
        return len(self.buffer)
    
    
# DQN agent for Crossy Road game
# uses DQN model and replay buffer for training

class DQNAgent:
    def __init__(self, input_size, output_size, hidden_size, replay_buffer_capacity, batch_size, gamma, epsilon_start, epsilon_end, epsilon_decay):
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_size = hidden_size
        self.replay_buffer_capacity = replay_buffer_capacity
        self.batch_size = batch_size
        self.gamma = gamma
        self.epsilon_start = epsilon_start
        self.epsilon_end = epsilon_end
        self.epsilon_decay = epsilon_decay
        
        self.model = DQN(input_size, output_size, hidden_size)
        self.replay_buffer = ReplayBuffer(replay_buffer_capacity)
        self.optimizer = optim.Adam(self.model.parameters())
        
        self.steps_done = 0
        
    def select_action(self, state):
        epsilon = self.epsilon_end + (self.epsilon_start - self.epsilon_end) * np.exp(-1. * self.steps_done / self.epsilon_decay)
        self.steps_done += 1
        return self.model.act(state, epsilon)
    
    def optimize_model(self):
        if len(self.replay_buffer) < self.batch_size:
            return
        
        state, action, reward, next_state, done = self.replay_buffer.sample(self.batch_size)
        state = torch.FloatTensor(np.float32(state))
        action = torch.LongTensor(action)
        reward = torch.FloatTensor(reward)
        next_state = torch.FloatTensor(np.float32(next_state))
        done = torch.FloatTensor(done)
        
        # print(state.shape, action.shape, reward.shape, next_state.shape, done.shape)
        
        next_state = next_state.permute(0, 3, 1, 2)
        state = state.permute(0, 3, 1, 2)
        
        q_values = self.model(state)
        next_q_values = self.model(next_state)
        q_value = q_values.gather(1, action.unsqueeze(1)).squeeze(1)
        next_q_value = next_q_values.max(1)[0]
        expected_q_value = reward + self.gamma * next_q_value * (1 - done)
        
        loss = F.smooth_l1_loss(q_value, expected_q_value)
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return loss.item()
    
    def push(self, state, action, reward, next_state, done):
        self.replay_buffer.push(state, action, reward, next_state, done)
        
    def save(self, path):
        torch.save(self.model.state_dict(), path)
        
    def load(self, path):
        self.model.load_state_dict(torch.load(path))
        self.model.eval()
        
    def reset(self):
        self.steps_done = 0
        

# test DQN agent
agent = DQNAgent((3, 84, 84), 4, 128, 10000, 32, 0.99, 1.0, 0.1, 10000)
state = np.random.rand(3, 84, 84)
action = agent.select_action(state)
reward = 1.0
next_state = np.random.rand(3, 84, 84)
done = 0
agent.push(state, action, reward, next_state, done)
loss = agent.optimize_model()
agent.save('dqn.pth')
agent.load('dqn.pth')
agent.reset()
print('test passed')


test passed


In [4]:
# run DQN agent on Crossy Road game to train and play the game
import pyautogui
import cv2
import time
import keyboard

RES_X = 1920
RES_Y = 1080

GAME_REGION = (405, 210, 850, 480)

restart_button = cv2.imread('restart_button.png', cv2.IMREAD_GRAYSCALE)

# Init restart button image
restart_button = cv2.imread('restart_button.png', cv2.IMREAD_GRAYSCALE)
restart_button = cv2.normalize(restart_button, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)

def get_screen(region):
    screen = pyautogui.screenshot(region=(region[0], region[1], region[2], region[3]))
    screen = np.array(screen)
    screen = cv2.resize(screen, (425, 240))
    # screen = np.moveaxis(screen, 2, 0)
    return screen

def is_game_over(image, score_threshold=0.9, scale=0.5):
    # Check if the game is over by checking if the restart button is visible
    grey_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    
    resized_screenshot = cv2.resize(grey_image, (0, 0), fx=scale, fy=scale)
    resized_template = cv2.resize(restart_button, (0, 0), fx=scale, fy=scale)
    h, w = resized_screenshot.shape
    
    resized_screenshot = resized_screenshot.astype(np.float32)
    
    cropped_search_box = resized_screenshot[int(h * 0.7):, int(w * 0.3):int(w * 0.7)]
    result = cv2.matchTemplate(cropped_search_box, resized_template, cv2.TM_CCOEFF_NORMED)
    locations = np.where(result >= score_threshold)
    return len(locations[0]) > 0

def compute_reward(reward_state):
    # Compute reward based on the change in the game screen   
    time = reward_state['time']
    state = reward_state['state']
    action = reward_state['action']
    next_state = reward_state['next_state']
    total_reward = reward_state['total_reward']
    reward = 0
    
    if is_game_over(next_state):
        reward = -100
    else:
        reward = 1
        
    reward += time * 0.1
    
    if action == 0:
        reward += 1
    elif action == 1:
        reward -= 1
    elif action == 2:
        reward = 0.25
    elif action == 3:
        reward += 0.25
    
    return reward
    

# train DQN agent
agent = DQNAgent((3, 425, 240), 4, 128, 1000, 8, 0.99, 1.0, 0.1, 10000)
episodes = 1000
episode_length = 1000
losses = []
rewards = []

print("Model is ready to train")
# start the train after pressing 's' key
keyboard.wait('s')

for episode in range(episodes):
    state = get_screen(GAME_REGION)
    total_reward = 0
    total_loss = 0
    
    
    start_time = time.time()
    for step in range(episode_length):
        action = agent.select_action(state)
        pyautogui.press(['up', 'down', 'left', 'right'][action])
        
        time.sleep(0.25)
        
        next_state = get_screen(GAME_REGION)
        
        reward_state = {
            'state': state,
            'action': action,
            'next_state': next_state,
            'time': time.time() - start_time,
            'total_reward': total_reward,
        }
        
        reward = compute_reward(reward_state)
        done = 0
        
        agent.push(state, action, reward, next_state, done)
        state = next_state
        total_reward += reward
        
        if is_game_over(state):
            done = 1
            agent.push(state, action, reward, next_state, done)
            break
        
        loss = agent.optimize_model()
        if loss is not None:
            total_loss += loss
        
        print('step: {}, loss: {}, reward: {}'.format(step, loss, reward))
    
    # tap space key to restart the game
    keyboard.press_and_release('space')
    time.sleep(4)
    keyboard.press_and_release('space')
    time.sleep(1)
            
    losses.append(total_loss)
    rewards.append(total_reward)
    print('episode: {}, loss: {}, reward: {}'.format(episode, total_loss, total_reward))
    agent.save('dqn.pth')
    agent.reset()


Model is ready to train
step: 0, loss: None, reward: 0.03911564350128183
step: 1, loss: None, reward: 2.0767778158187866
step: 2, loss: None, reward: 2.1142251968383787
step: 3, loss: None, reward: 0.15172569751739506
step: 4, loss: None, reward: 0.25
step: 5, loss: None, reward: 1.4774322032928466
step: 6, loss: None, reward: 1.5173386573791503
step: 7, loss: 3.1072330474853516, reward: 2.3049256086349486
step: 8, loss: 57.31833267211914, reward: 1.6109858751296997
step: 9, loss: 52.83697509765625, reward: 2.4069194078445433
step: 10, loss: 1.9456959962844849, reward: 0.25
step: 11, loss: 6.626850605010986, reward: 0.25
step: 12, loss: 5.461476802825928, reward: 2.5464776039123533
step: 13, loss: 2.4603328704833984, reward: 1.8416924238204957
step: 14, loss: 0.5790379643440247, reward: 0.6388439416885376
step: 15, loss: 0.8058202862739563, reward: 0.25
step: 16, loss: 0.5580230951309204, reward: 0.25
step: 17, loss: 0.6754741072654724, reward: 2.777116870880127
step: 18, loss: 1.15792

KeyboardInterrupt: 