In [1]:
import numpy as np
from games.tictactoe import TicTacToe
from games.connectfour import ConnectFour
from models.mcts import MCTS, MCTSParallel
from models.resnet import ResNet
from models.deepzero import DeepZero, DeepZeroParallel
import torch
from tqdm import tqdm
from tqdm import trange
import random
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt

# tictactoe game

In [2]:
game = TicTacToe()

device = torch.device("cpu")

model = ResNet(game, 4, 32, device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

player = 1

args = {
    'C': 2,
    'num_search': 100,
    'num_iterations': 3,
    'num_parallel_games': 100,
    'batch_size': 16,
    'num_selfplay_iterations': 350,
    'num_epochs': 4,
    'temperature': 1.25,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 1
}

deepzero = DeepZeroParallel(model, optimizer, game, args)
deepzero.learn()


  0%|          | 0/3 [00:04<?, ?it/s]


KeyboardInterrupt: 

In [6]:
game = TicTacToe()
player = -1
device = torch.device("cpu")

args = {
    'C': 2,
    'num_search': 100,
    'num_iterations': 3,
    'num_parallel_games': 100,
    'batch_size': 16,
    'num_selfplay_iterations': 350,
    'num_epochs': 4,
    'temperature': 1.25,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 0.3
}


model = ResNet(game, 4, 32, device=device)
model.load_state_dict(torch.load("weights/model_2_TicTacToe.pt", map_location=device))
model.eval()
mcts = MCTS(game, args, model)
state = game.get_initial_state()

while True:
    print(state)
    if player == 1:
        valid_moves = game.get_valid_moves(state)
        print("val_movies", [i for i in range(game.action_size) if valid_moves[i] == 1])
        action = int(input(f"{player}: "))
        if valid_moves[action] == 0:
            print("action not val")
            continue
    else:
        valid_moves = game.get_valid_moves(state)
        neutral_state = game.change_perspective(state, player)
        mcts_probs, net_win_value = mcts.search(neutral_state)
        print("expected win rate", net_win_value)
        mcts_probs = mcts_probs * valid_moves  # Mask invalid moves to zero
        action = np.argmax(mcts_probs)
        # Optional: Add a check for no valid moves, though this should not occur in a proper game state
        if valid_moves[action] == 0:
            raise ValueError("No valid moves available; game state may be invalid.")

    state = game.get_next_state(state, action, player)
    value, is_terminate = game.get_value_and_terminated(state, action)
    if is_terminate:
        if value == 1:
            print(player, "win")
        else:
            print(player, "lose")
        break
    player = game.get_opponent(player)

[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
expected win rate -0.3342079520225525
[[ 0. -1.  0.]
 [ 0.  0.  0.]
 [ 0.  0.  0.]]
val_movies [0, 2, 3, 4, 5, 6, 7, 8]
[[ 0. -1.  0.]
 [ 0.  0.  0.]
 [ 0.  0.  1.]]
expected win rate 0.16230599582195282
[[ 0. -1. -1.]
 [ 0.  0.  0.]
 [ 0.  0.  1.]]
val_movies [0, 3, 4, 5, 6, 7]
[[ 1. -1. -1.]
 [ 0.  0.  0.]
 [ 0.  0.  1.]]
expected win rate 0.49623584747314453
[[ 1. -1. -1.]
 [ 0. -1.  0.]
 [ 0.  0.  1.]]
val_movies [3, 5, 6, 7]
[[ 1. -1. -1.]
 [ 0. -1.  0.]
 [ 0.  1.  1.]]
expected win rate 0.4345284402370453
-1 win


In [8]:
game = ConnectFour()

device = torch.device("cpu")

model = ResNet(game, 9, 64, device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

player = 1

args = {
    'C': 2,
    'num_search': 600,
    'num_iterations': 8,
    'batch_size': 64,
    'num_selfplay_iterations': 500,
    'num_parallel_games': 100,
    'num_epochs': 4,
    'temperature': 1.25,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 0.3
}

deepzero = DeepZeroParallel(model, optimizer, game, args)
deepzero.learn()


  0%|          | 0/5 [01:12<?, ?it/s]


KeyboardInterrupt: 

In [18]:
game = ConnectFour()
player = -1
device = torch.device("cpu")
args = {
    'C': 2,
    'num_search': 500,
    'num_iterations': 10,
    'batch_size': 64,
    'num_selfplay_iterations': 500,
    'num_epochs': 4,
    'temperature': 1.25,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 0.3
}
model = ResNet(game, 9, 128, device=device)
model.load_state_dict(torch.load("weights/model_4_ConnectFour.pt", map_location=device))
model.eval()
mcts = MCTS(game, args, model)
state = game.get_initial_state()

while True:
    print(state)
    if player == 1:
        valid_moves = game.get_valid_moves(state)
        print("val_movies", [i for i in range(game.action_size) if valid_moves[i] == 1])
        action = int(input(f"{player}: "))

        if valid_moves[action] == 0:
            print("action not val")
            continue
    else:
        neutral_state = game.change_perspective(state, player)
        mcts_probs, net_win_value = mcts.search(neutral_state)
        print("expected win rate", net_win_value)
        action = np.argmax(mcts_probs)

    state = game.get_next_state(state, action, player)

    value, is_terminate = game.get_value_and_terminated(state, action)

    if is_terminate:
        if value == 1:
            print(player, "win")
        else:
             print(player, "lose")
        break

    player = game.get_opponent(player)

AttributeError: 'ConnectFour' object has no attribute 'shape_obs'

In [21]:
class Checkers:
    def __init__(self):
        self.row_count = 8
        self.column_count = 8
        self.shape_obs = 5
        # Действие кодируется как: from_pos * 64 + to_pos
        # from_pos = row * 8 + col, to_pos = row * 8 + col
        self.action_size = 64 * 64  # 4096 возможных действий
        # Фигуры: 1 = шашка текущего игрока, 2 = дамка текущего игрока
        #        -1 = шашка противника, -2 = дамка противника

    def __repr__(self):
        return "Checkers"

    def get_initial_state(self):
        """Возвращает начальное состояние доски 8x8"""
        state = np.zeros((self.row_count, self.column_count), dtype=np.int8)

        # Расставляем шашки только на тёмных клетках (row + col) % 2 == 1
        for row in range(self.row_count):
            for col in range(self.column_count):
                if (row + col) % 2 == 1:
                    if row < 3:
                        state[row, col] = -1  # Шашки противника (вверху)
                    elif row > 4:
                        state[row, col] = 1   # Шашки игрока (внизу)

        return state

    def flip_action(self, action):
        """Переворачивает action после change_perspective"""
        from_pos = action // 64
        to_pos = action % 64

        from_row, from_col = from_pos // 8, from_pos % 8
        to_row, to_col = to_pos // 8, to_pos % 8

        new_from_row = 7 - from_row
        new_to_row = 7 - to_row

        new_from_pos = new_from_row * 8 + from_col
        new_to_pos = new_to_row * 8 + to_col

        return new_from_pos * 64 + new_to_pos

    def get_next_state(self, state, action, player):
        """
        Применяет действие к состоянию и возвращает новое состояние.
        Состояние должно быть с перспективы player (после change_perspective).
        """
        state = state.copy()

        from_pos = action // 64
        to_pos = action % 64

        from_row, from_col = from_pos // 8, from_pos % 8
        to_row, to_col = to_pos // 8, to_pos % 8

        piece = state[from_row, from_col]
        state[from_row, from_col] = 0

        # Проверяем взятие (прыжок через 2 клетки)
        if abs(to_row - from_row) == 2:
            captured_row = (from_row + to_row) // 2
            captured_col = (from_col + to_col) // 2
            state[captured_row, captured_col] = 0

        # Превращение в дамку при достижении верхнего края (row == 0)
        # Игрок 1 всегда двигается вверх
        if piece == 1 and to_row == 0:
            state[to_row, to_col] = 2  # Дамка
        else:
            state[to_row, to_col] = piece

        return state

    def get_valid_moves(self, state):
        """
        Возвращает маску допустимых ходов размером action_size.
        Если есть взятия - возвращает только взятия (взятие обязательно).
        Предполагается что state с перспективы текущего игрока (player=1).
        """
        valid_moves = np.zeros(self.action_size, dtype=np.uint8)
        captures = []
        regular_moves = []

        for row in range(self.row_count):
            for col in range(self.column_count):
                piece = state[row, col]

                # Ищем фигуры текущего игрока (1 = шашка, 2 = дамка)
                if piece != 1 and piece != 2:
                    continue

                is_king = piece == 2

                # Направления движения: шашка вверх, дамка в обе стороны
                if is_king:
                    directions = [(-1, -1), (-1, 1), (1, -1), (1, 1)]
                else:
                    directions = [(-1, -1), (-1, 1)]  # Только вверх

                for dr, dc in directions:
                    new_row, new_col = row + dr, col + dc

                    if not self._is_valid_position(new_row, new_col):
                        continue

                    target = state[new_row, new_col]

                    if target == 0:
                        # Обычный ход
                        from_pos = row * 8 + col
                        to_pos = new_row * 8 + new_col
                        action = from_pos * 64 + to_pos
                        regular_moves.append(action)

                    elif target == -1 or target == -2:
                        # Возможное взятие (фигура противника)
                        jump_row, jump_col = new_row + dr, new_col + dc
                        if self._is_valid_position(jump_row, jump_col) and state[jump_row, jump_col] == 0:
                            from_pos = row * 8 + col
                            to_pos = jump_row * 8 + jump_col
                            action = from_pos * 64 + to_pos
                            captures.append(action)

                # Для обычной шашки проверяем взятие назад
                if not is_king:
                    for dr, dc in [(1, -1), (1, 1)]:  # Назад
                        new_row, new_col = row + dr, col + dc

                        if not self._is_valid_position(new_row, new_col):
                            continue

                        target = state[new_row, new_col]

                        if target == -1 or target == -2:
                            jump_row, jump_col = new_row + dr, new_col + dc
                            if self._is_valid_position(jump_row, jump_col) and state[jump_row, jump_col] == 0:
                                from_pos = row * 8 + col
                                to_pos = jump_row * 8 + jump_col
                                action = from_pos * 64 + to_pos
                                captures.append(action)

        # Взятие обязательно
        moves_to_use = captures if len(captures) > 0 else regular_moves

        for action in moves_to_use:
            valid_moves[action] = 1

        return valid_moves

    def check_win(self, state, action):
        """
        Проверяет победу после хода.
        Победа если у противника нет фигур или нет допустимых ходов.
        """
        if action is None:
            return False

        # Проверяем, есть ли фигуры у противника
        opponent_pieces = np.sum((state == -1) | (state == -2))
        if opponent_pieces == 0:
            return True

        # Проверяем, есть ли допустимые ходы у противника
        # Меняем перспективу и проверяем ходы
        opponent_state = self.change_perspective(state, -1)
        opponent_moves = self.get_valid_moves(opponent_state)
        if np.sum(opponent_moves) == 0:
            return True

        return False

    def get_value_and_terminated(self, state, action):
        """
        Возвращает (value, terminated).
        value = 1 при победе текущего игрока, 0 при ничьей или продолжении.
        """
        if self.check_win(state, action):
            return 1, True

        # Проверка на ничью (нет ходов у текущего игрока)
        if np.sum(self.get_valid_moves(state)) == 0:
            return 0, True

        return 0, False

    def get_opponent(self, player):
        """Возвращает противника"""
        return -player

    def get_opponent_value(self, value):
        """Возвращает значение с точки зрения противника"""
        return -value

    def change_perspective(self, state, player):
        """
        Меняет перспективу доски для игрока.
        После вызова текущий игрок всегда представлен как player=1.
        Доска переворачивается и значения инвертируются.
        """
        if player == -1:
            # Переворачиваем доску и меняем знаки фигур
            return np.flip(state, axis=0) * -1
        return state.copy()

    def get_encoded_state(self, state):
        """
        Кодирует состояние для нейронной сети.
        Возвращает 5 каналов:
        - Шашки текущего игрока (1)
        - Дамки текущего игрока (2)
        - Шашки противника (-1)
        - Дамки противника (-2)
        - Пустые клетки (0)
        """
        encoded_state = np.stack(
            (
                (state == 1).astype(np.float32),   # Мои шашки
                (state == 2).astype(np.float32),   # Мои дамки
                (state == -1).astype(np.float32),  # Шашки противника
                (state == -2).astype(np.float32),  # Дамки противника
                (state == 0).astype(np.float32)    # Пустые клетки
            )
        )

        if len(state.shape) == 3:
            encoded_state = np.swapaxes(encoded_state, 0, 1)

        return encoded_state

    def _is_valid_position(self, row, col):
        """Проверяет, находится ли позиция в пределах доски"""
        return 0 <= row < self.row_count and 0 <= col < self.column_count

    def has_additional_captures(self, state, row, col):
        """Проверяет, есть ли дополнительные взятия для шашки после хода"""
        piece = state[row, col]
        if piece != 1 and piece != 2:
            return False

        # Все направления для взятия
        directions = [(-1, -1), (-1, 1), (1, -1), (1, 1)]

        for dr, dc in directions:
            new_row, new_col = row + dr, col + dc

            if not self._is_valid_position(new_row, new_col):
                continue

            target = state[new_row, new_col]
            if target == -1 or target == -2:
                jump_row, jump_col = new_row + dr, new_col + dc
                if self._is_valid_position(jump_row, jump_col) and state[jump_row, jump_col] == 0:
                    return True

        return False

    def get_captures_for_piece(self, state, row, col):
        """Возвращает список взятий для конкретной шашки"""
        captures = []
        piece = state[row, col]

        if piece != 1 and piece != 2:
            return captures

        directions = [(-1, -1), (-1, 1), (1, -1), (1, 1)]

        for dr, dc in directions:
            new_row, new_col = row + dr, col + dc

            if not self._is_valid_position(new_row, new_col):
                continue

            target = state[new_row, new_col]
            if target == -1 or target == -2:
                jump_row, jump_col = new_row + dr, new_col + dc
                if self._is_valid_position(jump_row, jump_col) and state[jump_row, jump_col] == 0:
                    from_pos = row * 8 + col
                    to_pos = jump_row * 8 + jump_col
                    action = from_pos * 64 + to_pos
                    captures.append(action)

        return captures

    def action_to_coords(self, action):
        """Декодирует действие в координаты (from_row, from_col, to_row, to_col)"""
        from_pos = action // 64
        to_pos = action % 64
        return (from_pos // 8, from_pos % 8, to_pos // 8, to_pos % 8)

    def coords_to_action(self, from_row, from_col, to_row, to_col):
        """Кодирует координаты в действие"""
        from_pos = from_row * 8 + from_col
        to_pos = to_row * 8 + to_col
        return from_pos * 64 + to_pos

    def print_board(self, state):
        """Красивый вывод доски в консоль"""
        symbols = {
            0: '.',
            1: 'w',   # Моя шашка
            2: 'W',   # Моя дамка
            -1: 'b',  # Шашка противника
            -2: 'B'   # Дамка противника
        }

        print("  0 1 2 3 4 5 6 7")
        for row in range(self.row_count):
            print(f"{row} ", end="")
            for col in range(self.column_count):
                print(symbols[state[row, col]] + " ", end="")
            print()
        print()

In [22]:
game = Checkers()
player = -1
device = torch.device("cpu")
args = {
    'C': 2,
    'num_search': 600,
    'num_iterations': 10,
    'batch_size': 64,
    'num_selfplay_iterations': 400,
    'num_epochs': 4,
    'temperature': 1.25,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 0.3
}
model = ResNet(game, 12, 32, device=device)
model.eval()
mcts = MCTS(game, args, model)
state = game.get_initial_state()

while True:
    print(state)
    if player == 1:
        valid_moves = game.get_valid_moves(state).flatten()
        print("val_movies", [i for i in range(game.action_size) if valid_moves[i] == 1])
        action = int(input(f"{player}: "))

        if valid_moves[action] == 0:
            print("action not val")
            continue
    else:
        neutral_state = game.change_perspective(state, player)
        mcts_probs, net_win_value = mcts.search(neutral_state)
        print("expected win rate", net_win_value)
        valid_moves = game.get_valid_moves(neutral_state)
        mcts_probs = mcts_probs * valid_moves  # Mask invalid moves to zero
        action = np.argmax(mcts_probs)
        action = game.flip_action(action)
    state = game.get_next_state(state, action, player)

    value, is_terminate = game.get_value_and_terminated(state, action)

    if is_terminate:
        if value == 1:
            print(player, "win")
        else:
             print(player, "lose")
        break

    player = game.get_opponent(player)

[[ 0 -1  0 -1  0 -1  0 -1]
 [-1  0 -1  0 -1  0 -1  0]
 [ 0 -1  0 -1  0 -1  0 -1]
 [ 0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0]
 [ 1  0  1  0  1  0  1  0]
 [ 0  1  0  1  0  1  0  1]
 [ 1  0  1  0  1  0  1  0]]
expected win rate 0.06525762379169464
[[ 0 -1  0 -1  0 -1  0 -1]
 [-1  0 -1  0 -1  0 -1  0]
 [ 0 -1  0  0  0 -1  0 -1]
 [ 0  0  0  0 -1  0  0  0]
 [ 0  0  0  0  0  0  0  0]
 [ 1  0  1  0  1  0  1  0]
 [ 0  1  0  1  0  1  0  1]
 [ 1  0  1  0  1  0  1  0]]
val_movies [2593, 2721, 2723, 2851, 2853, 2981, 2983]
[[ 0 -1  0 -1  0 -1  0 -1]
 [-1  0 -1  0 -1  0 -1  0]
 [ 0 -1  0  0  0 -1  0 -1]
 [ 0  0  0  0 -1  0  0  0]
 [ 0  0  0  1  0  0  0  0]
 [ 1  0  0  0  1  0  1  0]
 [ 0  1  0  1  0  1  0  1]
 [ 1  0  1  0  1  0  1  0]]
expected win rate 0.08224310725927353
[[ 0 -1  0 -1  0 -1  0 -1]
 [-1  0 -1  0 -1  0 -1  0]
 [ 0 -1  0  0  0 -1  0 -1]
 [ 0  0  0  0  0  0  0  0]
 [ 0  0  0  0  0  0  0  0]
 [ 1  0 -1  0  1  0  1  0]
 [ 0  1  0  1  0  1  0  1]
 [ 1  0  1  0  1  0  1  0]]
va

KeyboardInterrupt: Interrupted by user

In [15]:
mcts_probs[0].shape

(4096,)

In [36]:
сheckers.print_board(state)

  0 1 2 3 4 5 6 7
0 . w . w . w . w 
1 w . w . w . w . 
2 . w . w . w . w 
3 . . . . . . . . 
4 . . . . . . . . 
5 b . b . b . b . 
6 . b . b . b . b 
7 b . b . b . b . 


In [37]:
сheckers.action_size

128