In [1]:
# Minimal AlphaZero-style skeleton for Go (self-play + MCTS + neural net)
# Dependencies: torch, numpy

import math
import random
import collections
import copy
from typing import List, Tuple, Optional, Dict
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

In [None]:
# ---------------------------
# 1) Basic Go board
# ---------------------------
EMPTY = 0
BLACK = 1
WHITE = 2
# Opponent's color
def opponent(player: int) -> int:
    return BLACK if player == WHITE else WHITE

class Board:
    def __init__(self, size:int=9, komi:float=0):
        if komi == 0:
            match (size): # New switch case notation
                case (19):
                    komi = 5.5
                case (13):
                    komi = 4.5 
                case(9):
                    komi = 3.5
                case (_):
                    komi = 1.5
            # end match
        self.size = size
        self.komi = komi
        self.grid = np.zeros((size, size), dtype=np.int8)
        self.history = set()
        self._add_history()
        self.pass_count = 0

    def copy(self):
        b = Board(self.size, self.komi)
        b.grid = self.grid.copy()
        b.history = set(self.history)
        b.pass_count = self.pass_count
        return b

    def in_bounds(self, x:int, y:int) -> bool:
        return 0 <= x < self.size and 0 <= y < self.size

    def neighbors(self, x:int, y:int):
        for dx, dy in ((1,0),(-1,0),(0,1),(0,-1)):
            nx, ny = x+dx, y+dy
            if self.in_bounds(nx, ny):
                yield nx, ny

    def _board_tuple(self):
        return tuple(self.grid.ravel().tolist())

    def _add_history(self):
        self.history.add(self._board_tuple())

    def get(self, x:int, y:int) -> int:
        return int(self.grid[y,x])

    def set(self, x:int, y:int, v:int):
        self.grid[y,x] = v

    def _collect_group(self, x:int, y:int):
        color = self.get(x, y)
        if color == EMPTY:
            return [], set()
        visited = set()
        stack = [(x,y)]
        group = []
        liberties = set()
        while stack:
            cx, cy = stack.pop()
            if (cx, cy) in visited:
                continue
            visited.add((cx, cy))
            group.append((cx, cy))
            for nx, ny in self.neighbors(cx, cy):
                val = self.get(nx, ny)
                if val == color and (nx, ny) not in visited:
                    stack.append((nx, ny))
                elif val == EMPTY:
                    liberties.add((nx, ny))
        return group, liberties

    def _remove_group(self, group: List[Tuple[int,int]]):
        if not group:
            return 0
        col = self.get(group[0][0], group[0][1])
        for x,y in group:
            self.set(x,y,EMPTY)
        return len(group)

    def _adjacent_enemy_groups(self, x:int, y:int, player:int):
        seen = set()
        for nx, ny in self.neighbors(x,y):
            if self.get(nx, ny) == opponent(player) and (nx, ny) not in seen:
                seen.add((nx, ny))
                yield nx, ny

    def _play_move_no_checks(self, player:int, move:Optional[Tuple[int,int]]):
        # move: None means pass
        if move is None:
            self.pass_count += 1
            self._add_history()
            return 0
        x,y = move
        self.set(x,y, player)
        self.pass_count = 0
        # capture adjacent enemy groups with zero liberties
        removed = 0
        to_remove = set()
        for nx, ny in self.neighbors(x,y):
            if self.get(nx, ny) == opponent(player):
                g, libs = self._collect_group(nx, ny)
                if len(libs) == 0:
                    to_remove.update(g)
        for rx, ry in to_remove:
            self.set(rx, ry, EMPTY)
            removed += 1
        self._add_history()
        return removed

    def is_suicide(self, player:int, move:Tuple[int,int]) -> bool:
        x,y = move
        tmp = self.copy()
        tmp.set(x,y, player)
        # remove opponent dead groups
        for nx, ny in tmp.neighbors(x,y):
            if tmp.get(nx, ny) == opponent(player):
                g, libs = tmp._collect_group(nx, ny)
                if len(libs) == 0:
                    tmp._remove_group(g)
        # check own group's liberties
        g, libs = tmp._collect_group(x,y)
        return len(libs) == 0

    def is_legal(self, player:int, move:Optional[Tuple[int,int]]) -> bool:
        if move is None:
            return True
        x,y = move
        if not self.in_bounds(x,y): return False
        if self.get(x,y) != EMPTY: return False
        # suicide
        if self.is_suicide(player, (x,y)): return False
        # superko (simple): simulate and check history
        tmp = self.copy()
        tmp._play_move_no_checks(player, (x,y))
        if tmp._board_tuple() in self.history:
            return False
        return True

    def legal_moves(self, player:int) -> List[Optional[Tuple[int,int]]]:
        moves = []
        for y in range(self.size):
            for x in range(self.size):
                if self.grid[y,x] == EMPTY and self.is_legal(player, (x,y)):
                    moves.append((x,y))
        moves.append(None)  # pass
        return moves

    def game_over(self) -> bool:
        return self.pass_count >= 2

    def score(self) -> Tuple[float,float]:
        # area scoring (Chinese)
        size = self.size
        visited = np.zeros((size,size), dtype=bool)
        black_area = 0
        white_area = 0
        for y in range(size):
            for x in range(size):
                v = self.get(x,y)
                if v == BLACK: black_area += 1
                elif v == WHITE: white_area += 1
        from collections import deque
        for y in range(size):
            for x in range(size):
                if self.get(x,y) != EMPTY or visited[y,x]:
                    continue
                q = deque()
                q.append((x,y))
                visited[y,x] = True
                region = [(x,y)]
                bordering = set()
                while q:
                    cx, cy = q.popleft()
                    for nx, ny in self.neighbors(cx,cy):
                        val = self.get(nx, ny)
                        if val == EMPTY and not visited[ny,nx]:
                            visited[ny,nx] = True
                            q.append((nx,ny))
                            region.append((nx,ny))
                        elif val in (BLACK, WHITE):
                            bordering.add(val)
                if len(bordering) == 1:
                    owner = next(iter(bordering))
                    if owner == BLACK:
                        black_area += len(region)
                    else:
                        white_area += len(region)
        white_area += self.komi
        return float(black_area), float(white_area)

    def winner(self) -> int:
        b, w = self.score()
        return BLACK if b > w else WHITE

    # helper: get neural network planes
    def get_state_planes(self, player:int) -> np.ndarray:
        # returns shape (3, size, size): black, white, to_move
        black = (self.grid == BLACK).astype(np.float32)
        white = (self.grid == WHITE).astype(np.float32)
        turn_plane = np.full((self.size, self.size), 1.0 if player == BLACK else 0.0, dtype=np.float32)
        return np.stack([black, white, turn_plane], axis=0)


