# 00. Import

In [1]:
from typing import Tuple
import numpy as np
import random
import copy
import pickle
from math import sqrt

import torch
from torch import nn, optim
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset
from torchvision import datasets, transforms

# 01. Environment

In [2]:
# parameter
state_size = (3,3)

In [3]:
class Environment:
    def __init__(self, state_size:Tuple):
        # env size
        self.state_size = state_size # (3, 3)
        self.n = self.state_size[0] # 3
        self.num_actions = self.n ** 2 # 9

        # state, action
        self.present_state = np.zeros((2, self.n, self.n)) # present_state[0]: state for first player
        self.action_space = np.arange(self.num_actions) # [0, 1, ..., 8] : action idx

        # reward, done
        self.reward_dict = {'win':1, 'lose':-1, 'draw':0, 'progress':0}
        self.done = False

        # 추가
        self.player = True # True: first player


    def step(self, action_idx):
        '''
        action_idx에 따라 게임 진행
        output: next_state, reward, done, is_win
        '''
        x, y = divmod(action_idx, self.n)

        self.change_player() # change turn
        self.present_state[1][x, y] = -1

        # 게임 종료 및 승자 확인
        next_state = self.present_state
        done, is_win = self.is_done(next_state)
        reward = self.check_reward(is_win)
        self.done = done

        return next_state ,reward, done, is_win


    def reset(self):
        '''
        game reset
        '''
        self.present_state = np.zeros((2, self.n, self.n))
        self.done = False
        self.player = True

    def render(self, state):
        '''
        print by string
        first player: X / second player: O
        '''
        state = state if self.player else state[[1, 0]]
        state = state.reshape(2, -1)
        board = state[0] - state[1] # -1: player / 1: enemy
        check_board = list(map(lambda x: 'X' if board[x] == -1 else 'O' if board[x] == 1 else '.', self.action_space))

        # string으로 변환하고 game board 형태로 출력
        board_string = ' '.join(check_board)
        formatted_string = '\n'.join([board_string[i:i+6] for i in range(0, len(board_string), 6)])

        print(formatted_string)
        print("-"*10)


    def check_legal_action(self, state):
        '''
        board에서 가능한 action array를 원핫으로 출력
        '''
        state = state.reshape(2,-1)
        board = state[0]+state[1]
        legal_actions = np.array([board[x] == 0 for x in self.action_space], dtype = int)
        return legal_actions


    def is_done(self, state):
        '''
        game의 종료 여부 확인
        is_win: True - win / False - draw
        '''
        is_done, is_win = False, False
        player_state = state[1]

        # 무승부 여부 확인
        if np.sum(state) == -9:
            is_done, is_win = True, False

        # 승리 조건 확인
        axis_diag_sum = np.concatenate([player_state.sum(axis=0), player_state.sum(axis=1), [player_state.trace()], [np.fliplr(player_state).trace()]]) # (8, )
        if -3 in axis_diag_sum:
            is_done, is_win = True, True

        return is_done, is_win


    # 추가 메서드
    def change_player(self):
        '''
        state를 다음 턴으로 돌려줌
        player를 다음 player로 바꿈
        '''
        self.present_state[[0, 1]] = self.present_state[[1, 0]]
        self.player = not self.player


    def check_reward(self, is_win):
        '''
        reward를 주는 함수
        draw, progress: 0
        first player 기준 reward 제공
        player를 돌린 후 reward를 제공하는 것 고려
        '''
        reward = 0

        if is_win:
            reward = self.reward_dict["lose"] if self.player else self.reward_dict["win"]

        return reward

    def choose_random_action(self, state):
        '''
        가능한 action 중에서 random으로 action을 선택한다.
        '''
        legal_actions = self.check_legal_action(state)
        legal_action_idxs = np.where(legal_actions != 0)[0]
        action = np.random.choice(legal_action_idxs)

        return action


## Test

In [4]:
state_size = (3,3)
env = Environment(state_size)
env.render(env.present_state)

. . . 
. . . 
. . .
----------


In [5]:
np.sum(env.present_state)

0.0

In [6]:
env.reset()
while not env.done:
    action = env.choose_random_action(env.present_state)
    next_state, reward, done, is_win = env.step(action)
    print(reward, done, is_win, env.player)
    env.render(next_state)

