In [1]:
import numpy as np
from games.tictactoe import TicTacToe
from games.connectfour import ConnectFour
from games.checkers import Checkers
from models.mcts import MCTS, MCTSParallel
from models.resnet import ResNet
from models.deepzero import DeepZero, DeepZeroParallel
import torch
from tqdm import tqdm
from tqdm import trange
import random
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt

In [None]:
game = TicTacToe()

device = torch.device("cpu")

model = ResNet(game, 4, 32, device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

player = 1

args = {
    'C': 2,
    'num_search': 100,
    'num_iterations': 3,
    'num_parallel_games': 100,
    'batch_size': 16,
    'num_selfplay_iterations': 350,
    'num_epochs': 4,
    'temperature': 1.25,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 1
}

deepzero = DeepZeroParallel(model, optimizer, game, args)
deepzero.learn()


In [None]:
game = TicTacToe()
player = -1
device = torch.device("cpu")

args = {
    'C': 2,
    'num_search': 100,
    'num_iterations': 3,
    'num_parallel_games': 100,
    'batch_size': 16,
    'num_selfplay_iterations': 350,
    'num_epochs': 4,
    'temperature': 1.25,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 0.3
}


model = ResNet(game, 4, 32, device=device)
model.load_state_dict(torch.load("weights/model_2_TicTacToe.pt", map_location=device))
model.eval()
mcts = MCTS(game, args, model)
state = game.get_initial_state()

while True:
    print(state)
    if player == 1:
        valid_moves = game.get_valid_moves(state)
        print("val_movies", [i for i in range(game.action_size) if valid_moves[i] == 1])
        action = int(input(f"{player}: "))
        if valid_moves[action] == 0:
            print("action not val")
            continue
    else:
        valid_moves = game.get_valid_moves(state)
        neutral_state = game.change_perspective(state, player)
        mcts_probs, net_win_value = mcts.search(neutral_state)
        print("expected win rate", net_win_value)
        mcts_probs = mcts_probs * valid_moves  # Mask invalid moves to zero
        action = np.argmax(mcts_probs)
        # Optional: Add a check for no valid moves, though this should not occur in a proper game state
        if valid_moves[action] == 0:
            raise ValueError("No valid moves available; game state may be invalid.")

    state = game.get_next_state(state, action, player)
    value, is_terminate = game.get_value_and_terminated(state, action)
    if is_terminate:
        if value == 1:
            print(player, "win")
        else:
            print(player, "lose")
        break
    player = game.get_opponent(player)

In [None]:
game = ConnectFour()

device = torch.device("cpu")

model = ResNet(game, 9, 64, device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)

player = 1

args = {
    'C': 2,
    'num_search': 600,
    'num_iterations': 8,
    'batch_size': 64,
    'num_selfplay_iterations': 500,
    'num_parallel_games': 100,
    'num_epochs': 4,
    'temperature': 1.25,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 0.3
}

deepzero = DeepZeroParallel(model, optimizer, game, args)
deepzero.learn()


In [None]:
game = ConnectFour()
player = -1
device = torch.device("cpu")
args = {
    'C': 2,
    'num_search': 500,
    'num_iterations': 10,
    'batch_size': 64,
    'num_selfplay_iterations': 500,
    'num_epochs': 4,
    'temperature': 1.25,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 0.3
}
model = ResNet(game, 9, 128, device=device)
model.load_state_dict(torch.load("weights/model_4_ConnectFour.pt", map_location=device))
model.eval()
mcts = MCTS(game, args, model)
state = game.get_initial_state()

while True:
    print(state)
    if player == 1:
        valid_moves = game.get_valid_moves(state)
        print("val_movies", [i for i in range(game.action_size) if valid_moves[i] == 1])
        action = int(input(f"{player}: "))

        if valid_moves[action] == 0:
            print("action not val")
            continue
    else:
        neutral_state = game.change_perspective(state, player)
        mcts_probs, net_win_value = mcts.search(neutral_state)
        print("expected win rate", net_win_value)
        action = np.argmax(mcts_probs)

    state = game.get_next_state(state, action, player)

    value, is_terminate = game.get_value_and_terminated(state, action)

    if is_terminate:
        if value == 1:
            print(player, "win")
        else:
             print(player, "lose")
        break

    player = game.get_opponent(player)

In [5]:
game = Checkers()
player = -1
device = torch.device("cpu")
args = {
    'C': 2,
    'num_searches': 400,
    'num_iterations': 10,
    'num_parallel_games': 200,
    'batch_size': 128,
    'num_selfPlay_iterations': 1000,
    'num_epochs': 10,
    'temperature': 1.25,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 0.3
}

model = ResNet(game, 24, 256, device=device)
model.load_state_dict(torch.load("weights/model_3_Checkers.pt", map_location=device))
model.eval()
mcts = MCTS(game, args, model)
state = game.get_initial_state()

while True:
    game.print_board(state)
    if player == 1:
        valid_moves = game.get_valid_moves(state).flatten()
        print("val_movies", [i for i in range(game.action_size) if valid_moves[i] == 1])
        action = int(input(f"{player}: "))

        if valid_moves[action] == 0:
            print("action not val")
            continue
    else:
        neutral_state = game.change_perspective(state, player)
        mcts_probs, net_win_value = mcts.search(neutral_state)
        print("expected win rate", net_win_value)
        valid_moves = game.get_valid_moves(neutral_state)
        mcts_probs = mcts_probs * valid_moves  # Mask invalid moves to zero
        action = np.argmax(mcts_probs)
        action = game.flip_action(action)
    state = game.get_next_state(state, action, player)

    value, is_terminate = game.get_value_and_terminated(state, action)

    if is_terminate:
        if value == 1:
            print(player, "win")
        else:
             print(player, "lose")
        break

    player = game.get_opponent(player)

  0 1 2 3 4 5 6 7
0 . b . b . b . b 
1 b . b . b . b . 
2 . b . b . b . b 
3 . . . . . . . . 
4 . . . . . . . . 
5 w . w . w . w . 
6 . w . w . w . w 
7 w . w . w . w . 

expected win rate_w 0.3734758198261261
  0 1 2 3 4 5 6 7
0 . b . b . b . b 
1 b . b . b . b . 
2 . b . b . b . . 
3 . . . . . . b . 
4 . . . . . . . . 
5 w . w . w . w . 
6 . w . w . w . w 
7 w . w . w . w . 

expected win rate_p 0.3088635206222534
  0 1 2 3 4 5 6 7
0 . b . b . b . b 
1 b . b . b . b . 
2 . b . b . b . . 
3 . . . . . . b . 
4 . . . . . . . . 
5 w . w . w . w . 
6 . w . w . w . w 
7 w . w . w . w . 



KeyboardInterrupt: 

In [7]:
game = Checkers()
device = torch.device("cpu")

# Базовые параметры MCTS
args = {
    'C': 2,
    'num_searches': 800,
    'num_iterations': 10,
    'num_parallel_games': 200,
    'batch_size': 128,
    'num_selfPlay_iterations': 1000,
    'num_epochs': 10,
    'temperature': 1.25,
    'dirichlet_epsilon': 0.25,
    'dirichlet_alpha': 0.3
}

# Для матча лучше отключить шум Дирихле и сделать temperature=1 (или <1)
eval_args = args.copy()
eval_args['dirichlet_epsilon'] = 0.0
eval_args['temperature'] = 1.0

# === МОДЕЛЬ / АЛГОРИТМ ДЛЯ БЕЛЫХ (player = 1) ===
model_white = ResNet(game, 24, 256, device=device)
model_white.load_state_dict(torch.load("weights/model_3_Checkers.pt", map_location=device))
model_white.eval()
mcts_white = MCTS(game, eval_args, model_white)

# === МОДЕЛЬ / АЛГОРИТМ ДЛЯ ЧЁРНЫХ (player = -1) ===
model_black = ResNet(game, 24, 256, device=device)
model_black.load_state_dict(torch.load("weights/model_3_Checkers.pt", map_location=device))
model_black.eval()
mcts_black = MCTS(game, eval_args, model_black)

# Если у вас один и тот же файл весов для обоих:
# model_white.load_state_dict(torch.load("weights/model_3_Checkers.pt", map_location=device))
# model_black.load_state_dict(torch.load("weights/model_3_Checkers.pt", map_location=device))

state = game.get_initial_state()
player = 1  # белые начинают

while True:
    game.print_board(state)
    print(f"Ход игрока {player} ({'белые' if player == 1 else 'чёрные'})")

    # Выбираем, каким MCTS пользоваться
    if player == 1:
        mcts_current = mcts_white
    else:
        mcts_current = mcts_black

    # Переводим позицию в нейтральную перспективу текущего игрока
    neutral_state = game.change_perspective(state, player)

    # --- Вариант 1: если ваш MCTS.search возвращает только политику ---
    # mcts_probs = mcts_current.search(neutral_state)

    # --- Вариант 2: если MCTS.search возвращает (политика, value) ---
    mcts_probs, net_win_value = mcts_current.search(neutral_state)
    print("Оценка шансов на победу (с точки зрения текущего игрока):", net_win_value)

    # Маска допустимых ходов
    valid_moves = game.get_valid_moves(neutral_state)
    mcts_probs = mcts_probs * valid_moves  # обнуляем запрещённые ходы

    if mcts_probs.sum() == 0:
        # MCTS не предложил ходов — делаем fallback: выбираем любой допустимый
        valid_indices = np.where(valid_moves == 1)[0]
        if len(valid_indices) == 0:
            # У текущего игрока реально нет ходов — партия окончена
            value, is_terminate = game.get_value_and_terminated(neutral_state, None)
            if value == 1:
                print(f"Игрок {player} выиграл (нет ходов у противника)!")
            else:
                print(f"Игрок {player} не может ходить — проигрыш или ничья.")
            break
        action_neutral = np.random.choice(valid_indices)
    else:
        # Жадный выбор лучшего хода
        action_neutral = np.argmax(mcts_probs)

    # Для чёрных нужно перевести ход обратно в глобальные координаты
    if player == -1:
        action = game.flip_action(action_neutral)
    else:
        action = action_neutral

    # Применяем ход к "реальному" состоянию
    state = game.get_next_state(state, action, player)

    # Проверяем окончание игры
    value, is_terminate = game.get_value_and_terminated(state, action)
    if is_terminate:
        game.print_board(state)
        if value == 1:
            print(f"Игрок {player} ({'белые' if player == 1 else 'чёрные'}) выиграл!")
        else:
            print(f"Игрок {player} ({'белые' if player == 1 else 'чёрные'}) проиграл!")
        break

    # Меняем игрока
    player = game.get_opponent(player)

  0 1 2 3 4 5 6 7
0 . b . b . b . b 
1 b . b . b . b . 
2 . b . b . b . b 
3 . . . . . . . . 
4 . . . . . . . . 
5 w . w . w . w . 
6 . w . w . w . w 
7 w . w . w . w . 

Ход игрока 1 (белые)
Оценка шансов на победу (с точки зрения текущего игрока): 0.3537505269050598
  0 1 2 3 4 5 6 7
0 . b . b . b . b 
1 b . b . b . b . 
2 . b . b . b . b 
3 . . . . . . . . 
4 . . . . . w . . 
5 w . w . . . w . 
6 . w . w . w . w 
7 w . w . w . w . 

Ход игрока -1 (чёрные)
Оценка шансов на победу (с точки зрения текущего игрока): 0.39288923144340515
  0 1 2 3 4 5 6 7
0 . b . b . b . b 
1 b . b . b . b . 
2 . b . b . b . . 
3 . . . . . . b . 
4 . . . . . w . . 
5 w . w . . . w . 
6 . w . w . w . w 
7 w . w . w . w . 

Ход игрока 1 (белые)
Оценка шансов на победу (с точки зрения текущего игрока): 0.34579038619995117
  0 1 2 3 4 5 6 7
0 . b . b . b . b 
1 b . b . b . b . 
2 . b . b . b . w 
3 . . . . . . . . 
4 . . . . . . . . 
5 w . w . . . w . 
6 . w . w . w . w 
7 w . w . w . w . 

Ход игрока -1 (чёр

KeyboardInterrupt: 