In [3]:
# ---------------------------
# 2) Neural network (PyTorch)
# ---------------------------
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(out_ch)
    def forward(self, x):
        return torch.relu(self.bn(self.conv(x)))

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out = torch.relu(out + x)
        return out

class AlphaNet(nn.Module):
    def __init__(self, board_size=9, num_res=3, channels=128):
        super().__init__()
        self.board_size = board_size
        self.conv_in = ConvBlock(3, channels)
        self.resblocks = nn.Sequential(*[ResidualBlock(channels) for _ in range(num_res)])
        # policy head
        self.p_conv = nn.Conv2d(channels, 2, 1)
        self.p_bn = nn.BatchNorm2d(2)
        self.p_fc = nn.Linear(2 * board_size * board_size, board_size * board_size + 1)  # +1 for pass
        # value head
        self.v_conv = nn.Conv2d(channels, 1, 1)
        self.v_bn = nn.BatchNorm2d(1)
        self.v_fc1 = nn.Linear(board_size * board_size, 64)
        self.v_fc2 = nn.Linear(64, 1)

    def forward(self, x):
        # x: [B, 3, H, W]
        out = self.conv_in(x)
        out = self.resblocks(out)
        # policy
        p = torch.relu(self.p_bn(self.p_conv(out)))
        p = p.view(p.size(0), -1)
        logits = self.p_fc(p)  # shape [B, H*W + 1]
        logp = torch.log_softmax(logits, dim=1)
        # value
        v = torch.relu(self.v_bn(self.v_conv(out)))
        v = v.view(v.size(0), -1)
        v = torch.relu(self.v_fc1(v))
        value = torch.tanh(self.v_fc2(v)).squeeze(1)
        return logp, value  # log probabilities and scalar value


In [4]:
# ---------------------------
# 3) MCTS with PUCT
# ---------------------------
class MCTSNode:
    def __init__(self, prior:float=0.0):
        self.prior = prior  # P(a)
        self.visit_count = 0  # N
        self.value_sum = 0.0  # W
        self.children = {}  # action -> MCTSNode

    def q_value(self):
        if self.visit_count == 0:
            return 0.0
        return self.value_sum / self.visit_count

