# AlphaZero

In [None]:
#################
# 선공:  1, 'O' #
# 후공: -1, 'X' #
#################

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import math
import torch
import torch.nn as nn
import torch.optim as optim
from collections import namedtuple
from torch.utils.data import DataLoader, TensorDataset
import os

## State

In [None]:
BOARD_SIZE = (3,3)

In [None]:
class State:
    def __init__(self, board_size=BOARD_SIZE, my_actions=None, enemy_actions=None):
        self.board_size = board_size # (3,3)
        self.num_actions = self.board_size[0] * self.board_size[1] # 3 * 3 = 9
        self.action_space = range(self.num_actions)

        self.my_actions = [] if my_actions is None else my_actions
        self.enemy_actions = [] if enemy_actions is None else enemy_actions

        self.board = self.create_board(self.my_actions, self.enemy_actions)

        self.available_actions = self.get_available_actions()

    def next(self, action):
        '''
        내 행동 이후 상대방 턴으로 변경
        '''
        my_actions = self.my_actions.copy()
        my_actions.append(action)

        return State(self.board_size, self.enemy_actions, my_actions)

    def create_board(self, my_actions, enemy_actions):
        total_board = np.zeros((2,self.board_size[0],self.board_size[1]))

        my_board = np.zeros(self.board_size).flatten()
        enemy_board = np.zeros(self.board_size).flatten()

        my_board[my_actions] = 1
        enemy_board[enemy_actions] = 1

        total_board[0] = my_board.reshape(self.board_size)
        total_board[1] = enemy_board.reshape(self.board_size)

        return total_board

    def get_available_actions(self):
        my_actions_set = set(self.my_actions)
        enemy_actions_set = set(self.enemy_actions)

        available_actions_set = set(range(self.num_actions)) - my_actions_set - enemy_actions_set

        return list(available_actions_set)

    def is_win(self):
        my_state = self.board[0]

        row_win = np.sum(my_state, axis=0).max() == self.board_size[0]
        col_win = np.sum(my_state, axis=1).max() == self.board_size[1]
        diag_win = np.trace(my_state) == self.board_size[0]
        anti_diag_win = np.trace(np.fliplr(my_state)) == self.board_size[0]

        return row_win or col_win or diag_win or anti_diag_win

    def is_draw(self):
        return (np.sum(self.board[0]) + np.sum(self.board[1])) >= self.num_actions

    def is_lose(self):
        enemy_state = self.board[1]

        row_lose = np.sum(enemy_state, axis=0).max() == self.board_size[0]
        col_lose = np.sum(enemy_state, axis=1).max() == self.board_size[1]
        diag_lose = np.trace(enemy_state) == self.board_size[0]
        anti_diag_lose = np.trace(np.fliplr(enemy_state)) == self.board_size[0]

        return row_lose or col_lose or diag_lose or anti_diag_lose

    def is_done(self):
        return self.is_win() or self.is_draw() or self.is_lose()

    def is_going_first(self):
        return len(self.my_actions) == len(self.enemy_actions)

## Environment

In [None]:
class TicTacToeEnv:
    def __init__(self):
        self.state = State()

        self.reward = {'win': 10, 'lose': -10, 'draw': 0, 'continue': 0}

    def reset(self):
        self.state = State()
        return self.state

    def step(self, action):
        my_actions = self.state.my_actions.copy()
        enemy_actions = self.state.enemy_actions.copy()

        my_actions.append(action)

        next_state = State(BOARD_SIZE, my_actions, enemy_actions)
        self.state = State(BOARD_SIZE, self.state.enemy_actions, my_actions) # 다음 스텝 - 상대방 턴

        if next_state.is_win():
            reward, done = self.reward['win'], True

        elif next_state.is_draw():
            reward, done = self.reward['draw'], True

        elif next_state.is_lose():
            reward, done = self.reward['lose'], True

        else:
            reward, done = self.reward['continue'], False

        return self.state, next_state, reward, done

    def reset(self):
        self.state = State()
        return self.state

    def render(self, state):
        board = state.board[0] + (-1 * state.board[1]) if state.is_going_first() else state.board[1] + (-1 * state.board[0])

        int_to_symbol = np.where(board == 1, 'O',
                                np.where(board == -1, 'X', '-'))

        rendering_board = '\n'.join([' '.join(row) for row in int_to_symbol])

        print()
        print(rendering_board)
        print()

