In [6]:
from google.colab import drive
drive.mount("/gdrive", force_remount=True)

Mounted at /gdrive


In [10]:
class TicTacToe():
    def __init__(self):
        self.reset()

    def reset(self):
        self.board = [0, 0, 0, 0, 0, 0, 0, 0, 0] * 3
        self.turn = 1
        self.gameover = False
        return self.board

    def get_state(self):
        if self.turn == 1:
            return self.board
        else:
            new_board = []
            for i in range(9):
                if self.board[i] == 1:
                    new_board.append(2)
                elif self.board[i] == 2:
                    new_board.append(1)
                else:
                    new_board.append(0)
            return new_board * 3

    def step(self, action):
        # 턴 미리 바꿔주기
        self.turn = 3 - self.turn

        # 불가능한 액션
        if self.board[action] != 0:
            return self.get_state(), -10, True
        self.board[action] = self.turn

        # 승리하는 수를 놓았는지
        if self.check_reward():
            return self.get_state(), 2, True

        # 패배하는 수를 놓았는지
        if self.check_defeatable():
            return self.get_state(), -1, False

        # 무승부인지
        if self.check_isfull():
            return self.get_state(), 1, True

        return self.get_state(), 0, False

    def check_isfull(self):
        for i in range(9):
            if self.board[i] == 0:
                return False
        return True

    def check_defeatable(self):
        for i in range(3):
            if (self.board[i * 3] == self.board[i * 3 + 1] == (3 - self.turn) and self.board[i * 3 + 2] == 0) or\
                (self.board[i * 3] == self.board[i * 3 + 2] == (3 - self.turn) and self.board[i * 3 + 1] == 0) or \
                    (self.board[i * 3 + 1] == self.board[i * 3 + 2] == (3 - self.turn) and self.board[i * 3] == 0):
                return True
            if (self.board[i] == self.board[i + 3] == (3 - self.turn) and self.board[i + 6] == 0) or \
                (self.board[i] == self.board[i + 6] == (3 - self.turn) and self.board[i + 3] == 0) or \
                    (self.board[i + 3] == self.board[i + 6] == (3 - self.turn) and self.board[i] == 0):
                return True
        if (self.board[0] == self.board[4] == (3 - self.turn) and self.board[8] == 0) or \
            (self.board[0] == self.board[8] == (3 - self.turn) and self.board[4] == 0) or \
                (self.board[4] == self.board[8] == (3 - self.turn) and self.board[0] == 0):
            return True
        if (self.board[2] == self.board[4] == (3 - self.turn) and self.board[6] == 0) or \
            (self.board[2] == self.board[6] == (3 - self.turn) and self.board[4] == 0) or \
                (self.board[4] == self.board[6] == (3 - self.turn) and self.board[2] == 0):
            return True
        return False

    def available_actions(self):
        actions = []
        for i in range(9):
            if self.board[i] == 0:
                actions.append(i)
        return actions

    def check_reward(self):
        for i in range(3):
            if self.board[i * 3] == self.board[i * 3 + 1] == self.board[i * 3 + 2] != 0:
                return 1
            if self.board[i] == self.board[i + 3] == self.board[i + 6] != 0:
                return 1
        if self.board[0] == self.board[4] == self.board[8] != 0:
            return 1
        if self.board[2] == self.board[4] == self.board[6] != 0:
            return 1
        return 0

    def render(self):
        print(self.turn)
        for i in range(3):
            for j in range(3):
                if self.board[i * 3 + j] == 1:
                    print('O', end='')
                elif self.board[i * 3 + j] == 2:
                    print('X', end='')
                else:
                    print(' ', end='')
            print()
        print()


In [None]:
import collections
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

# Hyperparameters
learning_rate = 0.00025
gamma = 1
buffer_limit = 100000
batch_size = 32

loss_sum = 0