class MCTS:
    def __init__(self, net:AlphaNet, board_size:int, c_puct:float=1.0, n_simulations:int=80, device='cpu'):
        self.net = net
        self.board_size = board_size
        self.c_puct = c_puct
        self.n_simulations = n_simulations
        self.device = device

    def run(self, root_board:Board, to_move:int):
        root = MCTSNode()
        # evaluate root with network to get priors
        state = torch.tensor(root_board.get_state_planes(to_move), dtype=torch.float32).unsqueeze(0)
        with torch.no_grad():
            logp, v = self.net(state.to(self.device))
        probs = torch.exp(logp).squeeze(0).cpu().numpy()  # length = H*W + 1
        legal = root_board.legal_moves(to_move)
        # map actions to indices: index = y * size + x, pass = size*size
        size = self.board_size
        legal_indices = []
        for mv in legal:
            if mv is None:
                idx = size*size
            else:
                idx = mv[1]*size + mv[0]
            legal_indices.append(idx)
        # initialize children with prior P from network but zero prior for illegal
        for idx in legal_indices:
            root.children[idx] = MCTSNode(prior=float(probs[idx]))
        # run simulations
        for _ in range(self.n_simulations):
            board_copy = root_board.copy()
            self._simulate(board_copy, to_move, root)
        # build policy target pi from visit counts
        counts = np.zeros(self.board_size*self.board_size + 1, dtype=np.float32)
        for a, node in root.children.items():
            counts[a] = node.visit_count
        pi = counts / counts.sum() if counts.sum() > 0 else counts
        return pi  # return visit-probabilities

    def _select_child(self, node:MCTSNode, total_N:float):
        # choose action maximizing Q + U
        best_score = -1e9
        best_action = None
        best_child = None
        for a, child in node.children.items():
            Q = child.q_value()
            U = self.c_puct * child.prior * math.sqrt(total_N) / (1 + child.visit_count)
            score = Q + U
            if score > best_score:
                best_score = score
                best_action = a
                best_child = child
        return best_action, best_child

    def _simulate(self, board:Board, to_move:int, root:MCTSNode):
        path = []
        node = root
        player = to_move
        # selection & expansion
        while True:
            if len(node.children) == 0:
                # leaf
                break
            total_N = sum(child.visit_count for child in node.children.values())
            action_idx, child = self._select_child(node, total_N)
            # apply action
            if action_idx == self.board_size*self.board_size:
                mv = None
            else:
                x = action_idx % self.board_size
                y = action_idx // self.board_size
                mv = (x,y)
            board._play_move_no_checks(player, mv)
            path.append((node, action_idx, player))
            # if the chosen child has no children yet, expand it using network
            node = child
            player = opponent(player)
            if len(node.children) == 0:
                break
        # Evaluate leaf with neural net
        state = torch.tensor(board.get_state_planes(player), dtype=torch.float32).unsqueeze(0)
        with torch.no_grad():
            logp, v = self.net(state.to(self.device))
        probs = torch.exp(logp).squeeze(0).cpu().numpy()
        # expand node with legal moves
        legal = board.legal_moves(player)
        for mv in legal:
            if mv is None:
                idx = self.board_size*self.board_size
            else:
                idx = mv[1]*self.board_size + mv[0]
            if idx not in node.children:
                node.children[idx] = MCTSNode(prior=float(probs[idx]))
        value = float(v.item())
        # backup value up the path
        for parent, action_idx, player_at_parent in reversed(path):
            # note: v is from viewpoint of 'player' who was to move at leaf; when backpropagating,
            # we need the value from parent perspective: if parent.player != player, flip sign.
            parent.visit_count += 1
            parent.value_sum += value
            value = -value  # alternate perspective
        # if path empty (root leaf), still need to update root? handled in outer loop via children counts
        # When we expanded root.children initially, root had children; after first selection we update nodes above.


In [5]:
# ---------------------------
# 4) Replay buffer and self-play
# ---------------------------
Example = collections.namedtuple('Example', ['state', 'pi', 'z'])

class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.buffer = collections.deque(maxlen=capacity)

    def push(self, example:Example):
        self.buffer.append(example)

    def sample(self, batch_size:int):
        batch = random.sample(self.buffer, min(batch_size, len(self.buffer)))
        states = np.stack([e.state for e in batch], axis=0)  # [B, 3, H, W]
        pis = np.stack([e.pi for e in batch], axis=0)       # [B, A]
        zs = np.array([e.z for e in batch], dtype=np.float32)  # [B]
        return states, pis, zs

    def __len__(self):
        return len(self.buffer)

