In [8]:
import numpy as np
import pyautogui
import imutils
import cv2
import sys
import time
import random
import math
import pytesseract

import mss
import mss.tools
from PIL import Image
import PIL.ImageOps
import re
import os
from selenium import webdriver
from Xlib import display, X

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

from collections import namedtuple

device = 'cuda:1'
discount_factor = 0.9


In [9]:
def open_and_size_browser_window(width, height, x_pos=0, y_pos=0, url='http://www.slither.io'):

    # opens the browser window
    chrome_options = webdriver.ChromeOptions()
    chrome_options.add_argument("--disable-infobars")
    driver = webdriver.Chrome("./chromedriver", chrome_options=chrome_options)
    driver.set_window_size(width, height)

    driver.set_window_position(x_pos, y_pos)
    driver.get(url)

    return driver

In [10]:
class Flatten(nn.Module):
    def forward(self, x):
        x = x.view(x.size()[0], -1)
        return x


class DQN(nn.Module):

    def __init__(self):
        super(DQN, self).__init__()
        self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=2)
        self.bn1 = nn.BatchNorm2d(6)
        self.conv2 = nn.Conv2d(6, 10, kernel_size=5, stride=2)
        self.bn2 = nn.BatchNorm2d(10)
        self.conv3 = nn.Conv2d(10, 16, kernel_size=5, stride=2)
        self.bn3 = nn.BatchNorm2d(16)
        self.conv4 = nn.Conv2d(16, 32, kernel_size=5, stride=2)
        self.bn4 = nn.BatchNorm2d(32)
        self.flatten = Flatten()
        self.dense1 = nn.Linear(3584, 1024)
        self.dense2 = nn.Linear(1024, 256)
        self.dense3 = nn.Linear(256, 8)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.flatten(x)

        x = F.relu(self.dense1(x))
        x = F.relu(self.dense2(x))
        x = self.dense3(x)

        return x

In [11]:
def action(number, click):
    radian = 2 * math.pi * number / 8
    move_to_radians(radian, click=click)


def move_to_radians(radians, click, radius=100):

    if click == 0:
        pyautogui.moveTo(728 + radius * math.cos(radians),
                         492 + radius * math.sin(radians))
    else:
        pyautogui.mouseDown(728 + radius * math.cos(radians),
                            492 + radius * math.sin(radians))
        time.sleep(0.1)
        pyautogui.mouseUp(728 + radius * math.cos(radians),
                          492 + radius * math.sin(radians))

    return radians

def start_game(start_button_position_x, start_button_position_y):

    time.sleep(1)
    pyautogui.click(start_button_position_x, start_button_position_y)
    time.sleep(0.1)
    move_to_radians(0, 0)


def get_direction():
    x, y = pyautogui.position()

    return math.atan2(y, x)


def Reward(prev_length, cur_length):
    dif = cur_length - prev_length

    return dif


def screenshot(x, y, w, h, gray, reduction_factor):
    with mss.mss() as sct:
        # The screen part to capture
        region = {'left': x, 'top': y, 'width': w, 'height': h}

        # Grab the data
        img = sct.grab(region)

        if gray:
            result = cv2.cvtColor(np.array(img), cv2.COLOR_BGRA2GRAY)
        else:
            result = cv2.cvtColor(np.array(img), cv2.COLOR_BGRA2BGR)

        img = result[::reduction_factor, ::reduction_factor]
        img = Image.fromarray(img)
        # img.show()
        return img


def read_score(driver):

    dead = False
    score = 10

    try:
        score = int(driver.find_elements_by_tag_name('span')[32].text)
        print('Alive: {}'.format(score))
    except:
        dead = True
        print('Dead')
        pass

    return score, dead

In [12]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """transition 저장"""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [13]:
BATCH_SIZE = 128
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 200
TARGET_UPDATE = 10


dqn = DQN().to(device)
#dqn.load_state_dict(torch.load('Model'))
target_dqn = DQN().to(device)
target_dqn.load_state_dict(dqn.state_dict())
target_dqn.eval()

optimizer = optim.RMSprop(dqn.parameters())
memory = ReplayMemory(10000)


steps_done = 0


def select_action(env):
    global steps_done
    epsilon = EPS_END + (EPS_START - EPS_END) * \
        math.exp(-1. * steps_done / EPS_DECAY)

    if np.random.random() < epsilon:
        #action_number = np.random.randint(0,16)
        action_number = np.random.randint(0,8)
    else:
        Q = dqn(env)
        action_number = torch.argmax(Q)

    #action(action_number // 2. action_number %2)
    action(action_number, 0)
    return action_number





In [None]:
def train_model():
    if len(memory) < BATCH_SIZE:
        return
        
    transitions = memory.sample(BATCH_SIZE)