## Net

In [None]:
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, downsample=False):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        if downsample or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.downsample = None

    def forward(self, x):
        identity = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)

        if self.downsample is not None:
            identity = self.downsample(identity)

        x += identity
        x = self.relu(x)

        return x

In [None]:
class ResNet(nn.Module):
    def __init__(self, config, zero_init_residual=False):
        super().__init__()
        block, n_blocks, channels = config
        self.in_channels = channels[0]

        # Initial Conv Layer
        self.conv1 = nn.Conv2d(2, self.in_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU(inplace=True)

        # Residual Layers
        self.layer1 = self.get_resnet_layer(block, n_blocks[0], channels[0])
        self.layer2 = self.get_resnet_layer(block, n_blocks[1], channels[1])

        # Global Average Pooling
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def get_resnet_layer(self, block, n_blocks, channels, stride=1):
        layers = []
        for i in range(n_blocks):
            if i == 0:
                layers.append(block(self.in_channels, channels, downsample=(self.in_channels != channels)))
            else:
                layers.append(block(channels, channels))
            self.in_channels = channels
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.avgpool(x)  # Output shape: (batch_size, channels, 1, 1)
        return x

In [None]:
ResNetConfig = namedtuple('ResNetConfig', ['block', 'n_blocks', 'channels'])

In [None]:
config = ResNetConfig(block=BasicBlock, n_blocks=[2, 2], channels=[64, 64])

In [None]:
class AlphaZeroNet(nn.Module):
    def __init__(self, board_size=(3, 3), config=config):
        super(AlphaZeroNet, self).__init__()
        _, _, channels = config
        self.board_size = board_size

        self.resnet = ResNet(config=config)

        # Policy Head
        self.policy_head = nn.Sequential(
            nn.Conv2d(channels[-1], 2, kernel_size=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2, board_size[0] * board_size[1]),
            nn.Softmax(dim=1)
        )

        # Value Head
        self.value_head = nn.Sequential(
            nn.Conv2d(channels[-1], 1, kernel_size=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(1, 1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.resnet(x)  # Shape: (batch_size, channels, 1, 1)
        policy = self.policy_head(x)  # Policy Head
        value = self.value_head(x)   # Value Head
        return policy, value

## MCTS Node

In [None]:
class Node:
    def __init__(self, state, agent, parent=None, action=None):
        self.state = state
        self.agent = agent
        self.parent_node = parent
        self.child_nodes = None
        self.action = action  # 현재 노드로 이동할 때의 액션
        self.w = 0  # 보상 누계
        self.n = 0  # 시행 횟수

    def expand(self):
        available_actions = self.state.available_actions
        self.child_nodes = []
        for action in available_actions:
            child_state = self.state.next(action)
            child_node = Node(child_state, self.agent, parent=self, action=action)  # action 전달
            self.child_nodes.append(child_node)

    def get_result(self):
        if self.state.is_lose():
            return -2  # 패배 (상대방 승리) - 값 증가시킴
        elif self.state.is_draw():
            return 0  # 무승부
        else:
            return 1  # 승리

    # UCB1이 가장 큰 child node 얻기
    def next_child_node(self):
        # 시행 횟수가 0인 child node 반환
        for child_node in self.child_nodes:
            if child_node.n == 0:
                return child_node

        # UCB1 계산
        t = 0
        for c in self.child_nodes:
            t += c.n
        ucb1_values = []
        for child_node in self.child_nodes:
            ucb1_values.append(-child_node.w/child_node.n+(2*math.log(t)/child_node.n)**0.5)

        # UCB1이 가장 큰 child node 반환
        return self.child_nodes[self.agent.argmax(ucb1_values)]

    def backpropagate(self, value):
        node = self
        while node is not None:
            node.w += value
            node.n += 1
            node = node.parent_node

## Agent

In [None]:
NUM_OF_SIMULATION = 200
TAU_START = 1.0
TAU_END = 0.3
DECAY_STEPS = 20

In [None]:
class AlphaZeroAgent:
    def __init__(self, network, env):
        self.network = network
        self.env = env
        self.num_simulations = NUM_OF_SIMULATION
        self.tau_start = TAU_START
        self.tau_end = TAU_END
        self.decay_steps = DECAY_STEPS
        self.current_step = 0

    def get_temperature(self):
        tau = max(self.tau_end, self.tau_start - (self.tau_start - self.tau_end) * (self.current_step / self.decay_steps))

        return tau

    def run_mcts(self, root_state):
        root = Node(root_state, agent=self)

        if not root.child_nodes:  # 자식 노드가 없는 경우
            root.expand()

        for _ in range(self.num_simulations):
            leaf_node, value = self.simulate(root)
            leaf_node.backpropagate(value)

        # 방문 횟수를 기반으로 정책 계산
        policy = np.zeros(root_state.num_actions, dtype=np.float32)
        for child in root.child_nodes:
            policy[child.action] = child.n

        tau = self.get_temperature()
        policy = np.clip(policy, 1e-4, None)  # 너무 작은 값을 클립
        policy = policy ** (1 / tau)
        if np.sum(policy) > 0:
            policy /= np.sum(policy)  # 정규화
        else:
            policy = np.ones_like(policy) / len(policy)  # 모든 액션에 균등한 확률 할당


        self.current_step += 1  # 학습 진행 업데이트
        # policy 계산 후 확인
        # if np.sum(policy) != 1:
        #     print(f"Warning: Sum of policy probabilities is not 1. Actual sum: {np.sum(policy)}")
        # print(f"MCTS policy : {policy}\n")

        return policy

    def simulate(self, node):
        if node.state.is_done():
            return node, node.get_result()

        # 즉각적인 상대방의 승리를 방지
        if node.state.is_lose():
            return node, -1  # 최악의 결과 방지

        if node.child_nodes is None:
            node.expand()
            return node, self.predict_value(node.state)

        best_child = node.next_child_node()
        return self.simulate(best_child)


    def predict_value(self, state):
        with torch.no_grad():
            state_tensor = torch.tensor(state.board, dtype=torch.float32).unsqueeze(0)
            _, value = self.network(state_tensor)
        # print(f"MCTS value : {value.item()}")
        return value.item()

    def playout(self, state):
        if state.is_lose():
            return self.env.reward['lose']

        if state.is_draw():
            return  self.env.reward['draw']

        # 다음 상태의 상태 평가
        return -self.playout(state.next(self.random_action(state)))

    def random_action(self, state):
        available_actions = state.available_actions
        return np.random.choice(available_actions)

    def argmax(self, collection):
        max_idx_list = np.arange(len(collection))[collection == np.max(collection)]
        return np.random.choice(max_idx_list)

## Self-play

In [None]:
def self_play(env, agent):
    states, mcts_policies, values = [], [], []
    state = env.reset()

    while not state.is_done():
        policy = agent.run_mcts(state)
        action = np.random.choice(len(policy), p=policy)
        states.append(state.board)
        mcts_policies.append(policy)
        state, _, _, _ = env.step(action)

    if state.is_win():
        outcome = 1
    elif state.is_draw():
        outcome = 0
    else:
        outcome = -1

    for _ in range(len(states)):
        values.append(outcome)
        outcome = -outcome  # 상대방의 결과는 반대가 됨

    return states, mcts_policies, values

## Train

In [None]:
def train(network, optimizer, data):
    network.train()
    total_loss = 0
    for states, policies, values in data:
        states = torch.tensor(np.array(states), dtype=torch.float32)
        policies = torch.tensor(np.array(policies), dtype=torch.float32)
        values = torch.tensor(np.array(values), dtype=torch.float32)

        pred_policies, pred_values = network(states)

        # 상대방 위협 방지 페널티 추가
        pred_policies = torch.softmax(pred_policies, dim=-1)
        penalty = (policies * (pred_policies < 0.1).float()).sum(dim=1)
        policy_loss = -torch.sum(policies * torch.log(pred_policies + 1e-6)).mean() + penalty.mean()
        value_loss = ((values - pred_values.view(-1)) ** 2).mean()

        loss = policy_loss + value_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(data)

## Evaluate

In [None]:
def evaluate_network(best_net, current_net, env, num_games=50):
    current_agent = AlphaZeroAgent(current_net, env)
    best_agent = AlphaZeroAgent(best_net, env)

    current_net_wins, best_net_wins, draws = 0, 0, 0

    for i in range(num_games):
        state = env.reset()

        while not state.is_done():
            if state.is_going_first():
                # 현재 네트워크의 행동
                policy = current_agent.run_mcts(state)
            else:
                # 이전 베스트 네트워크의 행동
                policy = best_agent.run_mcts(state)

            action = np.random.choice(len(policy), p=policy)
            state, _, _, _ = env.step(action)

        # 게임 결과 확인
        if state.is_win():
            if state.is_going_first():
                current_net_wins += 1
            else:
                best_net_wins += 1
        elif state.is_draw():
            draws += 1
        else:
            if state.is_going_first():
                best_net_wins += 1
            else:
                current_net_wins += 1

    print(f"Evaluation: Current Net Wins: {current_net_wins}, Best Net Wins: {best_net_wins}, Draws: {draws}")
    return current_net_wins, best_net_wins

## Main

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
def save_checkpoint(model, optimizer, epoch, file_path):
    torch.save({
        'model_state': model.state_dict(),
        'optimizer_state': optimizer.state_dict(),
        'epoch': epoch
    }, file_path)
    print(f"Checkpoint saved at epoch {epoch} to {file_path}")

In [None]:
def load_checkpoint(model, optimizer, file_path):
    if os.path.exists(file_path):
        checkpoint = torch.load(file_path)
        model.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optimizer_state'])
        start_epoch = checkpoint['epoch'] + 1  # 다음 에포크부터 시작
        print(f"Checkpoint loaded. Resuming from epoch {start_epoch}")
    else:
        start_epoch = 0
        print("No checkpoint found. Starting from scratch.")
    return start_epoch

In [None]:
env = TicTacToeEnv()
net = AlphaZeroNet()
agent = AlphaZeroAgent(net, env)
optimizer = optim.Adam(net.parameters(), lr=0.0005)

best_net = AlphaZeroNet()
best_net.load_state_dict(net.state_dict())

# 체크포인트 로드
folder_path = '/content/drive/MyDrive/Colab Notebooks/kanghwa'
file_path = os.path.join(folder_path, 'checkpoint2.pth')
start_epoch = load_checkpoint(net, optimizer, file_path)

num_epochs = 16
num_games = 100

for epoch in range(start_epoch, num_epochs):
    data = [self_play(env, agent) for _ in range(num_games)]
    loss = train(net, optimizer, data)
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {loss}")

    if (epoch + 1) % 2 == 0:
        current_net_wins, best_net_wins = evaluate_network(best_net, net, env, num_games)

        if current_net_wins > best_net_wins:
            print("Update best network.")
            best_net.load_state_dict(net.state_dict())

    save_checkpoint(best_net, optimizer, epoch, file_path)

No checkpoint found. Starting from scratch.
Epoch 1/16, Loss: 12.014530134201049
Checkpoint saved at epoch 0 to /content/drive/MyDrive/Colab Notebooks/kanghwa/checkpoint2.pth
Epoch 2/16, Loss: 11.67139497756958
Evaluation: Current Net Wins: 98, Best Net Wins: 2, Draws: 0
Update best network.
Checkpoint saved at epoch 1 to /content/drive/MyDrive/Colab Notebooks/kanghwa/checkpoint2.pth
Epoch 3/16, Loss: 11.357373151779175
Checkpoint saved at epoch 2 to /content/drive/MyDrive/Colab Notebooks/kanghwa/checkpoint2.pth
Epoch 4/16, Loss: 11.41596399307251
Evaluation: Current Net Wins: 98, Best Net Wins: 2, Draws: 0
Update best network.
Checkpoint saved at epoch 3 to /content/drive/MyDrive/Colab Notebooks/kanghwa/checkpoint2.pth
Epoch 5/16, Loss: 11.40026351928711
Checkpoint saved at epoch 4 to /content/drive/MyDrive/Colab Notebooks/kanghwa/checkpoint2.pth
Epoch 6/16, Loss: 11.306687183380127
Evaluation: Current Net Wins: 94, Best Net Wins: 6, Draws: 0
Update best network.
Checkpoint saved at e

## Test

In [None]:
def test_agent(agent, env, num_games=10, human_play=False):
    agent_wins = 0
    draws = 0
    losses = 0

    for game in range(num_games):
        print(f"Game {game + 1}/{num_games}")
        state = env.reset()
        env.render(state)

        while not state.is_done():
            if state.is_going_first():  # Agent's turn
                policy = agent.run_mcts(state)
                action = np.random.choice(len(policy), p=policy)
            else:  # Opponent's turn
                if human_play:
                    # 사람의 입력을 받아서 동작
                    action = int(input("Enter your move (0-8): "))
                    while action not in state.available_actions:
                        action = int(input("Invalid move. Try again (0-8): "))
                else:
                    # 랜덤한 상대
                    action = np.random.choice(state.available_actions)

            state, _, _, _ = env.step(action)
            env.render(state)

        # 게임 결과가 끝난 후 처리
        if state.is_lose():
            if state.is_going_first():
                print("Opponent wins!")
                losses += 1
            else:
                print("Agent wins!")
                agent_wins += 1
        elif state.is_draw():
            print("It's a draw!")
            draws += 1

    print("\nTest Results:")
    print(f"Agent Wins: {agent_wins}")
    print(f"Draws: {draws}")
    print(f"Losses: {losses}")

In [None]:
# Test
agent = AlphaZeroAgent(best_net, env)

test_agent(agent, env, num_games=5, human_play=True)

Game 1/5

- - -
- - -
- - -


- - O
- - -
- - -

Enter your move (0-8): 4

- - O
- X -
- - -


O - O
- X -
- - -

Enter your move (0-8): 1

O X O
- X -
- - -


O X O
- X O
- - -

Enter your move (0-8): 8

O X O
- X O
- - X


O X O
O X O
- - X

Enter your move (0-8): 7

O X O
O X O
- X X

Opponent wins!
Game 2/5

- - -
- - -
- - -


- - -
- - -
- - O

Enter your move (0-8): 4

- - -
- X -
- - O


- - -
- X O
- - O

Enter your move (0-8): 2

- - X
- X O
- - O


- - X
O X O
- - O

Enter your move (0-8): 6

- - X
O X O
X - O

Opponent wins!
Game 3/5

- - -
- - -
- - -


- - -
- - -
- - O

Enter your move (0-8): 4

- - -
- X -
- - O


- - -
- X O
- - O

Enter your move (0-8): 2

- - X
- X O
- - O


- - X
O X O
- - O

Enter your move (0-8): 6

- - X
O X O
X - O

Opponent wins!
Game 4/5

- - -
- - -
- - -


O - -
- - -
- - -

Enter your move (0-8): 4

O - -
- X -
- - -


O - O
- X -
- - -

Enter your move (0-8): 1

O X O
- X -
- - -


O X O
- X O
- - -

Enter your move (0-8): 7

O X O
- X O
-

In [None]:
test_agent(agent, env, num_games=10, human_play=False)

Game 1/10

- - -
- - -
- - -


O - -
- - -
- - -


O - -
- - X
- - -


O O -
- - X
- - -


O O -
- - X
X - -


O O O
- - X
X - -

Agent wins!
Game 2/10

- - -
- - -
- - -


- O -
- - -
- - -


X O -
- - -
- - -


X O -
- - -
- O -


X O -
- - -
- O X


X O -
- O -
- O X

Agent wins!
Game 3/10

- - -
- - -
- - -


- - -
- - -
- - O


- - X
- - -
- - O


O - X
- - -
- - O


O - X
- - -
- X O


O - X
- O -
- X O

Agent wins!
Game 4/10

- - -
- - -
- - -


O - -
- - -
- - -


O - -
X - -
- - -


O O -
X - -
- - -


O O -
X - -
- - X


O O O
X - -
- - X

Agent wins!
Game 5/10

- - -
- - -
- - -


- O -
- - -
- - -


- O X
- - -
- - -


- O X
- - O
- - -


- O X
- X O
- - -


O O X
- X O
- - -


O O X
X X O
- - -


O O X
X X O
- O -


O O X
X X O
- O X


O O X
X X O
O O X

It's a draw!
Game 6/10

- - -
- - -
- - -


- - -
O - -
- - -


- - -
O - -
- - X


- - -
O - -
O - X


- - -
O - -
O X X


O - -
O - -
O X X

Agent wins!
Game 7/10

- - -
- - -
- - -


O - -
- - -
- - -


O - -
- - -
X - 