class ReplayBuffer():
    def __init__(self):
        self.buffer = collections.deque(maxlen=buffer_limit)

    def put(self, transition):
        self.buffer.append(transition)

    def sample(self, n):
        mini_batch = random.sample(self.buffer, n)
        s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst = [], [], [], [], []

        for transition in mini_batch:
            s, a, r, s_prime, done_mask = transition
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r])
            s_prime_lst.append(s_prime)
            done_mask_lst.append([done_mask])

        return torch.tensor(s_lst, dtype=torch.float).cuda(), torch.tensor(a_lst).cuda(), \
            torch.tensor(r_lst).cuda(), torch.tensor(s_prime_lst, dtype=torch.float).cuda(), \
            torch.tensor(done_mask_lst).cuda()

    def size(self):
        return len(self.buffer)


class Qnet(nn.Module):
    def __init__(self):
        super(Qnet, self).__init__()
        self.fc1 = nn.Linear(9*3, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.fc4 = nn.Linear(256, 256)
        self.fc5 = nn.Linear(256, 9)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = self.fc5(x)
        return x

    def sample_action(self, obs, epsilon):
        out = self.forward(obs)
        coin = random.random()
        if coin < epsilon:
            return random.randint(0, 8)
        else:
            return out.argmax().item()


def train(q, q_target, memory, optimizer):
    global loss_sum
    for i in range(10):
        s, a, r, s_prime, done_mask = memory.sample(batch_size)

        q_out = q(s)
        q_a = q_out.gather(1, a)
        max_q_prime = q_target(s_prime).max(1)[0].unsqueeze(1)
        target = r + gamma * max_q_prime * done_mask
        loss = F.smooth_l1_loss(q_a, target)
        loss_sum += loss.item()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


def main():
    env = TicTacToe()
    q = Qnet().cuda()
    q_target = Qnet().cuda()
    
    # 불러오기
    # q.load_state_dict(torch.load('/gdrive/MyDrive/ossp2/DQN1.pth'))
    
    q_target.load_state_dict(q.state_dict())
    memory = ReplayBuffer()

    print_interval = 20
    optimizer = optim.Adam(q.parameters(), lr=learning_rate)

    for n_epi in range(20000):
        epsilon = max(0.01, 0.2 - 0.01*(n_epi/200))
        s = env.reset()
        done = False

        while not done:
            # env.render()
            a = q.sample_action(torch.from_numpy(np.array(s)).float().cuda(), epsilon)

            available_actions = env.available_actions()
            if a not in available_actions:
                s_prime, r, done = env.step(random.choice(available_actions))
            else:
                s_prime, r, done = env.step(a)
            
            done_mask = 0.0 if done else 1.0
            memory.put((s, a, r/100.0, s_prime, done_mask))
            s = s_prime

            if done:
                break

        if memory.size() > 2000:
            train(q, q_target, memory, optimizer)

        if n_epi % print_interval == 0 and n_epi != 0:
            q_target.load_state_dict(q.state_dict())
            print(n_epi, "loss sum: ",loss_sum)

    torch.save(q.state_dict(), '/gdrive/MyDrive/DQN1.pth')


if __name__ == '__main__':
    main()


20 loss sum:  0
40 loss sum:  0
60 loss sum:  0
80 loss sum:  0
100 loss sum:  0
120 loss sum:  0
140 loss sum:  0
160 loss sum:  0
180 loss sum:  0
200 loss sum:  0
220 loss sum:  0
240 loss sum:  0
260 loss sum:  0.014479881792794913
280 loss sum:  0.03950716927101894
300 loss sum:  0.05791999700704764
320 loss sum:  0.0784501064408687
340 loss sum:  0.0987197378926794
360 loss sum:  0.11764931467405404
380 loss sum:  0.13678867762064328
400 loss sum:  0.15356049586989684
420 loss sum:  0.17127715349124628
440 loss sum:  0.1871655227168958
460 loss sum:  0.204757799003346
480 loss sum:  0.21994588607594778
500 loss sum:  0.23451442164878245
520 loss sum:  0.24930027136360877
540 loss sum:  0.26158529751046444
560 loss sum:  0.2759781268414372
580 loss sum:  0.2884376626516314
600 loss sum:  0.30142019568666
620 loss sum:  0.31134146262866125
640 loss sum:  0.3223078941982749
660 loss sum:  0.33151479156094865
680 loss sum:  0.3415781360681649
700 loss sum:  0.3508474117397782
720 los