In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
from math import exp
from collections import namedtuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


## Hyperparameters

In [None]:
# State #
BOARD_SIZE = (3,3)

# Agent #
C_PUCT_INIT = 2.5
C_PUCT_DECAY = 0.05
C_PUCT_MIN = 1.0

TAU_INIT = 1.0
TAU_DECAY = 0.05
TAU_MIN = 1e-5

NUM_OF_SIMULATION = 200

# Main #
NUM_EPOCHS = 2 # (self play 데이터 이용한) 학습 횟수
NUM_GAMES = 10 # self play 횟수
BATCH_SIZE = 1

LEARNING_RATE = 0.0005
WEIGHT_DECAY = 1e-4

## State

In [None]:
class State:
    def __init__(self, board_size=BOARD_SIZE, my_actions=None, enemy_actions=None):
        self.board_size = board_size # (3,3)
        self.num_actions = self.board_size[0] * self.board_size[1] # 3 * 3 = 9
        self.action_space = range(self.num_actions)

        self.my_actions = [] if my_actions is None else my_actions
        self.enemy_actions = [] if enemy_actions is None else enemy_actions

        self.board = self._create_board(self.my_actions, self.enemy_actions)

        self.available_actions = self._get_available_actions()

    def next(self, action):
        '''
        내 행동 이후 상대방 턴으로 변경
        '''
        my_actions = self.my_actions.copy()
        my_actions.append(action)

        return State(self.board_size, self.enemy_actions, my_actions)

    def _create_board(self, my_actions, enemy_actions):
        total_board = np.zeros((2, self.board_size[0], self.board_size[1]))

        my_board = np.zeros(self.board_size).flatten()
        enemy_board = np.zeros(self.board_size).flatten()

        my_board[my_actions] = 1
        enemy_board[enemy_actions] = 1

        total_board[0] = my_board.reshape(self.board_size)
        total_board[1] = enemy_board.reshape(self.board_size)

        return total_board

    def _get_available_actions(self):
        my_actions_set = set(self.my_actions)
        enemy_actions_set = set(self.enemy_actions)

        available_actions_set = set(self.action_space) - my_actions_set - enemy_actions_set

        return list(available_actions_set)

    def is_win(self):
        my_state = self.board[0]

        row_win = np.sum(my_state, axis=0).max() == self.board_size[0]
        col_win = np.sum(my_state, axis=1).max() == self.board_size[1]
        diag_win = np.trace(my_state) == self.board_size[0]
        anti_diag_win = np.trace(np.fliplr(my_state)) == self.board_size[0]

        return row_win or col_win or diag_win or anti_diag_win

    def is_draw(self):
        return (np.sum(self.board[0]) + np.sum(self.board[1])) >= self.num_actions

    def is_lose(self):
        enemy_state = self.board[1]

        row_lose = np.sum(enemy_state, axis=0).max() == self.board_size[0]
        col_lose = np.sum(enemy_state, axis=1).max() == self.board_size[1]
        diag_lose = np.trace(enemy_state) == self.board_size[0]
        anti_diag_lose = np.trace(np.fliplr(enemy_state)) == self.board_size[0]

        return row_lose or col_lose or diag_lose or anti_diag_lose

    def is_done(self):
        return self.is_win() or self.is_draw() or self.is_lose()

    def is_going_first(self):
        return len(self.my_actions) == len(self.enemy_actions)

## Environment

In [None]:
class TicTacToeEnv:
    def __init__(self):
        self.state = State()
        self.reward = {'win': 10, 'lose': -10, 'draw': 0, 'continue': 0}

    def reset(self):
        self.state = State()
        return self.state

    def step(self, action):
        my_actions = self.state.my_actions.copy()
        enemy_actions = self.state.enemy_actions.copy()

        my_actions.append(action)

        next_state = State(BOARD_SIZE, my_actions, enemy_actions) # 현재 턴 끝난 상태 반영
        self.state = State(BOARD_SIZE, self.state.enemy_actions, my_actions) # 현재 내 턴 끝나고 상대 턴으로 변경

        if next_state.is_win():
            reward, done = self.reward['win'], True

        elif next_state.is_draw():
            reward, done = self.reward['draw'], True

        elif next_state.is_lose():
            reward, done = self.reward['lose'], True

        else:
            reward, done = self.reward['continue'], False

        return self.state, next_state, reward, done

    def render(self, state):
        board = state.board[0] - state.board[1] if state.is_going_first() else state.board[1] - state.board[0]
        int_to_symbol = np.vectorize({1: 'O', -1: 'X', 0: '-'}.__getitem__)
        rendering_board = int_to_symbol(board)
        print("\n" + "\n".join(" ".join(row) for row in rendering_board) + "\n")