def self_play_episode(net:AlphaNet, mcts:MCTS, board_size:int, n_simulations:int, temperature:float=1.0):
    board = Board(size=board_size)
    to_move = BLACK
    examples:List[Example] = []
    move_no = 0
    while not board.game_over() and move_no < board_size*board_size*4:
        pi = mcts.run(board, to_move)  # uses net internally
        # apply temperature: simple softmax on visit counts (here pi already normalized)
        if temperature == 0:
            # pick argmax
            a = int(np.argmax(pi))
        else:
            p = pi ** (1.0 / temperature)
            p = p / (p.sum() + 1e-12)
            a = int(np.random.choice(len(p), p=p))
        # store example: state plane for current player, policy pi (from MCTS), z unknown yet
        state = board.get_state_planes(to_move)
        examples.append(Example(state=state, pi=pi.copy(), z=None))
        # convert a to move
        if a == board_size*board_size:
            mv = None
        else:
            mv = (a % board_size, a // board_size)
        board._play_move_no_checks(to_move, mv)
        to_move = opponent(to_move)
        move_no += 1
    # game finished: compute z for each example (from perspective of player to move at that state)
    winner = board.winner()
    for ex in examples:
        # if winner == BLACK -> z=+1 for states where to_move was BLACK, else -1
        # but ex.state includes turn_plane indicating player at that state: turn_plane=1.0 means black to move
        turn_plane = ex.state[2]
        black_to_move = bool(turn_plane[0,0] > 0.5)
        if winner == BLACK:
            z = 1.0 if black_to_move else -1.0
        else:
            z = 1.0 if not black_to_move else -1.0  # if white won, z=+1 for white-to-move states
        # replace example with z
        ex_z = Example(state=ex.state, pi=ex.pi, z=z)
        yield ex_z


In [6]:
# ---------------------------
# 5) Training loop (simple)
# ---------------------------
def train_loop(board_size=9,
               n_iterations=100,
               self_play_games_per_iter=10,
               mcts_simulations=50,
               batch_size=64,
               epochs=1,
               device='cpu'):
    device = torch.device(device)
    net = AlphaNet(board_size=board_size, num_res=3, channels=64).to(device)
    mcts = MCTS(net, board_size=board_size, c_puct=1.0, n_simulations=mcts_simulations, device=device)
    buffer = ReplayBuffer(capacity=20000)
    optimizer = optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-4)

    for it in range(n_iterations):
        # 1) self-play to fill buffer
        for g in range(self_play_games_per_iter):
            for ex in self_play_episode(net, mcts, board_size, mcts_simulations, temperature=1.0):
                buffer.push(ex)
        print(f"Iter {it}: buffer size = {len(buffer)}")
        # 2) train network on buffer
        for epoch in range(epochs):
            if len(buffer) < batch_size:
                continue
            states, pis, zs = buffer.sample(batch_size)
            states_t = torch.tensor(states, dtype=torch.float32).to(device)
            pis_t = torch.tensor(pis, dtype=torch.float32).to(device)
            zs_t = torch.tensor(zs, dtype=torch.float32).to(device)
            logp, v = net(states_t)
            # policy loss: cross entropy between pi (target visit dist) and predicted logp
            policy_loss = - (pis_t * logp).sum(dim=1).mean()
            value_loss = ((v - zs_t) ** 2).mean()
            loss = policy_loss + value_loss
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(f"Iter {it}: training done")

    return net


In [7]:
# ---------------------------
# 6) Quick run
# ---------------------------
if __name__ == "__main__":
    # quick demo: small run, not intended to reach strong play
    net = train_loop(board_size=9,
                     n_iterations=10,
                     self_play_games_per_iter=2,
                     mcts_simulations=32,
                     batch_size=32,
                     epochs=1,
                     device='cpu')
    print("Done")

Iter 0: buffer size = 194
Iter 0: training done
Iter 1: buffer size = 350
Iter 1: training done
Iter 2: buffer size = 458
Iter 2: training done
Iter 3: buffer size = 560
Iter 3: training done
Iter 4: buffer size = 600
Iter 4: training done
Iter 5: buffer size = 689
Iter 5: training done
Iter 6: buffer size = 723
Iter 6: training done
Iter 7: buffer size = 822
Iter 7: training done
Iter 8: buffer size = 864
Iter 8: training done
Iter 9: buffer size = 899
Iter 9: training done
Done


In [9]:
# watch_match.py  (или вставь в конец az_go.py)
import time
import torch
import numpy as np

