In [1]:
from google.colab import drive
drive.mount("/gdrive", force_remount=True)

Mounted at /gdrive


In [7]:
class TicTacToe():
    def __init__(self):
        self.reset()

    def reset(self):
        self.board = [0, 0, 0, 0, 0, 0, 0, 0, 0]
        self.turn = 1
        self.gameover = False
        return self.board

    def get_state(self):
        if self.turn == 1:
            return self.board
        else:
            new_board = []
            for i in range(9):
                if self.board[i] == 1:
                    new_board.append(2)
                elif self.board[i] == 2:
                    new_board.append(1)
                else:
                    new_board.append(0)
            return new_board

    def step(self, action):
        # 턴 미리 바꿔주기
        self.turn = 3 - self.turn

        # 불가능한 액션
        if self.board[action] != 0:
            return self.get_state(), -10, True
        self.board[action] = self.turn

        # 승리하는 수를 놓았는지
        if self.check_reward():
            return self.get_state(), 2, True

        # 패배하는 수를 놓았는지
        if self.check_defeatable():
            return self.get_state(), -1, False

        # 무승부인지
        if self.check_isfull():
            return self.get_state(), 1, True

        return self.get_state(), 0, False

    def check_isfull(self):
        for i in range(9):
            if self.board[i] == 0:
                return False
        return True

    def check_defeatable(self):
        for i in range(3):
            if (self.board[i * 3] == self.board[i * 3 + 1] == (3 - self.turn) and self.board[i * 3 + 2] == 0) or\
                (self.board[i * 3] == self.board[i * 3 + 2] == (3 - self.turn) and self.board[i * 3 + 1] == 0) or \
                    (self.board[i * 3 + 1] == self.board[i * 3 + 2] == (3 - self.turn) and self.board[i * 3] == 0):
                return True
            if (self.board[i] == self.board[i + 3] == (3 - self.turn) and self.board[i + 6] == 0) or \
                (self.board[i] == self.board[i + 6] == (3 - self.turn) and self.board[i + 3] == 0) or \
                    (self.board[i + 3] == self.board[i + 6] == (3 - self.turn) and self.board[i] == 0):
                return True
        if (self.board[0] == self.board[4] == (3 - self.turn) and self.board[8] == 0) or \
            (self.board[0] == self.board[8] == (3 - self.turn) and self.board[4] == 0) or \
                (self.board[4] == self.board[8] == (3 - self.turn) and self.board[0] == 0):
            return True
        if (self.board[2] == self.board[4] == (3 - self.turn) and self.board[6] == 0) or \
            (self.board[2] == self.board[6] == (3 - self.turn) and self.board[4] == 0) or \
                (self.board[4] == self.board[6] == (3 - self.turn) and self.board[2] == 0):
            return True
        return False

    def available_actions(self):
        actions = []
        for i in range(9):
            if self.board[i] == 0:
                actions.append(i)
        return actions

    def check_reward(self):
        for i in range(3):
            if self.board[i * 3] == self.board[i * 3 + 1] == self.board[i * 3 + 2] != 0:
                return 1
            if self.board[i] == self.board[i + 3] == self.board[i + 6] != 0:
                return 1
        if self.board[0] == self.board[4] == self.board[8] != 0:
            return 1
        if self.board[2] == self.board[4] == self.board[6] != 0:
            return 1
        return 0

    def render(self):
        print(self.turn)
        for i in range(3):
            for j in range(3):
                if self.board[i * 3 + j] == 1:
                    print('O', end='')
                elif self.board[i * 3 + j] == 2:
                    print('X', end='')
                else:
                    print(' ', end='')
            print()
        print()


In [9]:
import collections
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

# Hyperparameters
learning_rate = 0.0005
gamma = 1
buffer_limit = 100000
batch_size = 32