## Net

In [None]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, downsample=False):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        if downsample or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.downsample = None

    def forward(self, x):
        identity = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)

        if self.downsample is not None:
            identity = self.downsample(identity)

        x += identity
        x = self.relu(x)

        return x

In [None]:
class ResNet(nn.Module):
    def __init__(self, config, zero_init_residual=False):
        super().__init__()
        block, n_blocks, channels = config
        self.in_channels = channels[0]

        # Initial Conv Layer
        self.conv1 = nn.Conv2d(2, self.in_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU(inplace=True)

        # Residual Layers
        self.layer1 = self.get_resnet_layer(block, n_blocks[0], channels[0])
        self.layer2 = self.get_resnet_layer(block, n_blocks[1], channels[1])

        # Global Average Pooling
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def get_resnet_layer(self, block, n_blocks, channels, stride=1):
        layers = []
        for i in range(n_blocks):
            if i == 0:
                layers.append(block(self.in_channels, channels, downsample=(self.in_channels != channels)))
            else:
                layers.append(block(channels, channels))
            self.in_channels = channels
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.avgpool(x)  # Output shape: (batch_size, channels, 1, 1)
        return x

In [None]:
ResNetConfig = namedtuple('ResNetConfig', ['block', 'n_blocks', 'channels'])

config = ResNetConfig(block=BasicBlock, n_blocks=[2, 2], channels=[16, 16])

class AlphaZeroNet(nn.Module):
    def __init__(self, board_size=(3, 3), config=config):
        super(AlphaZeroNet, self).__init__()
        _, _, channels = config
        self.board_size = board_size

        self.resnet = ResNet(config=config)

        # Policy Head
        self.policy_head = nn.Sequential(
            nn.Conv2d(channels[-1], 2, kernel_size=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(2, board_size[0] * board_size[1]),
            nn.Softmax(dim=1)
        )

        # Value Head
        self.value_head = nn.Sequential(
            nn.Conv2d(channels[-1], 1, kernel_size=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(1, 1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.resnet(x)  # Shape: (batch_size, channels, 1, 1)

        # Policy Head
        policy = self.policy_head(x)

        # Value Head
        value = self.value_head(x)

        return policy, value

## MCTS Node

In [None]:
class Node:
    def __init__(self, state, agent, parent=None, action=None, p=None):
        self.agent = agent
        self.parent_node = parent
        self.child_nodes = []
        self.state = state
        self.action = action
        self.n = 0 # 시행 횟수
        self.w = 0 # 보상 누계 (누적된 가치)
        self.p = p # 해당 액션을 선택할 확률 from policy net

    def expand(self):
        if not self.is_fully_expanded():
            self.child_nodes = [
                Node(
                    state=self.state.next(action),
                    agent=self.agent,
                    parent=self,
                    action=action,
                    p=self.agent.get_prior_prob(self.state, action)
                )
                for action in self.state.available_actions
            ]

    def select(self):
        if not self.is_fully_expanded():
            return self

        if not self.child_nodes:
            return None

        child_visits = np.array([child.n for child in self.child_nodes])
        child_values = np.array([-child.w / max(child.n, 1) for child in self.child_nodes])
        child_priors = np.array([child.p for child in self.child_nodes])

        q_values = child_values
        u_values = (
            self.agent.c_puct
            * child_priors
            * np.sqrt(self.n)
            / (1 + child_visits)
        )

        ucb_scores = q_values + u_values
        best_child_idx = np.argmax(ucb_scores)

        return self.child_nodes[best_child_idx]

    def backup(self, value):
        node = self
        while node is not None:
            node.n += 1
            node.w += value
            node = node.parent_node
            value = -value

    def is_fully_expanded(self):
        return len(self.child_nodes) == len(self.state.available_actions)

    def get_terminal_value(self):
        if self.state.is_lose():
            return -1  # 패배
        if self.state.is_draw():
            return 0  # 무승부
        return 1  # 승리

## Agent

In [None]:
class AlphaZeroAgent:
    def __init__(self, network, env):
        self.network = network
        self.env = env

        self.num_simulations = NUM_OF_SIMULATION
        self.tau = TAU_INIT # temperature
        self.c_puct = C_PUCT_INIT

    def get_policy(self, state):
        root = Node(state=state, agent=self)

        for _ in range(self.num_simulations):
            node = root

            # Selection
            while node.is_fully_expanded() and not node.state.is_done():
                node = node.select()

            value = 0

            # Expansion and Evaluation
            if not node.state.is_done():
                node.expand()
                value = self._predict_value(node.state)
            else:
                value = node.get_terminal_value()

            # Backpropagation
            node.backup(value)

        action_probs = np.zeros(self.env.state.num_actions)
        for child in root.child_nodes:
            action_probs[child.action] = child.n

        if self.tau == 0:
            best_action = np.argmax(action_probs)
            action_probs = np.zeros_like(action_probs)
            action_probs[best_action] = 1.0
        else:
            action_probs = action_probs ** (1.0 / self.tau)
            action_probs = action_probs / np.sum(action_probs)

        return action_probs

    def _compute_policy(self, root_node):
        policy = np.zeros(self.env.state.num_actions, dtype=np.float32)
        for child in root_node.child_nodes:
            policy[child.action] = child.n
        print()

        policy = np.clip(policy, 1e-8, None)

        if self.tau == 0:
            # tau가 0인 경우 argmax로 결정 (test)
            best_action = np.argmax(policy)
            final_policy = np.zeros_like(policy)
            final_policy[best_action] = 1.0
            return final_policy

        # tau > 0인 경우 정규화
        tau = max(self.tau, 1e-8)
        policy = policy ** (1 / tau)
        return policy / np.sum(policy)

    def _simulate(self, node):
        if node.state.is_done():
            return node, node.get_terminal_value()

        if not node.is_fully_expanded():
            node.expand()
            value = self._predict_value(node.state)
            return node, value

        best_child = node.select()
        # 재귀적으로 시뮬레이션 수행하되, 자식 노드의 관점에서는 value가 반전됨
        leaf_node, value = self._simulate(best_child)
        return leaf_node, -value  # 부모 노드 관점에서는 value를 반전

    def get_prior_prob(self, state, action):
        with torch.no_grad():
            state_tensor = torch.tensor(state.board, dtype=torch.float32).unsqueeze(0).to(device)
            policy, _ = self.network(state_tensor)
            policy = policy.squeeze(0).cpu()
        return policy[action]

    def _predict_value(self, state):
        with torch.no_grad():
            state_tensor = torch.tensor(state.board, dtype=torch.float32).unsqueeze(0).to(device)
            _, value = self.network(state_tensor)
        return value.item()

    def random_action(self, state):
        return np.random.choice(state.available_actions)

## Self-play

(알파제로에서) 셀프플레이 25,000 게임, MCTS 시뮬레이션 1,600번

각 게임마다 처음 30번은 temp = 1, 나머지 move는 temp → 0

In [None]:
def self_play(env, agent):
    states, mcts_policies, values = [], [], []
    state = env.reset()

    while not state.is_done():
        policy = agent.get_policy(state)
        action = np.random.choice(len(policy), p=policy)
        states.append(state.board)
        mcts_policies.append(policy)
        state, _, _, _ = env.step(action)

    if state.is_win():
        outcome = 1
    elif state.is_draw():
        outcome = 0
    else:
        outcome = -1

    for _ in range(len(states)):
        values.append(outcome)
        outcome = -outcome  # 상대방의 결과는 반대가 됨

    return states, mcts_policies, values

## Train

In [None]:
def train(network, optimizer, data, batch_size=20):
    network.train()
    total_loss = 0

    states = torch.tensor(np.concatenate([d[0] for d in data]), dtype=torch.float32).to(device)
    policies = torch.tensor(np.concatenate([d[1] for d in data]), dtype=torch.float32).to(device)
    values = torch.tensor(np.concatenate([d[2] for d in data]), dtype=torch.float32).to(device)

    # 데이터셋과 DataLoader 생성
    dataset = TensorDataset(states, policies, values)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    for batch_states, batch_policies, batch_values in dataloader:
        batch_states = batch_states.to(device)
        batch_policies = batch_policies.to(device)
        batch_values = batch_values.to(device)

        pred_policies, pred_values = network(batch_states)

        pred_policies = torch.softmax(pred_policies, dim=-1)
        policy_loss = -torch.sum(batch_policies * torch.log(pred_policies + 1e-6), dim=1).mean()
        value_loss = ((batch_values - pred_values.view(-1)) ** 2).mean()

        loss = policy_loss + value_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

## Evaluate

→ 메서드 이름 변경 (기존의 best net이랑 현재 net 대결해서 현재 net 승률 (55%) 넘으면 best net 업데이트)

+ temp = 0 으로 설정해야 함

In [None]:
def evaluate_network(best_net, current_net, env, num_games=50):
    current_agent = AlphaZeroAgent(current_net, env)
    best_agent = AlphaZeroAgent(best_net, env)

    current_net_wins, best_net_wins, draws = 0, 0, 0

    for i in range(num_games):
        state = env.reset()

        while not state.is_done():
            if state.is_going_first():
                # 현재 네트워크의 행동
                policy = current_agent.get_policy(state)
            else:
                # 이전 베스트 네트워크의 행동
                policy = best_agent.get_policy(state)

            action = np.random.choice(len(policy), p=policy)
            state, _, _, _ = env.step(action)

        # 게임 결과 확인
        if state.is_win():
            if state.is_going_first():
                current_net_wins += 1
            else:
                best_net_wins += 1
        elif state.is_draw():
            draws += 1
        else:
            if state.is_going_first():
                best_net_wins += 1
            else:
                current_net_wins += 1

    print(f"Evaluation: Current Net Wins: {current_net_wins}, Best Net Wins: {best_net_wins}, Draws: {draws}")
    return current_net_wins, best_net_wins

## Main

In [None]:
 from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
def save_checkpoint(model, optimizer, epoch, file_path):
    torch.save({
        'model_state': model.state_dict(),
        'optimizer_state': optimizer.state_dict(),
        'epoch': epoch
    }, file_path)
    print(f"Checkpoint saved\n")

In [None]:
def load_checkpoint(model, optimizer, file_path):
    if os.path.exists(file_path):
        checkpoint = torch.load(file_path)
        model.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optimizer_state'])
        start_epoch = checkpoint['epoch'] + 1  # 다음 에포크부터 시작
        print(f"Checkpoint loaded. Resuming from epoch {start_epoch}")
    else:
        start_epoch = 0
        print("No checkpoint found. Starting from scratch.")
    return start_epoch

In [None]:
try_num = input()

env = TicTacToeEnv()
net = AlphaZeroNet().to(device)
agent = AlphaZeroAgent(net, env)
optimizer = optim.Adam(net.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

best_net = AlphaZeroNet().to(device)
best_net.load_state_dict(net.state_dict())

folder_path = '/content/drive/MyDrive/Colab Notebooks/kanghwa'
file_path = os.path.join(folder_path, f'checkpoint{try_num}.pth')
start_epoch = load_checkpoint(net, optimizer, file_path)

losses = []
win_rates = []

for epoch in range(start_epoch, NUM_EPOCHS):
    data = []
    for i in range(NUM_GAMES):
        if i <= (NUM_GAMES * 0.1):
            agent.c_puct = C_PUCT_INIT
            agent.tau = TAU_INIT
        else:
            agent.c_puct = max(C_PUCT_MIN, C_PUCT_INIT*exp(-C_PUCT_DECAY*i))
            agent.tau = max(TAU_MIN, TAU_INIT*exp(-TAU_DECAY*i))

        ###
        if (i + 1) % 10 == 0:
            print(f"C_puct : {agent.c_puct}")
            print(f"Tau : {agent.tau}")
            print(f"{i + 1}/{NUM_GAMES}\n")
        ###

        game_data = self_play(env, agent)
        data.append(game_data)

    loss = train(net, optimizer, data, batch_size=BATCH_SIZE)
    losses.append(loss)
    print(f"Epoch {epoch + 1}/{NUM_EPOCHS}, Loss: {loss}")

    if (epoch + 1) % 2 == 0:
        current_net_wins, best_net_wins = evaluate_network(best_net, net, env, NUM_GAMES)
        win_rate = current_net_wins / NUM_GAMES
        win_rates.append(win_rate)

        if win_rate >= 0.55:
            print("Update best network.")
            best_net.load_state_dict(net.state_dict())

    save_checkpoint(best_net, optimizer, epoch, file_path)

잘됨
Checkpoint loaded. Resuming from epoch 2


  checkpoint = torch.load(file_path)


In [None]:
plt.figure(figsize=(12, 6))
plt.plot(range(1, len(losses) + 1), losses, label='Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Loss Over Epochs')
plt.legend()
plt.grid()
plt.show()

plt.figure(figsize=(12, 6))
plt.plot(range(2, len(win_rates) * 2 + 1, 2), win_rates, label='Win Rate')
plt.xlabel('Epochs')
plt.ylabel('Win Rate')
plt.title('Win Rate Over Epochs')
plt.legend()
plt.grid()
plt.show()

## Test

In [None]:
def test(agent, env, num_games=10, human_play=False, random_first=True):
    player = ''

    agent_wins = 0
    draws = 0
    losses = 0

    agent.tau = 0

    for game in range(num_games):
        print(f"GAME {game + 1}/{num_games}")
        state = env.reset()

        # 게임 시작 전 랜덤하게 선공/후공 결정
        if random_first:
            agent_first = np.random.choice([True, False])
            print(f"===Agent goes {'first (O)' if agent_first else 'second (X)'}===")
        else:
            agent_first = True  # 기본값은 에이전트가 선공

        env.render(state)

        while not state.is_done():
            is_agent_turn = (state.is_going_first() == agent_first)

            if is_agent_turn:  # Agent's turn
                player = 'Agent\'s'
                policy = agent.get_policy(state)
                action = np.random.choice(len(policy), p=policy)
            else:  # Opponent's turn
                if human_play:
                    # 사람의 입력을 받아서 동작
                    player = 'Your'
                    print("0 1 2\n3 4 5\n6 7 8")
                    action = int(input("Enter your move : "))
                    while action not in state.available_actions:
                        action = int(input("Invalid move. Try again : "))
                else:
                    # 랜덤한 상대
                    player = 'Random player\'s'
                    action = np.random.choice(state.available_actions)

            state, _, _, _ = env.step(action)
            print(f"{player} turn")
            env.render(state)
            print("---------------------")

        # 게임 결과 처리
        if state.is_lose():
            if (state.is_going_first() == agent_first):
                print("Opponent wins!\n")
                losses += 1
            else:
                print("Agent wins!\n")
                agent_wins += 1
        elif state.is_draw():
            print("It's a draw!\n")
            draws += 1

    print("\nTest Results:")
    print(f"Agent Wins: {agent_wins}")
    print(f"Draws: {draws}")
    print(f"Losses: {losses}")

In [None]:
# Test
agent = AlphaZeroAgent(best_net, env)

In [None]:
test(agent, env, num_games=10, human_play=False)

GAME 1/10
===Agent goes second (X)===

- - -
- - -
- - -

Random player's turn

- - -
- - -
- O -

---------------------
Agent's turn

- X -
- - -
- O -

---------------------
Random player's turn

- X O
- - -
- O -

---------------------
Agent's turn

- X O
- X -
- O -

---------------------
Random player's turn

O X O
- X -
- O -

---------------------
Agent's turn

O X O
- X X
- O -

---------------------
Random player's turn

O X O
- X X
- O O

---------------------
Agent's turn

O X O
X X X
- O O

---------------------
Agent wins!

GAME 2/10
===Agent goes second (X)===

- - -
- - -
- - -

Random player's turn

- - -
- - -
- O -

---------------------
Agent's turn

- X -
- - -
- O -

---------------------
Random player's turn

- X -
O - -
- O -

---------------------
Agent's turn

X X -
O - -
- O -

---------------------
Random player's turn

X X O
O - -
- O -

---------------------
Agent's turn

X X O
O - X
- O -

---------------------
Random player's turn

X X O
O O X
- O -

----

In [None]:
test(agent, env, num_games=3, human_play=True)

GAME 1/3
===Agent goes second (X)===

- - -
- - -
- - -

0 1 2
3 4 5
6 7 8
Enter your move : 7
Your turn

- - -
- - -
- O -

---------------------
Agent's turn

- X -
- - -
- O -

---------------------
0 1 2
3 4 5
6 7 8
Enter your move : 2
Your turn

- X O
- - -
- O -

---------------------
Agent's turn

- X O
- X -
- O -

---------------------
0 1 2
3 4 5
6 7 8
Enter your move : 8
Your turn

- X O
- X -
- O O

---------------------
Agent's turn

- X O
- X X
- O O

---------------------
0 1 2
3 4 5
6 7 8
Enter your move : 6
Your turn

- X O
- X X
O O O

---------------------
Opponent wins!

GAME 2/3
===Agent goes first (O)===

- - -
- - -
- - -

Agent's turn

- O -
- - -
- - -

---------------------
0 1 2
3 4 5
6 7 8
Enter your move : 4
Your turn

- O -
- X -
- - -

---------------------
Agent's turn

O O -
- X -
- - -

---------------------
0 1 2
3 4 5
6 7 8
Enter your move : 2
Your turn

O O X
- X -
- - -

---------------------
Agent's turn

O O X
- X -
O - -

---------------------
0