0 False False False
X . . 
. . . 
. . .
----------
0 False False True
X . . 
. . . 
. O .
----------
0 False False False
X . . 
. . . 
X O .
----------
0 False False True
X . . 
O . . 
X O .
----------
0 False False False
X . . 
O . X 
X O .
----------
0 False False True
X O . 
O . X 
X O .
----------
0 False False False
X O X 
O . X 
X O .
----------
0 False False True
X O X 
O . X 
X O O
----------
1 True True False
X O X 
O X X 
X O O
----------


# 02. Net

In [7]:
# parameter
state_size = (3, 3) # env.state_size
action_size = 9 # env.action_size

CONV_UNITS = 64
RESIDUAL_NUM = 16
BATCHSIZE = 64

In [8]:
# Residual block
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3,3), bias=False, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        sc = x

        x = self.conv(x)
        x = self.bn(x)
        x = F.relu(x)

        x = self.conv(x)
        x = self.bn(x)
        x += sc
        x = F.relu(x)
        return x

In [239]:
# main Net
class Net(nn.Module):
    def __init__(self, state_size, action_size, conv_units):
        super().__init__()
        self.conv = nn.Conv2d(in_channels=2, out_channels=conv_units, kernel_size=(3,3), bias=False, padding=1)
        self.bn = nn.BatchNorm2d(conv_units)
        self.pool = nn.MaxPool2d(kernel_size=(3,3), stride=1, padding=1)
        self.residual_block = ResidualBlock(conv_units, conv_units)

        self.batch_size = BATCHSIZE

        self.policy_head = nn.Sequential(
            nn.Conv2d(conv_units, 2, kernel_size=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(18, action_size),
            nn.Softmax(dim=-1)
        )

        self.value_head = nn.Sequential(
            nn.Conv2d(conv_units, 1, kernel_size=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(9, 1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.conv(x)
        # x = self.bn(x)
        x = F.relu(x)
        x = self.pool(x)

        # residual block
        for i in range(RESIDUAL_NUM):
            x = self.residual_block(x)

        # pooling
        x = self.pool(x)

        # policy, value 반환
        policy = self.policy_head(x)
        value = self.value_head(x)

        return policy, value

## Test

In [10]:
state_size = (3,3)
env = Environment(state_size)
model = Net(state_size, env.num_actions, CONV_UNITS)

# 03. MCTS

In [11]:
# parameter
C_PUCT = 1.0
EVAL_CNT = 10

In [117]:
# define Mcts class
class Mcts:
    def __init__(self, env, model, state, temperature):
        self.env = copy.deepcopy(env)
        self.model = model
        self.state = state
        self.temperature = temperature
        self.legal_actions = self.env.check_legal_action(self.state)

        # define Node class
        class Node:
            def __init__(self, mcts, state, p):
                self.mcts = copy.deepcopy(mcts) # Mcts 객체 참조 저장
                self.env = copy.deepcopy(self.mcts.env)

                self.state = copy.deepcopy(state)
                self.p = p # policy
                self.n = 0 # count
                self.w = 0 # cumulate value
                self.child_nodes = None

            def evaluate(self):
                self.state[[0, 1]] = self.state[[1, 0]]
                is_done, is_win = self.env.is_done(self.state)

                # 게임 종료 시 승패에 따라 value 계산
                if is_done:
                    value = 1 if is_win else 0

                    self.w += value
                    self.n += 1
                    return value

                # child node가 없는 경우 => 확장
                if not self.child_nodes:
                    state = self.state
                    state = torch.tensor(state, dtype = torch.float32)
                    state = state.unsqueeze(0)
                    # model을 통해 policy와 value 얻음
                    policies, value = self.predict(state)

                    self.w += value
                    self.n += 1

                    # expand child node
                    self.child_nodes = []
                    legal_actions = np.where(self.mcts.legal_actions != 0)

                    for action, policy in zip(*legal_actions, *policies):
                        self.env.present_state = copy.deepcopy(self.state)
                        next_state, _, _, _ = self.env.step(action)
                        self.child_nodes.append(Node(self.mcts, next_state, policy))

                    return value

                # end node가 아니고, child node가 있는 경우 => 전개
                else:
                    next_child_node = self.next_child_node()
                    value = -next_child_node.evaluate()

                    self.w += value
                    self.n += 1
                    return value


            def next_child_node(self):
                '''
                PUCB에 따라 child node를 선택
                '''
                node_scores = list(map(lambda c: c.n, self.child_nodes))

                scores = sum(node_scores)
                # pucb 값에 따라 정렬한 child nodes list (마지막이 최댓값을 갖는 child node)
                pucb_sorted = sorted(self.child_nodes, key = lambda c: (-c.w / c.n if c.n else 0.0) + C_PUCT * c.p * sqrt(scores) / (1 + c.n))

                return pucb_sorted[-1]


            def predict(self, state):
                '''
                model을 통해 policy와 value 계산
                '''
                x = state # 차원 맞추기
                # x = x.unsqueeze(0) 아마도...
                policies, value = self.mcts.model.forward(x)
                policies = policies.detach().numpy()
                value = value.detach().numpy()
                policies = policies * self.mcts.legal_actions # legal action에 대한 policy
                policies /= np.sum(policies) if np.sum(policies) else 1 # 합계 1의 확률분포로 변환

                return policies, value

        self.Node = Node # Node 객체 생성

    ########################
    # methods of Mcts
    def get_policy(self, state):
        '''
        MCTS에 따라 policy 계산
        '''
        root_node = self.Node(self, state, 0) # Mcts 객체 self 전달

        for i in range(EVAL_CNT):
            root_node.evaluate()

        scores = [c.n for c in root_node.child_nodes]

        if self.temperature == 0: # 최대값인 경우에만 1로 지정
            action = np.argmax(scores)
            scores = np.zeros(len(scores))
            scores[action] = 1

        else: # 볼츠만 분포를 기반으로 분산 추가
            scores = self.boltzman(scores, self.temperature)

        return scores


    def boltzman(self, xs, temperature):
        '''
        볼츠만 분포
        '''
        xs = [x ** (1/temperature) for x in xs]
        return [x/sum(xs) for x in xs]


    def get_action(self, state):
        '''
        MCTS를 통해 얻은 policy에 따른 action 선택
        '''
        legal_actions = np.where(self.legal_actions != 0)[0]
        policy = self.get_policy(state)
        action = np.random.choice(legal_actions, p=policy)
        return policy, action

# 04. Self-play

In [118]:
# parameter
SP_GAME_COUNT = 3  # 셀프 플레이를 수행할 게임 수(오리지널: 25,000)
SP_TEMPERATURE = 1.0  # 볼츠만 분포의 온도 파라미터

CONV_UNITS = 64

state_size = (3,3)
env = Environment(state_size)

In [144]:
# 1번의 게임 play 함수
def play_one_game(model):
    env.reset()
    done = False
    history = []

    while not done:
        state = env.present_state.copy()
        legal_actions = np.where(env.check_legal_action(env.present_state))
        # state 순서: player = True 기준, (player_state, enemy_state) 고정
        # env.render(state) # test를 위해
        state[[0, 1]] = state[[1, 0]] if not env.player else state[[0, 1]]

        mcts = Mcts(env, model, state, temperature = SP_TEMPERATURE)
        scores, action = mcts.get_action(state)
        _, reward, done, _ = env.step(action)

        policies = [0.0]*env.num_actions

        for action, policy in zip(*legal_actions, scores):
            policies[action] = policy
        # print(done)

        history.append((state, policies))

    history = [(x[0], x[1], reward) for x in history]

    return history

In [145]:
# self play 함수
def self_play(model):
    env = Environment(state_size)
    data = []

    for i in range(SP_GAME_COUNT):
        history = play_one_game(model)
        data.extend(history)
        if i % 10 == 0:
            print(f"game {i+1}")

    return data

## Test

In [146]:
state_size = (3,3)
env = Environment(state_size)
model = Net(state_size, env.num_actions, CONV_UNITS)

In [147]:
hist = self_play(model)

game 1


In [148]:
print(len(hist))
print(hist[-3])

25
(array([[[-1.,  0.,  0.],
        [ 0.,  0.,  0.],
        [ 0., -1., -1.]],

       [[ 0.,  0., -1.],
        [-1.,  0.,  0.],
        [-1.,  0.,  0.]]]), [0.0, 0.2222222222222222, 0.0, 0.0, 0.3333333333333333, 0.4444444444444444, 0.0, 0.0, 0.0], 1)


# 05. Evaluate Network

In [149]:
# parameter
NUM_GAME = 10
TEMPERATURE = 1.0 # 볼츠만 분포

file_name = "model1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [150]:
# 1 game play하는 함수
def play_game(mcts_list):
    env = Environment(state_size)

    while not env.done:
        state = env.present_state.copy()

        next_player = mcts_list[0] if env.player else mcts_list[1]
        _, action = next_player.get_action()
        _, reward, _, _ = env.step(action)

    point = 1 if reward==1 else 0.5 if reward==-1 else 0
    return point # first player point

In [151]:
# network 평가하는 함수
def evaluate_network():
    model_latest = Net(state_size, env.num_actions, CONV_UNITS).to(device)
    model_best = Net(state_size, env.num_actions, CONV_UNITS).to(device)

    with open(f'{file_name}_model_latest.pkl', 'rb') as f:
        model_latest.load_state_dict(pickle.load(f))

    with open(f'{file_name}_model_best.pkl') as f:
        model_best.load_state_dict(pickle.load(f))

    mcts_latest = Mcts(env, model_latest, env.present_state, TEMPERATURE)
    mcts_best = Mcts(env, model_best, env.present_state, TEMPERATURE)

    mcts_list = [mcts_latest, mcts_best]

    # 대전
    total_point = 0
    for i in range(NUM_GAME):
        # 선 플레이어를 교대하면서 대전
        if i % 2 == 0: # first player: latest
            point = play_game(mcts_list)

        else: # first player: best
            mcts_list[[0, 1]] = mcts_list[[1, 0]]
            point = 1 - play_game(mcts_list) # latest의 point

        total_point += point

    average_point = total_point/NUM_GAME
    print(f"Average point: {average_point}")

    # best player 교체
    if average_point > 0.5:
        with open(f'{file_name}_model_best.pkl') as f:
            pickle.dump(model_latest, f)

        return True

    else:
        return False

# 06. Train

In [228]:
# parameter
TRAIN_NUM = 10
TRAIN_EPOCHS = 100  # 학습 횟수
BATCHSIZE = 64
LEARN_MAX = 0.001

SP_GAME_COUNT = 10

file_name = "model1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

state_size = (3,3)
env = Environment(state_size)

CONV_UNITS = 64
model = Net(state_size, env.num_actions, CONV_UNITS).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARN_MAX, weight_decay=1e-5)

In [253]:
# loss function 정의
def loss_function(pred_policy, pred_value, y):
    mse = F.mse_loss(pred_policy, y[:, :-1])
    cross_entrophy = -torch.mean(y[:, -2:-1] * torch.log(pred_value))
    return mse + cross_entrophy

In [254]:
# dataset을 만드는 함수
def make_dataset(history):
    batch_size = min(BATCHSIZE, len(history))
    mini_batch = random.sample(history, batch_size)
    states, policies, results = zip(*mini_batch) # history에서 어떻게 생겼는지 봐야함
    policies = np.array(policies)
    results = np.array(results).reshape(-1, 1)
    Y_array = np.concatenate([policies, results], axis=1)

    X = torch.tensor(states, dtype=torch.float32).to(device)
    Y = torch.tensor(Y_array, dtype=torch.float32).to(device)
    return X, Y

In [255]:
# network train하는 함수
def train_network():
    with open(f'{file_name}_history.pkl', 'rb') as f:
        history = pickle.load(f)

    for i in range(TRAIN_EPOCHS):
        X, Y = make_dataset(history)
        # 우선 그냥 현재의 model로 학습한다는 느낌...
        pred_policy, pred_value = model.forward(X)
        loss = loss_function(pred_policy, pred_value, Y)
        # 역전파
        optimizer.zero_grad()
        loss.requires_grad_(True)
        loss.backward()
        optimizer.step()

    # 최근 모델 저장
    with open(f'{file_name}_model_latest.pkl', 'wb') as f:
        pickle.dump(model.state_dict(), f)

    # lr,... epoch,... 조절...

In [271]:
# network train하는 함수
def train_network(model, history):
    for i in range(TRAIN_EPOCHS):
        X, Y = make_dataset(history)
        # 우선 그냥 현재의 model로 학습한다는 느낌...
        pred_policy, pred_value = model.forward(X)
        loss = loss_function(pred_policy, pred_value, Y)
        # 역전파
        optimizer.zero_grad()
        loss.requires_grad_(True)
        loss.backward()
        optimizer.step()

    # lr,... epoch,... 조절...

In [272]:
state_size = (3,3)
env = Environment(state_size)

CONV_UNITS = 64
model = Net(state_size, env.num_actions, CONV_UNITS).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LEARN_MAX)

# network train cycle
for i in range(TRAIN_NUM):
    history = self_play(model)
    train_network(model, history)

    env = Environment(state_size)
    env.reset()

    while not env.done:
        if env.player:
            state = env.present_state.copy()
            state = torch.tensor(state, dtype = torch.float32)
            state = state.unsqueeze(0)

            policy, value = model.forward(state)
            policy = policy.detach().numpy()
            policy = policy * env.check_legal_action(env.present_state)
            print(policy)
            action = np.argmax(policy)
        else:
            legal_actions = np.where(env.check_legal_action(env.present_state)!=0)[0]
            action = random.choice(legal_actions)

        _, reward, done, is_win = env.step(action)

        print(reward, done, is_win, env.player)
        env.render(env.present_state)

    # update_best_player = evaluate_network()

    # if update_best_player:
    #     evaluate_best_player()

game 1
[[0.09430505 0.08764277 0.09351463 0.09197062 0.12300365 0.14932774
  0.09552693 0.12627751 0.13843104]]
0 False False False
. . . 
. . X 
. . .
----------
0 False False True
. . . 
. O X 
. . .
----------
[[0.09430505 0.08764277 0.09351463 0.09197062 0.         0.
  0.09552693 0.12627751 0.13843104]]
0 False False False
. . . 
. O X 
. . X
----------
0 False False True
. . . 
. O X 
O . X
----------
[[0.09430505 0.08764277 0.09351463 0.09197062 0.         0.
  0.         0.12627751 0.        ]]
0 False False False
. . . 
. O X 
O X X
----------
-1 True True True
. . O 
. O X 
O X X
----------
game 1
[[0.09577799 0.07429262 0.10749502 0.10580871 0.12213466 0.14040369
  0.08962057 0.13660912 0.1278577 ]]
0 False False False
. . . 
. . X 
. . .
----------
0 False False True
. . . 
. . X 
. . O
----------
[[0.09577799 0.07429262 0.10749502 0.10580871 0.12213466 0.
  0.08962057 0.13660912 0.        ]]
0 False False False
. . . 
. . X 
. X O
----------
0 False False True
O . . 
. . X

KeyboardInterrupt: 

In [269]:
env = Environment(state_size)
env.reset()

while not env.done:
    if env.player:
        state = env.present_state.copy()
        state = torch.tensor(state, dtype = torch.float32)
        state = state.unsqueeze(0)

        policy, value = model.forward(state)
        policy = policy.detach().numpy()
        policy = policy * env.check_legal_action(env.present_state)
        print(policy)
        action = np.argmax(policy)
    else:
        legal_actions = np.where(env.check_legal_action(env.present_state)!=0)[0]
        action = random.choice(legal_actions)

    _, reward, done, is_win = env.step(action)

    print(reward, done, is_win, env.player)
    env.render(env.present_state)

[[[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]

 [[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]]
[[0.11440203 0.08727124 0.08985619 0.10936214 0.10744875 0.11670816
  0.1109485  0.10561733 0.15838566]]
0 False False False
. . . 
. . . 
. . X
----------
0 False False True
. . O 
. . . 
. . X
----------
[[[ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0. -1.]]

 [[ 0.  0. -1.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]]]
[[0.11440203 0.08727124 0.         0.10936214 0.10744875 0.11670816
  0.1109485  0.10561733 0.        ]]
0 False False False
. . O 
. . X 
. . X
----------
0 False False True
. . O 
O . X 
. . X
----------
[[[ 0.  0.  0.]
  [ 0.  0. -1.]
  [ 0.  0. -1.]]

 [[ 0.  0. -1.]
  [-1.  0.  0.]
  [ 0.  0.  0.]]]
[[0.11440203 0.08727124 0.         0.         0.10744875 0.
  0.1109485  0.10561733 0.        ]]
0 False False False
X . O 
O . X 
. . X
----------
0 False False True
X O O 
O . X 
. . X
----------
[[[-1.  0.  0.]
  [ 0.  0. -1.]
  [ 0.  0. -1.]]

 [[ 0. -1. -1.]
  [-1.  0.  0.]
  [ 0.  0.  0.]]]
[[0.    