class ReplayBuffer():
    def __init__(self):
        self.buffer = collections.deque(maxlen=buffer_limit)

    def put(self, transition):
        self.buffer.append(transition)

    def sample(self, n):
        mini_batch = random.sample(self.buffer, n)
        s_lst, a_lst, r_lst, s_prime_lst, done_mask_lst = [], [], [], [], []

        for transition in mini_batch:
            s, a, r, s_prime, done_mask = transition
            s_lst.append(s)
            a_lst.append([a])
            r_lst.append([r])
            s_prime_lst.append(s_prime)
            done_mask_lst.append([done_mask])

        return torch.tensor(s_lst, dtype=torch.float).cuda(), torch.tensor(a_lst).cuda(), \
            torch.tensor(r_lst).cuda(), torch.tensor(s_prime_lst, dtype=torch.float).cuda(), \
            torch.tensor(done_mask_lst).cuda()

    def size(self):
        return len(self.buffer)


class Qnet(nn.Module):
    def __init__(self):
        super(Qnet, self).__init__()
        self.fc1 = nn.Linear(9, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 128)
        self.fc4 = nn.Linear(128, 128)
        self.fc5 = nn.Linear(128, 9)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = self.fc5(x)
        return x

    def sample_action(self, obs, epsilon):
        out = self.forward(obs)
        coin = random.random()
        if coin < epsilon:
            return random.randint(0, 8)
        else:
            return out.argmax().item()


def train(q, q_target, memory, optimizer):
    for i in range(10):
        s, a, r, s_prime, done_mask = memory.sample(batch_size)

        q_out = q(s)
        q_a = q_out.gather(1, a)
        max_q_prime = q_target(s_prime).max(1)[0].unsqueeze(1)
        target = r + gamma * max_q_prime * done_mask
        loss = F.smooth_l1_loss(q_a, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


def main():
    env = TicTacToe()
    q = Qnet().cuda()
    q_target = Qnet().cuda()
    
    # 불러오기
    q.load_state_dict(torch.load('/gdrive/MyDrive/ossp2/DQN1.pth'))
    
    q_target.load_state_dict(q.state_dict())
    memory = ReplayBuffer()

    print_interval = 20
    optimizer = optim.Adam(q.parameters(), lr=learning_rate)

    for n_epi in range(20000):
        epsilon = max(0.01, 0.2 - 0.01*(n_epi/200))
        s = env.reset()
        done = False

        while not done:
            # env.render()
            a = q.sample_action(torch.from_numpy(np.array(s)).float().cuda(), epsilon)

            available_actions = env.available_actions()
            if a not in available_actions:
                s_prime, r, done = env.step(random.choice(available_actions))
            else:
                s_prime, r, done = env.step(a)
            
            done_mask = 0.0 if done else 1.0
            memory.put((s, a, r/100.0, s_prime, done_mask))
            s = s_prime

            if done:
                break

        if memory.size() > 2000:
            train(q, q_target, memory, optimizer)

        if n_epi % print_interval == 0 and n_epi != 0:
            q_target.load_state_dict(q.state_dict())
            print(n_epi)

    torch.save(q.state_dict(), '/gdrive/MyDrive/ossp2/DQN1.pth')


if __name__ == '__main__':
    main()


20
40
60
80
100
120
140
160
180
200
220
240
260
280
300
320
340
360
380
400
420
440
460
480
500
520
540
560
580
600
620
640
660
680
700
720
740
760
780
800
820
840
860
880
900
920
940
960
980
1000
1020
1040
1060
1080
1100
1120
1140
1160
1180
1200
1220
1240
1260
1280
1300
1320
1340
1360
1380
1400
1420
1440
1460
1480
1500
1520
1540
1560
1580
1600
1620
1640
1660
1680
1700
1720
1740
1760
1780
1800
1820
1840
1860
1880
1900
1920
1940
1960
1980
2000
2020
2040
2060
2080
2100
2120
2140
2160
2180
2200
2220
2240
2260
2280
2300
2320
2340
2360
2380
2400
2420
2440
2460
2480
2500
2520
2540
2560
2580
2600
2620
2640
2660
2680
2700
2720
2740
2760
2780
2800
2820
2840
2860
2880
2900
2920
2940
2960
2980
3000
3020
3040
3060
3080
3100
3120
3140
3160
3180
3200
3220
3240
3260
3280
3300
3320
3340
3360
3380
3400
3420
3440
3460
3480
3500
3520
3540
3560
3580
3600
3620
3640
3660
3680
3700
3720
3740
3760
3780
3800
3820
3840
3860
3880
3900
3920
3940
3960
3980
4000
4020
4040
4060
4080
4100
4120
4140
4160
4180
4200
422