# опционально: pygame рендер
try:
    import pygame
    PYGAME_OK = True
except Exception:
    PYGAME_OK = False

def load_net(path: str, board_size: int, device='cpu'):
    """Создает AlphaNet и загружает веса из path (torch .pth/.pt)."""
    net = AlphaNet(board_size=board_size, num_res=3, channels=64)
    net.load_state_dict(torch.load(path, map_location=device))
    net.to(device)
    net.eval()
    return net

def board_to_ascii(board: Board):
    """Возвращает строковое представление доски ('.' empty, 'X' black, 'O' white)."""
    lines = []
    for y in range(board.size):
        row = []
        for x in range(board.size):
            v = board.get(x, y)
            if v == EMPTY:
                row.append('.')
            elif v == BLACK:
                row.append('X')
            else:
                row.append('O')
        lines.append(''.join(row))
    return '\n'.join(lines)

class PygameRenderer:
    def __init__(self, board_size=9, cell_size=30, padding=20):
        if not PYGAME_OK:
            raise RuntimeError("pygame not available")
        self.board_size = board_size
        self.cell_size = cell_size
        self.padding = padding
        self.width = board_size * cell_size + padding * 2
        self.height = self.width
        self.last_pass = False
        pygame.init()
        self.screen = pygame.display.set_mode((self.width, self.height))
        pygame.display.set_caption("Watch Match")
        self.font = pygame.font.SysFont(None, 24)

    def draw(self, board: Board, to_move):
        self.screen.fill((210, 170, 110))
        PADDING = self.padding
        CELL = self.cell_size
        STONE_R = int(CELL * 0.4)
        # grid
        for i in range(self.board_size):
            pygame.draw.line(
                self.screen, (0,0,0),
                (PADDING, PADDING + i*CELL),
                (PADDING + (self.board_size-1)*CELL, PADDING + i*CELL), 1)
            pygame.draw.line(
                self.screen, (0,0,0),
                (PADDING + i*CELL, PADDING),
                (PADDING + i*CELL, PADDING + (self.board_size-1)*CELL), 1)
        # hoshi
        hoshi = [3, 9, 15]
        for hx in hoshi:
            for hy in hoshi:
                if 0 <= hx < self.board_size and 0 <= hy < self.board_size:
                    pygame.draw.circle(self.screen, (0,0,0),
                        (PADDING + hx*CELL, PADDING + hy*CELL), max(2, int(CELL*0.08)))
        # stones
        for y in range(self.board_size):
            for x in range(self.board_size):
                v = board.get(x,y)
                if v == BLACK:
                    pygame.draw.circle(self.screen, (0,0,0),
                        (PADDING + x*CELL, PADDING + y*CELL), STONE_R)
                elif v == WHITE:
                    pygame.draw.circle(self.screen, (255,255,255),
                        (PADDING + x*CELL, PADDING + y*CELL), STONE_R)
        # turn text
        txt = "Black to move" if to_move == BLACK else "White to move"
        if self.last_pass: txt = "White passed" if to_move == BLACK else "Black passed"
        surf = self.font.render(txt, True, (0,0,0))
        self.screen.blit(surf, (5,5))
        pygame.display.flip()

def pick_action_from_pi(pi: np.ndarray, board: Board, board_size:int, temperature:float=0.0):
    """pi is visit-probabilities (length board_size*board_size+1). Returns chosen action index."""
    # mask illegal moves
    current_player = None
    legal = board.legal_moves(current_player)  # we will set current_player outside, but easier pass as global? avoid
    # Better: build legal mask here by checking each index
    mask = np.zeros_like(pi, dtype=np.bool_)
    for mv in board.legal_moves(current_player):
        if mv is None:
            idx = board_size*board_size
        else:
            idx = mv[1]*board_size + mv[0]
        mask[idx] = True
    masked = np.copy(pi)
    masked[~mask] = 0.0
    s = masked.sum()
    if s <= 0:
        # fallback: uniform over legal
        legal_idxs = np.where(mask)[0]
        return int(np.random.choice(legal_idxs))
    masked = masked / s
    if temperature == 0.0:
        return int(np.argmax(masked))
    else:
        p = masked ** (1.0 / temperature)
        p = p / p.sum()
        return int(np.random.choice(len(p), p=p))

def play_match(net1: AlphaNet,
               net2: AlphaNet,
               board_size:int=9,
               mcts_simulations:int=160,
               c_puct:float=1.0,
               device='cpu',
               render: str = 'console',  # 'console' or 'pygame'
               pause_time: float = 0.5,
               temperature: float = 0.0):
    """
    Провести партию между net1 (играет Black) и net2 (White).
    render = 'console' или 'pygame'. pause_time в секундах между ходами.
    """
    # подготовка MCTS для каждой сети
    mcts1 = MCTS(net1, board_size=board_size, c_puct=c_puct, n_simulations=mcts_simulations, device=device)
    mcts2 = MCTS(net2, board_size=board_size, c_puct=c_puct, n_simulations=mcts_simulations, device=device)

    board = Board(size=board_size)
    to_move = BLACK
    move_no = 0

    renderer = None
    if render == 'pygame':
        if not PYGAME_OK:
            print("pygame not available, falling back to console render")
            render = 'console'
        else:
            renderer = PygameRenderer(board_size=board_size, cell_size=30, padding=20)

    # main loop
    while not board.game_over() and move_no < board_size*board_size*4:
        if render == 'console':
            print("\nMove", move_no, "to_move:", "BLACK" if to_move==BLACK else "WHITE")
            print(board_to_ascii(board))
        elif render == 'pygame':
            renderer.draw(board, to_move)
            # handle events to allow window close
            for ev in pygame.event.get():
                if ev.type == pygame.QUIT:
                    pygame.quit()
                    return

        # choose MCTS for current player
        if to_move == BLACK:
            pi = mcts1.run(board, to_move)  # visit-probs
        else:
            pi = mcts2.run(board, to_move)

        # pick action deterministically if temperature==0
        # convert pi to mask and argmax
        # build legal mask and choose
        size = board_size
        mask = np.zeros_like(pi, dtype=np.bool_)
        for mv in board.legal_moves(to_move):
            if mv is None:
                idx = size*size
            else:
                idx = mv[1]*size + mv[0]
            mask[idx] = True
        masked = np.copy(pi)
        masked[~mask] = 0.0
        if masked.sum() <= 0:
            # fallback random legal
            legal = board.legal_moves(to_move)
            mv = random.choice(legal)
        else:
            if temperature == 0.0:
                a = int(np.argmax(masked))
            else:
                p = masked ** (1.0 / temperature)
                p = p / p.sum()
                a = int(np.random.choice(len(p), p=p))
            if a == size*size:
                mv = None
            else:
                mv = (a % size, a // size)
            renderer.last_pass = mv == None

        # apply move (we use internal _play_move_no_checks to avoid extra superko checks done by is_legal,
        # but to be safe, you can check is_legal before applying)
        if mv is None:
            board._play_move_no_checks(to_move, None)
        else:
            # ensure legal
            if board.is_legal(to_move, mv):
                board._play_move_no_checks(to_move, mv)
            else:
                # illegal (shouldn't happen if mask built correctly), pick random legal
                legal = board.legal_moves(to_move)
                mv = random.choice(legal)
                board._play_move_no_checks(to_move, mv)

        to_move = opponent(to_move)
        move_no += 1
        time.sleep(pause_time)

    # finished
    bscore, wscore = board.score()
    winner = board.winner()
    if render == 'console':
        print("\nFinal board:")
        print(board_to_ascii(board))
        print(f"Score Black: {bscore:.1f}, White(with komi): {wscore:.1f}")
        print("Winner:", "BLACK" if winner==BLACK else "WHITE")
    elif render == 'pygame':
        renderer.draw(board, to_move)
        print("Game finished. Close window to exit.")
        # keep window open until closed
        while True:
            for ev in pygame.event.get():
                if ev.type == pygame.QUIT:
                    pygame.quit()
                    return
            time.sleep(0.1)

# ---------------------------
# Example usage:
# ---------------------------
# 1) если у тебя есть сохраненные веса:
# net_black = load_net("best_black.pth", board_size=9, device='cpu')
# net_white = load_net("best_white.pth", board_size=9, device='cpu')
# play_match(net_black, net_white, board_size=9, mcts_simulations=128, render='console')

# 2) если у тебя есть одна сеть (самонаблюдение), можно использовать одну и ту же модель:
# net = load_net("model.pth", board_size=9)
# play_match(net, net, board_size=9, mcts_simulations=128, render='pygame', pause_time=0.4)

# 3) если ты вызвал train_loop и получил net object:
# net_trained = train_loop(...)
# play_match(net_trained, net_trained, board_size=9, mcts_simulations=128, render='console')
play_match(net, net, board_size=9, mcts_simulations=128, render='pygame', pause_time=0.4)


Game finished. Close window to exit.
