In [1]:
# Minimal AlphaZero-style skeleton for Go (self-play + MCTS + neural net)
# Dependencies: torch, numpy

import math
import random
import collections
import copy
from typing import List, Tuple, Optional, Set, Dict
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import threading
import time

In [2]:
# ---------------------------
# 1) Basic Go board
# ---------------------------
EMPTY = 0
BLACK = 1
WHITE = 2
# Opponent's color
def opponent(player: int) -> int:
    return BLACK if player == WHITE else WHITE

class Board:
    def __init__(self, size: int = 9, komi: float = None, history_len: int = 8):
        if komi is None:
            match (size): # New switch case notation
                case (19):
                    komi = 5.5
                case (13):
                    komi = 4.5
                case(9):
                    komi = 3.5
                case (_):
                    komi = 1.5
            # end match
        self.size = size
        self.komi = komi
        self.grid = np.zeros((size, size), dtype=np.int8)
        self.history: Set[Tuple[Tuple[int, ...], int]] = set()
        self.past_states: deque = deque(maxlen=history_len)
        # counters captures
        self.black_captures = 0
        self.white_captures = 0
        # ko point: (x,y) or None
        self.ko: Optional[Tuple[int, int]] = None
        # number of consecutive passes
        self.pass_count = 0
        # initially add starting position with Black to move
        self._add_history(player=BLACK)

    def copy(self) -> 'Board':
        b = Board(self.size, self.komi, history_len=self.past_states.maxlen)
        b.grid = self.grid.copy()
        b.history = set(self.history)
        if len(self.past_states):
            # copy elements safely
            copied = (s.copy() if isinstance(s, np.ndarray) else s for s in self.past_states)
            b.past_states = deque(copied, maxlen=self.past_states.maxlen)
        else:
            b.past_states = deque(maxlen=self.past_states.maxlen)
        # copy capture counters and ko state
        b.black_captures = self.black_captures
        b.white_captures = self.white_captures
        b.ko = None if self.ko is None else (self.ko[0], self.ko[1])
        b.pass_count = self.pass_count
        return b

    def in_bounds(self, x:int, y:int) -> bool:
        return 0 <= x < self.size and 0 <= y < self.size

    def neighbors(self, x:int, y:int):
        for dx, dy in ((1,0),(-1,0),(0,1),(0,-1)):
            nx, ny = x+dx, y+dy
            if self.in_bounds(nx, ny):
                yield nx, ny

    def _board_key(self, player: int) -> Tuple[Tuple[int, ...], int]:
        # включает игрока для корректного суперко
        return (tuple(self.grid.ravel().tolist()), player)

    def _add_history(self, player: int):
        self.history.add(self._board_key(player))
        # сохраняем копию сетки для NN-плоскостей
        self.past_states.append(self.grid.copy())

    def get(self, x:int, y:int) -> int:
        return int(self.grid[y,x])

    def set(self, x:int, y:int, v:int):
        self.grid[y,x] = v

    def _collect_group(self, x:int, y:int):
        color = self.get(x, y)
        if color == EMPTY:
            return [], set()
        visited = set()
        stack = [(x,y)]
        group = []
        liberties = set()
        while stack:
            cx, cy = stack.pop()
            if (cx, cy) in visited:
                continue
            visited.add((cx, cy))
            group.append((cx, cy))
            for nx, ny in self.neighbors(cx, cy):
                val = self.get(nx, ny)
                if val == color and (nx, ny) not in visited:
                    stack.append((nx, ny))
                elif val == EMPTY:
                    liberties.add((nx, ny))
        return group, liberties

    def _remove_group(self, group: List[Tuple[int, int]]) -> int:
        if not group:
            return 0
        col = self.get(group[0][0], group[0][1])
        for x,y in group:
            self.set(x,y,EMPTY)
        return len(group)

    def _play_move_no_checks(self, player:int, move:Optional[Tuple[int,int]]) -> int:
        # move: None means pass
        if move is None:
            self.pass_count += 1
            return 0
        x,y = move
        self.set(x,y, player)
        self.pass_count = 0
        # capture adjacent enemy groups with zero liberties
        removed = 0
        to_remove = set()
        for nx, ny in self.neighbors(x,y):
            if self.get(nx, ny) == opponent(player):
                g, libs = self._collect_group(nx, ny)
                if len(libs) == 0:
                    to_remove.update(g)
        for rx, ry in to_remove:
            self.set(rx, ry, EMPTY)
            removed += 1
        if removed:
            if player == BLACK:
                self.black_captures += removed
            else:
                self.white_captures += removed
        if removed == 1:
            removed_list = list(to_remove)
            rx, ry = removed_list[0]
            # ko — это точка, где сейчас нет камня и все соседи заняты так, что возврат создает повтор
            self.ko = (rx, ry)
        else:
            self.ko = None
        return removed

    def is_suicide(self, player:int, move:Tuple[int,int]) -> bool:
        x,y = move
        tmp_grid = self.grid.copy()
        tmp_grid[y, x] = player
        removed_any = False
        # remove opponent dead groups
        for nx, ny in self.neighbors(x, y):
            if tmp_grid[ny, nx] == opponent(player):
                # собрать группу на временной сетке
                # реализуем сборку локально, чтобы не менять self
                stack = [(nx, ny)]
                visited = set()
                libs = set()
                while stack:
                    cx, cy = stack.pop()
                    if (cx, cy) in visited:
                        continue
                    visited.add((cx, cy))
                    for sx, sy in ((cx+1, cy), (cx-1, cy), (cx, cy+1), (cx, cy-1)):
                        if 0 <= sx < self.size and 0 <= sy < self.size:
                            val = tmp_grid[sy, sx]
                            if val == opponent(player) and (sx, sy) not in visited:
                                stack.append((sx, sy))
                            elif val == EMPTY:
                                libs.add((sx, sy))
                if not libs:
                    # эта группа будет снята
                    removed_any = True
                    # очистим ее на tmp_grid
                    for (gx, gy) in visited:
                        tmp_grid[gy, gx] = EMPTY
        # теперь проверить свободы собственной группы
        stack = [(x, y)]
        visited = set()
        libs = set()
        while stack:
            cx, cy = stack.pop()
            if (cx, cy) in visited:
                continue
            visited.add((cx, cy))
            for sx, sy in ((cx+1, cy), (cx-1, cy), (cx, cy+1), (cx, cy-1)):
                if 0 <= sx < self.size and 0 <= sy < self.size:
                    val = tmp_grid[sy, sx]
                    if val == player and (sx, sy) not in visited:
                        stack.append((sx, sy))
                    elif val == EMPTY:
                        libs.add((sx, sy))
        return len(libs) == 0

    def is_legal(self, player:int, move:Optional[Tuple[int,int]]) -> bool:
        if move is None:
            return True
        x,y = move
        if not self.in_bounds(x,y): return False
        if self.get(x,y) != EMPTY: return False
        if self.ko is not None and (x, y) == self.ko:
            pass
        # suicide
        if self.is_suicide(player, (x,y)): return False
        # superko (simple): simulate and check history
        tmp = self.copy()
        tmp._play_move_no_checks(player, (x,y))
        next_player = opponent(player)
        if tmp._board_key(next_player) in self.history:
            return False
        return True

    def legal_moves(self, player:int) -> List[Optional[Tuple[int,int]]]:
        moves: List[Optional[Tuple[int, int]]] = []
        for y in range(self.size):
            for x in range(self.size):
                if self.grid[y,x] == EMPTY and self.is_legal(player, (x,y)):
                    moves.append((x,y))
        moves.append(None)  # pass
        if len(moves) > self.size * 0.25: # NO MORE INFINITE PASSES
            non_pass_moves = [m for m in moves if m != None]
            if non_pass_moves:
                moves = non_pass_moves
        return moves
    
    def play_move(self, player: int, move: Optional[Tuple[int, int]]) -> bool:
        """Попытаться сыграть ход; вернуть True если ход сделан, False если нелегален."""
        if not self.is_legal(player, move):
            return False
        # выполняем ход и обновляем историю с учетом следующего игрока
        removed = self._play_move_no_checks(player, move)
        next_player = opponent(player)
        # обновляем history и past_states
        self._add_history(next_player)
        return True

    def game_over(self) -> bool:
        return self.pass_count >= 2

    def score_chinese(self) -> Tuple[float, float]:
        # area scoring (stones + territory)
        size = self.size
        visited = np.zeros((size, size), dtype=bool)
        black_area = int((self.grid == BLACK).sum())
        white_area = int((self.grid == WHITE).sum())
        from collections import deque as _dq
        for y in range(size):
            for x in range(size):
                if self.get(x, y) != EMPTY or visited[y, x]:
                    continue
                q = _dq()
                q.append((x, y))
                visited[y, x] = True
                region = [(x, y)]
                bordering = set()
                while q:
                    cx, cy = q.popleft()
                    for nx, ny in self.neighbors(cx, cy):
                        val = self.get(nx, ny)
                        if val == EMPTY and not visited[ny, nx]:
                            visited[ny, nx] = True
                            q.append((nx, ny))
                            region.append((nx, ny))
                        elif val in (BLACK, WHITE):
                            bordering.add(val)
                if len(bordering) == 1:
                    owner = next(iter(bordering))
                    if owner == BLACK:
                        black_area += len(region)
                    else:
                        white_area += len(region)
        # komi к белым
        white_area += self.komi
        return float(black_area), float(white_area)

    def score_japanese(self) -> Tuple[float, float]:
        # territory + captures
        size = self.size
        visited = np.zeros((size, size), dtype=bool)
        black_territory = 0
        white_territory = 0
        from collections import deque as _dq
        for y in range(size):
            for x in range(size):
                if self.get(x, y) != EMPTY or visited[y, x]:
                    continue
                q = _dq()
                q.append((x, y))
                visited[y, x] = True
                region = [(x, y)]
                bordering = set()
                while q:
                    cx, cy = q.popleft()
                    for nx, ny in self.neighbors(cx, cy):
                        val = self.get(nx, ny)
                        if val == EMPTY and not visited[ny, nx]:
                            visited[ny, nx] = True
                            q.append((nx, ny))
                            region.append((nx, ny))
                        elif val in (BLACK, WHITE):
                            bordering.add(val)
                if len(bordering) == 1:
                    owner = next(iter(bordering))
                    if owner == BLACK:
                        black_territory += len(region)
                    else:
                        white_territory += len(region)
        black_score = black_territory + self.black_captures
        white_score = white_territory + self.white_captures + self.komi
        return float(black_score), float(white_score)

    def winner(self, japanese: bool = False) -> int:
        if japanese:
            b, w = self.score_japanese()
        else:
            b, w = self.score_chinese()
        return BLACK if b > w else WHITE
    
    def get_state_planes(self, player: int, history_len: int = 8) -> np.ndarray:
        """Возвращает плоскости для NN: shape = (C, size, size).
        Формат:
        - для каждой из последних history_len позиций: 2 плоскости (black, white)
        - плоскость to_move (всегда 1.0, если player == BLACK)
        - плоскость ko (1 в точке ko)
        Итоговое число плоскостей = 2*history_len + 2"""
        hl = min(history_len, self.past_states.maxlen)
        planes = []
        # собираем последние hl позиций, от старых к новым
        # если истории меньше, заполняем нулями
        states = list(self.past_states)
        pad = hl - len(states)
        for _ in range(pad):
            planes.append(np.zeros((self.size, self.size), dtype=np.float32)) # black
            planes.append(np.zeros((self.size, self.size), dtype=np.float32)) # white
        # взять последние hl состояний
        for s in states[-hl:]:
            black_plane = (s == BLACK).astype(np.float32)
            white_plane = (s == WHITE).astype(np.float32)
            planes.append(black_plane)
            planes.append(white_plane)
        # to_move
        turn_plane = np.full((self.size, self.size), 1.0 if player == BLACK else 0.0, dtype=np.float32)
        planes.append(turn_plane)
        # ko
        ko_plane = np.zeros((self.size, self.size), dtype=np.float32)
        if self.ko is not None:
            kx, ky = self.ko
            ko_plane[ky, kx] = 1.0
        planes.append(ko_plane)
        return np.stack(planes, axis=0)

In [3]:
# ---------------------------
# 2) Neural network (PyTorch)
# ---------------------------
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False)
        self.gn = nn.GroupNorm(8, out_ch)
    def forward(self, x):
        return torch.relu(self.gn(self.conv(x)))

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
        self.gn1 = nn.GroupNorm(8, channels)
        self.conv2 = nn.Conv2d(channels, channels, 3, padding=1, bias=False)
        self.gn2 = nn.GroupNorm(8, channels)
    def forward(self, x):
        out = torch.relu(self.gn1(self.conv1(x)))
        out = self.gn2(self.conv2(out))
        out = torch.relu(out + x)
        return out

class AlphaNet(nn.Module):
    def __init__(self, board_size=9, in_channels=18, num_res=6, channels=128):
        super().__init__()
        self.board_size = board_size
        self.conv_in = ConvBlock(in_channels, channels)
        self.resblocks = nn.Sequential(*[ResidualBlock(channels) for _ in range(num_res)])

        # policy head
        self.p_conv1 = nn.Conv2d(channels, 32, 1, bias=False)
        self.p_gn1 = nn.GroupNorm(8, 32)
        self.p_conv2 = nn.Conv2d(32, 2, 1, bias=False)
        self.p_gn2 = nn.GroupNorm(2, 2)
        self.p_fc = nn.Linear(2 * board_size * board_size, board_size * board_size + 1)  # +1 for pass
        # value head
        self.v_conv1 = nn.Conv2d(channels, 32, 1, bias=False)
        self.v_gn1 = nn.GroupNorm(8, 32)
        self.v_fc1 = nn.Linear(32 * board_size * board_size, 128)
        self.v_fc2 = nn.Linear(128, 1)
    def forward(self, x, legal_mask=None):
        out = self.conv_in(x)
        out = self.resblocks(out)
        # policy
        p = torch.relu(self.p_gn1(self.p_conv1(out)))
        p = torch.relu(self.p_gn2(self.p_conv2(p)))
        p = p.view(p.size(0), -1)
        logits = self.p_fc(p)

        if legal_mask is not None:
            logits = logits.masked_fill(~legal_mask, float("-inf"))

        logp = torch.log_softmax(logits, dim=1)
        # value
        v = torch.relu(self.v_gn1(self.v_conv1(out)))
        v = v.view(v.size(0), -1)
        v = torch.relu(self.v_fc1(v))
        value = torch.tanh(self.v_fc2(v)).squeeze(1)
        
        return logp, value  # log probabilities and scalar value


In [4]:
# ---------------------------
# 3) MCTS with PUCT
# ---------------------------
class MCTSNode:
    def __init__(self, prior:float=0.0):
        self.prior = float(prior)   # P(a)
        self.visit_count = 0        # N
        self.value_sum = 0.0        # W
        self.children = {}          # action -> MCTSNode
        self.lock = threading.Lock() # for parallel updates
        self.virtual_loss = 0       # for virtual loss

    def q_value(self) -> float:
        if self.visit_count == 0:
            return 0.0
        return self.value_sum / self.visit_count
    
    def n(self) -> int:
        return self.visit_count

class MCTS:
    def __init__(self, net:AlphaNet, board_size:int, c_puct:float=1.0, n_simulations:int=80, device='cpu'):
        self.net = net
        self.board_size = board_size
        self.c_puct = c_puct
        self.n_simulations = n_simulations
        self.device = device
        self.prof = {'nn': 0.0, 'copy': 0.0, 'legal': 0.0, 'sim': 0.0}

    def run(self, root_board:Board, to_move:int, eps:float=0.25, alpha:float=0.03, temperature: float = 1.0, parallel: bool = False):
        root = MCTSNode()
        size = self.board_size

        t0 = time.perf_counter()
        # evaluate root with network to get priors
        state = torch.tensor(root_board.get_state_planes(to_move), dtype=torch.float32).unsqueeze(0)
        with torch.no_grad():
            logp, v = self.net(state.to(self.device))
        probs = torch.exp(logp).squeeze(0).cpu().numpy()  # length = H*W + 1
        self.prof['nn'] += time.perf_counter() - t0

        # build list of legal indices at root
        legal = root_board.legal_moves(to_move)
        legal_indices: List[int] = []
        for mv in legal:
            if mv is None:
                idx = size*size
            else:
                idx = mv[1]*size + mv[0]
            legal_indices.append(idx)

        # mask illegal priors, then renormalize
        masked_priors = np.zeros_like(probs, dtype=np.float32)
        masked_priors[legal_indices] = probs[legal_indices]
        total = masked_priors.sum()
        if total > 0:
            masked_priors /= total
        else:
            # fallback uniform over legal
            for idx in legal_indices:
                masked_priors[idx] = 1.0 / len(legal_indices)
        
        # add Dirichlet noise for root during self-play
        dir_noise = np.random.dirichlet([alpha] * len(legal_indices))
        for i, idx in enumerate(legal_indices):
            masked_priors[idx] = (1 - eps) * masked_priors[idx] + eps * dir_noise[i]

        # create root children with priors 
        for idx in legal_indices:
            root.children[idx] = MCTSNode(prior=float(masked_priors[idx]))

        # run simulations
        if parallel:
            threads = []
            for _ in range(self.n_simulations):
                th = threading.Thread(target=self._simulate, args=(root_board.copy(), to_move, root))
                th.start()
                threads.append(th)
            for th in threads:
                th.join()
        else:
            for _ in range(self.n_simulations):
                board_copy = root_board.copy()
                self._simulate(board_copy, to_move, root)

        # build policy target pi from visit counts
        counts = np.zeros(self.board_size*self.board_size + 1, dtype=np.float32)
        for a, node in root.children.items():
            counts[a] = node.visit_count
        total = counts.sum()
        if total <= 0:
            # fallback uniform over legal actions
            for idx in legal_indices:
                counts[idx] = 1.0
            total = counts.sum()
        pi = counts / total
        #return pi  # return visit-probabilities
        # temperature: sample or take argmax depending on temperature
        if temperature == 0 or temperature < 1e-6:
            action = int(np.argmax(pi))
            final_pi = np.zeros_like(pi)
            final_pi[action] = 1.0
            return final_pi
        else:
            # apply temperature to pi
            pi_temp = pi ** (1.0 / temperature)
            pi_temp_sum = pi_temp.sum()
            if pi_temp_sum > 0:
                pi_temp /= pi_temp_sum
            else:
                pi_temp = pi
            return pi_temp

    def _select_child(self, node:MCTSNode, total_N:float) -> Tuple[int, MCTSNode]:
        # choose action maximizing Q + U
        best_score = -1e9
        best_action = None
        best_child = None
        for a, child in node.children.items():
            Q = 0.0
            if child.visit_count > 0:
                Q = child.q_value()
            U = self.c_puct * child.prior * math.sqrt(total_N) / (1 + child.visit_count + child.virtual_loss)
            score = Q + U
            if score > best_score:
                best_score = score
                best_action = a
                best_child = child
        assert best_action is not None
        return best_action, best_child

    def _simulate(self, board:Board, to_move:int, root:MCTSNode):
        start = time.perf_counter()
        node = root
        player = to_move
        path: List[Tuple[MCTSNode, int]] = [] # nodes and actions

        # selection
        while True:
            with node.lock:
                if len(node.children) == 0: # leaf
                    break
                total_N = sum(ch.visit_count for ch in node.children.values())
                action_idx, child = self._select_child(node, total_N)
                child.virtual_loss += 1
            # apply action
            if action_idx == self.board_size*self.board_size:
                mv = None
            else:
                x = action_idx % self.board_size
                y = action_idx // self.board_size
                mv = (x,y)
            board._play_move_no_checks(player, mv)
            path.append((node, action_idx))
            node = child
            player = opponent(player)

        # Evaluate leaf with neural network
        if board.game_over():
            # terminal: value is +1 for winner from viewpoint of last player who moved
            winner = board.winner()
            value = 1.0 if winner == opponent(player) else -1.0
        else:
            t0 = time.perf_counter()
            state = torch.tensor(board.get_state_planes(player), dtype=torch.float32).unsqueeze(0)
            with torch.no_grad():
                logp, v = self.net(state.to(self.device))
            probs = torch.exp(logp).squeeze(0).cpu().numpy()
            self.prof['nn'] += time.perf_counter() - t0
            # expand node with legal moves
            legal = board.legal_moves(player)
            legal_idxs = []
            for mv in legal:
                if mv is None:
                    idx = self.board_size*self.board_size
                else:
                    idx = mv[1]*self.board_size + mv[0]
                legal_idxs.append(idx)
            # mask and renormalize probs for this leaf
            masked = np.zeros_like(probs, dtype=np.float32)
            masked[legal_idxs] = probs[legal_idxs]
            s = masked.sum()
            if s > 0:
                masked /= s
            else:
                for idx in legal_idxs:
                    masked[idx] = 1.0 / len(legal_idxs)
            # create children for the leaf node
            for idx in legal_idxs:
                if idx not in node.children:
                    node.children[idx] = MCTSNode(prior=float(masked[idx]))
            value = float(v.item())
        
        # early termination heuristic: if abs(value) is extremely close to 1 and node depth small, backpropagate
        # backup
        cur_value = value
        for parent, action_idx in reversed(path):
            # remove virtual loss and update stats
            child = parent.children[action_idx]
            with parent.lock:
                # decrement virtual loss applied during selection
                if child.virtual_loss > 0:
                    child.virtual_loss -= 1
                parent.visit_count += 1
                parent.value_sum += cur_value
            cur_value = -cur_value


        # if path empty we still may need to update root if leaf was root
        if len(path) == 0:
            # we're at root leaf; update root directly
            with root.lock:
                root.visit_count += 1
                root.value_sum += cur_value
        self.prof['sim'] += time.perf_counter() - start

    # utility to print profiling
    def print_profile(self):
        print("MCTS profile:")
        for k, v in self.prof.items():
            print(f" {k}: {v:.4f} sec")

In [5]:
# ---------------------------
# 4) Replay buffer and self-play
# ---------------------------
Example = collections.namedtuple('Example', ['state', 'pi', 'z'])

class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.buffer = collections.deque(maxlen=capacity)

    def push(self, example:Example):
        self.buffer.append(example)

    def push_many(self, examples):
        for ex in examples:
            self.buffer.append(ex)

    def sample(self, batch_size:int, return_torch=False, device='cpu'):
        if len(self.buffer) == 0:
            raise ValueError('ReplayBuffer is empty')
        batch_size = min(batch_size, len(self.buffer))
        batch = random.sample(self.buffer, batch_size)

        states = np.stack([e.state for e in batch], axis=0)
        pis = np.stack([e.pi for e in batch], axis=0)
        zs = np.array([e.z for e in batch], dtype=np.float32)

        if return_torch:
            return (
                torch.tensor(states, dtype=torch.float32, device=device),
                torch.tensor(pis, dtype=torch.float32, device=device),
                torch.tensor(zs, dtype=torch.float32, device=device),
            )
        return states, pis, zs

    def __len__(self):
        return len(self.buffer)

def self_play_episode(net:AlphaNet, mcts:MCTS, board_size:int, n_simulations:int, temperature:float=1.0):
    board = Board(size=board_size)
    to_move = BLACK
    examples:List[Example] = []
    move_no = 0
    for move_no in range(board_size * board_size * 4):
        if board.game_over():
            break
        pi = mcts.run(board, to_move)  # uses net internally
        # safety: если MCTS вернул нули (не должно быть, но возможная защита)
        s = pi.sum()
        if s <= 0:
            pi = np.ones_like(pi, dtype=np.float32)
            pi /= pi.size
        # apply temperature: simple softmax on visit counts (here pi already normalized)
        if temperature == 0.0:
            # pick argmax
            a = int(np.argmax(pi))
        else:
            p = pi ** (1.0 / temperature)
            p = p / (p.sum() + 1e-12)
            a = int(np.random.choice(len(p), p=p))
        # store example: state plane for current player, policy pi (from MCTS), z unknown yet
        state = board.get_state_planes(to_move)
        examples.append(Example(state=state, pi=pi.copy(), z=None))
        # convert a to move
        if a == board_size*board_size:
            mv = None
        else:
            mv = (a % board_size, a // board_size)
        if not board.is_legal(to_move, mv):
            # forced pass if no legal moves
            mv = None
        board._play_move_no_checks(to_move, mv)
        to_move = opponent(to_move)
    # game finished: compute z for each example (from perspective of player to move at that state)
    winner = board.winner()
    if winner is None:
        winner = BLACK
    
    # final scoring of examples
    for ex in examples:
        # if winner == BLACK -> z=+1 for states where to_move was BLACK, else -1
        # but ex.state includes turn_plane indicating player at that state: turn_plane=1.0 means black to move
        turn_plane = ex.state[-1]
        black_to_move = bool(turn_plane[0, 0] > 0.5)
        if winner == BLACK:
            z = 1.0 if black_to_move else -1.0
        else:
            z = -1.0 if black_to_move else 1.0  # if white won, z=+1 for white-to-move states
        # replace example with z
        ex_z = Example(state=ex.state, pi=ex.pi, z=z)
        yield ex_z


In [6]:
# ---------------------------
# 5) Training loop (PyTorch)
# ---------------------------
import os

# гиперпараметры по умолчанию, можно поменять
TRAIN_PARAMS = {
    "batch_size": 256,
    "lr": 1e-3,
    "weight_decay": 1e-4,
    "epochs_per_iteration": 1,    # сколько шагов обучения за одну итерацию self-play
    "self_play_games": 100,       # сколько партий собрать до обучения (на CPU может быть меньше)
    "replay_capacity": 200000,
    "n_simulations": 160,         # MCTS sim per move
    "c_puct": 1.5,
    "dirichlet_alpha": 0.3,
    "dirichlet_eps": 0.25,
    "temperature": 1.0,
    "save_dir": "./checkpoints",
    "checkpoint_every": 1,        # сохранять каждая итерация
    "eval_games": 40,
    "num_iters": 1000,
    "grad_clip": 1.0,
    "device": ("cuda" if torch.cuda.is_available() else "cpu"),
}

# util: ensure dir
def ensure_dir(d):
    if not os.path.exists(d):
        os.makedirs(d, exist_ok=True)

# loss: policy cross entropy (negative log likelihood) + value MSE
# weight for value loss:
VALUE_LOSS_WEIGHT = 0.7

def train_step(net: AlphaNet, optimizer, replay: ReplayBuffer, batch_size: int, device: str):
    # sample batch and run one optimizer step
    states_np, pis_np, zs_np = replay.sample(batch_size)
    # convert to tensors
    states = torch.tensor(states_np, dtype=torch.float32, device=device)
    pis = torch.tensor(pis_np, dtype=torch.float32, device=device)
    zs = torch.tensor(zs_np, dtype=torch.float32, device=device)

    print("len(replay) =", len(replay))

    net.train()
    optimizer.zero_grad()
    # compute legal mask if you want; for now assume pis already zeros for illegal moves
    logp, value = net(states)   # logp: [B, A], value: [B]
    # policy loss: cross entropy with target pi (we use -sum pi * logp)
    # note: if some pi entries are zero it's fine
    policy_loss = - (pis * logp).sum(dim=1).mean()
    value_loss = (value - zs).pow(2).mean()
    loss = policy_loss + VALUE_LOSS_WEIGHT * value_loss
    loss.backward()
    torch.nn.utils.clip_grad_norm_(net.parameters(), TRAIN_PARAMS['grad_clip'])
    optimizer.step()
    return {
        "loss": float(loss.item()),
        "policy_loss": float(policy_loss.item()),
        "value_loss": float(value_loss.item())
    }

# optional: simple prioritized replay wrapper (sum-tree not implemented here; lightweight version)
class PrioritizedReplay:
    # lightweight: store (example, priority), sample weighted by priority
    def __init__(self, capacity=100000, alpha=0.6):
        self.capacity = capacity
        self.alpha = alpha
        self.buffer = []
        self.pos = 0
        self.max_prio = 1.0

    def push(self, example):
        prio = self.max_prio
        if len(self.buffer) < self.capacity:
            self.buffer.append((example, prio))
        else:
            self.buffer[self.pos] = (example, prio)
            self.pos = (self.pos + 1) % self.capacity

    def sample(self, batch_size):
        if len(self.buffer) == 0:
            raise ValueError("empty buffer")
        prios = np.array([p for (_, p) in self.buffer], dtype=np.float32)
        probs = prios ** self.alpha
        probs /= probs.sum()
        idxs = np.random.choice(len(self.buffer), size=min(batch_size, len(self.buffer)), p=probs, replace=False)
        batch = [self.buffer[i][0] for i in idxs]
        # return also indices so caller can update priorities
        return batch, idxs

    def update_priorities(self, idxs, priorities):
        for i, pr in zip(idxs, priorities):
            self.buffer[i] = (self.buffer[i][0], pr)
            self.max_prio = max(self.max_prio, pr)

    def __len__(self):
        return len(self.buffer)


# self-play worker sketch using multiprocessing.Process
# it runs games and writes Examples to a queue for the trainer to read and insert into replay
import multiprocessing as mp

def self_play_worker(proc_id: int, net: AlphaNet, queue: mp.Queue, params: dict, stop_event: mp.Event):
    # local MCTS uses a copy of network (weights are read-only here)
    local_mcts = MCTS(net, board_size=net.board_size, c_puct=params["c_puct"],
                      n_simulations=params["n_simulations"], device=params["device"])
    random_seed = int(time.time()) + proc_id * 1000
    random.seed(random_seed)
    np.random.seed(random_seed)
    while not stop_event.is_set():
        # play single episode
        for ex in self_play_episode(net, local_mcts, net.board_size, params["n_simulations"], temperature=params["temperature"]):
            # push to queue as raw Example
            queue.put(ex)
        # small sleep to yield CPU
        time.sleep(0.01)
    return

# main training loop
def train_loop(net: AlphaNet, params: dict):
    device = params["device"]
    net.to(device)
    # optimizer and scheduler
    optimizer = torch.optim.AdamW(net.parameters(), lr=params["lr"], weight_decay=params["weight_decay"])
    # optional scheduler, e.g. cosine or step LR
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=1000, eta_min=1e-5)

    # replay buffer
    replay = ReplayBuffer(capacity=params["replay_capacity"])
    # or prioritized: 
    # replay = PrioritizedReplay(capacity=params["replay_capacity"])

    ensure_dir(params["save_dir"])

    # queue for self-play workers
    manager = mp.Manager()
    queue = manager.Queue(maxsize=10000)
    stop_event = manager.Event()

    # spawn workers
    n_workers = max(1, mp.cpu_count() - 1)
    workers = []
    for i in range(n_workers):
        # send net state dict to worker via deepcopy to avoid heavy IPC; worker uses same net object if in-process
        p = mp.Process(target=self_play_worker, args=(i, net, queue, params, stop_event))
        p.start()
        workers.append(p)

    try:
        iter_no = 0
        while iter_no < params["num_iters"]:
            if iter_no > 15:
                fe = int(input())
                if fe != "":
                    break
            # collect self-play games until we have enough
            collected = 0
            target_games = params["self_play_games"]
            while collected < target_games:
                try:
                    ex = queue.get(timeout=5.0)  # block waiting for new example
                except Exception:
                    # timeout or empty queue; can continue waiting
                    continue
                replay.push(ex)
                # you might want to group examples by game to ensure episode integrity
                collected += 1  # this counts examples; if you prefer games, change producer to signal game boundaries
                # print progress occasionally
                if len(replay) % 1000 == 0:
                    print(f"Replay size: {len(replay)}")
            # train for some epochs / steps
            for ep in range(params["epochs_per_iteration"]):
                # number of gradient steps per epoch can be set relative to self_play_games
                steps = max(1, len(replay) // params["batch_size"])
                for step in range(steps):
                    stats = train_step(net, optimizer, replay, params["batch_size"], device)
                    if step % 10 == 0:
                        print(f"Iter {iter_no} Ep {ep} Step {step} loss {stats['loss']:.4f}")
                scheduler.step()

            # save checkpoint
            if iter_no % params["checkpoint_every"] == 0:
                ckpt = {
                    "iter": iter_no,
                    "net_state": net.state_dict(),
                    "optimizer_state": optimizer.state_dict(),
                    "replay_size": len(replay),
                    "params": params,
                }
                path = os.path.join(params["save_dir"], f"checkpoint_{iter_no}.pt")
                torch.save(ckpt, path)
                print("Saved checkpoint", path)

            iter_no += 1

    finally:
        # stop workers
        stop_event.set()
        for p in workers:
            p.terminate()
            p.join()

def train_loop_debug(
    board_size=9,
    n_iterations=20,
    self_play_games_per_iter=5,
    mcts_simulations=64,
    batch_size=64,
    epochs=1,
    device="cpu",
    seed=0,
):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    device = torch.device(device)

    net = AlphaNet(board_size=board_size, num_res=6, channels=128).to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-4)
    replay = ReplayBuffer(capacity=50000)

    for it in range(n_iterations):
        print("=" * 50)
        print("Iteration", it)

        # -------------------------
        # 1) self-play
        # -------------------------
        t0 = time.time()
        total_moves = 0

        for g in range(self_play_games_per_iter):
            mcts = MCTS(
                net,
                board_size=board_size,
                c_puct=1.5,
                n_simulations=mcts_simulations,
                device=device,
            )
            traj = list(self_play_episode(
                net, mcts, board_size,
                mcts_simulations, temperature=1.0
            ))
            for ex in traj:
                replay.push(ex)
            total_moves += len(traj)

        t1 = time.time()
        sp_time = t1 - t0
        sp_speed = total_moves / max(sp_time, 1e-9)

        print(f"Self-play: {self_play_games_per_iter} games")
        print(f"Moves generated: {total_moves}")
        print(f"Self-play time: {sp_time:.2f} sec ({sp_speed:.1f} moves/sec)")
        print(f"Replay size: {len(replay)}")

        # -------------------------
        # 2) training
        # -------------------------
        if len(replay) < batch_size:
            print("Not enough data in buffer, skipping training")
            continue

        t2 = time.time()

        loss_list = []
        policy_list = []
        value_list = []

        for ep in range(epochs):
            # number of steps per epoch
            steps = max(1, len(replay) // batch_size)
            for _ in range(steps):
                states, pis, zs = replay.sample(batch_size)
                states_t = torch.tensor(states, dtype=torch.float32, device=device)
                pis_t = torch.tensor(pis, dtype=torch.float32, device=device)
                zs_t = torch.tensor(zs, dtype=torch.float32, device=device)

                net.train()
                optimizer.zero_grad()

                logp, v = net(states_t)

                policy_loss = - (pis_t * logp).sum(dim=1).mean()
                value_loss = ((v - zs_t) ** 2).mean()
                loss = policy_loss + 0.5 * value_loss

                loss.backward()
                torch.nn.utils.clip_grad_norm_(net.parameters(), 1.0)
                optimizer.step()

                loss_list.append(loss.item())
                policy_list.append(policy_loss.item())
                value_list.append(value_loss.item())

        t3 = time.time()
        train_time = t3 - t2

        print(f"Train time: {train_time:.2f} sec")
        print(f"Avg loss:   {np.mean(loss_list):.4f}")
        print(f"Avg policy: {np.mean(policy_list):.4f}")
        print(f"Avg value:  {np.mean(value_list):.4f}")

    return net


# simple evaluate function: play net vs current best policy (self play with fixed temperature=0)
def evaluate(net: AlphaNet, opponent_net: AlphaNet, games: int = 40):
    # plays games alternating colors
    wins = 0
    for g in range(games):
        board = Board(size=net.board_size)
        to_move = BLACK if (g % 2 == 0) else WHITE
        if g % 2 == 0:
            nets = {BLACK: net, WHITE: opponent_net}
        else:
            nets = {BLACK: opponent_net, WHITE: net}
        mcts_players = {p: MCTS(nets[p], board_size=net.board_size, c_puct=1.5, n_simulations=160, device=TRAIN_PARAMS["device"]) for p in (BLACK, WHITE)}
        move_no = 0
        while not board.game_over() and move_no < net.board_size * net.board_size * 4:
            cur_net = nets[to_move]
            mcts = mcts_players[to_move]
            pi = mcts.run(board, to_move)
            a = int(np.argmax(pi))
            if a == net.board_size * net.board_size:
                mv = None
            else:
                mv = (a % net.board_size, a // net.board_size)
            board._play_move_no_checks(to_move, mv)
            to_move = opponent(to_move)
            move_no += 1
        winner = board.winner()
        if winner == (BLACK if (g % 2 == 0) else WHITE):
            wins += 1
    print(f"Evaluation: {wins}/{games} wins")

# Example of running:
# net = AlphaNet(board_size=9, in_channels=18, num_res=6, channels=128)
# train_loop(net, TRAIN_PARAMS)


In [7]:
# ---------------------------
# 6) Training runner
# ---------------------------

def run_training():
    # параметры можешь менять свободно
    board_size = 9
    n_iterations = 10
    self_play_games_per_iter = 4
    mcts_simulations = 48
    batch_size = 64
    epochs = 2
    device = 'cpu'
    seed = 0

    print("Starting training...")
    print(f"Device: {device}")

    # если используешь debug-версию тренировки:
    net = train_loop_debug(
        board_size=board_size,
        n_iterations=n_iterations,
        self_play_games_per_iter=self_play_games_per_iter,
        mcts_simulations=mcts_simulations,
        batch_size=batch_size,
        epochs=epochs,
        device=device,
        seed=seed
    )

    # если используешь обычную train_loop — замени строку выше на:
    # net = train_loop(...)

    print("Training finished.")
    return net


net = run_training()

Starting training...
Device: cpu
Iteration 0
Self-play: 4 games
Moves generated: 680
Self-play time: 234.52 sec (2.9 moves/sec)
Replay size: 680
Train time: 5.77 sec
Avg loss:   4.3147
Avg policy: 4.2799
Avg value:  0.0697
Iteration 1
Self-play: 4 games
Moves generated: 722
Self-play time: 247.48 sec (2.9 moves/sec)
Replay size: 1402
Train time: 12.54 sec
Avg loss:   4.8324
Avg policy: 4.2900
Avg value:  1.0848
Iteration 2
Self-play: 4 games
Moves generated: 709
Self-play time: 233.24 sec (3.0 moves/sec)
Replay size: 2111
Train time: 17.77 sec
Avg loss:   5.0897
Avg policy: 4.2958
Avg value:  1.5879
Iteration 3
Self-play: 4 games
Moves generated: 896
Self-play time: 299.74 sec (3.0 moves/sec)
Replay size: 3007
Train time: 25.76 sec
Avg loss:   4.8049
Avg policy: 4.2947
Avg value:  1.0204
Iteration 4
Self-play: 4 games
Moves generated: 1005
Self-play time: 341.57 sec (2.9 moves/sec)
Replay size: 4012
Train time: 34.51 sec
Avg loss:   4.7983
Avg policy: 4.3048
Avg value:  0.9871
Iteratio

In [9]:
# watch_match.py  (или вставь в конец az_go.py)
import time
import torch
import numpy as np

# опционально: pygame рендер
try:
    import pygame
    PYGAME_OK = True
except Exception:
    PYGAME_OK = False

def load_net(path: str, board_size: int, device='cpu'):
    """Создает AlphaNet и загружает веса из path (torch .pth/.pt)."""
    net = AlphaNet(board_size=board_size, num_res=3, channels=64)
    net.load_state_dict(torch.load(path, map_location=device))
    net.to(device)
    net.eval()
    return net

def board_to_ascii(board: Board):
    """Возвращает строковое представление доски ('.' empty, 'X' black, 'O' white)."""
    lines = []
    for y in range(board.size):
        row = []
        for x in range(board.size):
            v = board.get(x, y)
            if v == EMPTY:
                row.append('.')
            elif v == BLACK:
                row.append('X')
            else:
                row.append('O')
        lines.append(''.join(row))
    return '\n'.join(lines)

class PygameRenderer:
    def __init__(self, board_size=9, cell_size=30, padding=20):
        if not PYGAME_OK:
            raise RuntimeError("pygame not available")
        self.board_size = board_size
        self.cell_size = cell_size
        self.padding = padding
        self.width = board_size * cell_size + padding * 2
        self.height = self.width
        self.last_pass = False
        pygame.init()
        self.screen = pygame.display.set_mode((self.width, self.height))
        pygame.display.set_caption("Watch Match")
        self.font = pygame.font.SysFont(None, 24)

    def draw(self, board: Board, to_move):
        self.screen.fill((210, 170, 110))
        PADDING = self.padding
        CELL = self.cell_size
        STONE_R = int(CELL * 0.4)
        # grid
        for i in range(self.board_size):
            pygame.draw.line(
                self.screen, (0,0,0),
                (PADDING, PADDING + i*CELL),
                (PADDING + (self.board_size-1)*CELL, PADDING + i*CELL), 1)
            pygame.draw.line(
                self.screen, (0,0,0),
                (PADDING + i*CELL, PADDING),
                (PADDING + i*CELL, PADDING + (self.board_size-1)*CELL), 1)
        # hoshi
        hoshi = [3, 9, 15]
        for hx in hoshi:
            for hy in hoshi:
                if 0 <= hx < self.board_size and 0 <= hy < self.board_size:
                    pygame.draw.circle(self.screen, (0,0,0),
                        (PADDING + hx*CELL, PADDING + hy*CELL), max(2, int(CELL*0.08)))
        # stones
        for y in range(self.board_size):
            for x in range(self.board_size):
                v = board.get(x,y)
                if v == BLACK:
                    pygame.draw.circle(self.screen, (0,0,0),
                        (PADDING + x*CELL, PADDING + y*CELL), STONE_R)
                elif v == WHITE:
                    pygame.draw.circle(self.screen, (255,255,255),
                        (PADDING + x*CELL, PADDING + y*CELL), STONE_R)
        # turn text
        txt = "Black to move" if to_move == BLACK else "White to move"
        if self.last_pass: txt = "White passed" if to_move == BLACK else "Black passed"
        surf = self.font.render(txt, True, (0,0,0))
        self.screen.blit(surf, (5,5))
        pygame.display.flip()

def pick_action_from_pi(pi: np.ndarray, board: Board, board_size:int, temperature:float=0.0):
    """pi is visit-probabilities (length board_size*board_size+1). Returns chosen action index."""
    # mask illegal moves
    current_player = board.turn
    legal = board.legal_moves(board.turn)  # we will set current_player outside, but easier pass as global? avoid
    # Better: build legal mask here by checking each index
    mask = np.zeros_like(pi, dtype=np.bool_)
    for mv in board.legal_moves(current_player):
        if mv is None:
            idx = board_size*board_size
        else:
            idx = mv[1]*board_size + mv[0]
        mask[idx] = True
    masked = np.copy(pi)
    masked[~mask] = 0.0
    s = masked.sum()
    if s <= 0:
        # fallback: uniform over legal
        legal_idxs = np.where(mask)[0]
        return int(np.random.choice(legal_idxs))
    masked = masked / s
    if temperature == 0.0:
        return int(np.argmax(masked))
    else:
        p = masked ** (1.0 / temperature)
        p = p / p.sum()
        return int(np.random.choice(len(p), p=p))

def play_match(net1: AlphaNet,
               net2: AlphaNet,
               board_size:int=9,
               mcts_simulations:int=160,
               c_puct:float=1.2,
               device='cpu',
               render: str = 'console',  # 'console' or 'pygame'
               pause_time: float = 0.15,
               temperature: float = 0.0):
    """
    Провести партию между net1 (играет Black) и net2 (White).
    render = 'console' или 'pygame'. pause_time в секундах между ходами.
    """
    # подготовка MCTS для каждой сети
    mcts1 = MCTS(net1, board_size=board_size, c_puct=c_puct, n_simulations=mcts_simulations, device=device)
    mcts2 = MCTS(net2, board_size=board_size, c_puct=c_puct, n_simulations=mcts_simulations, device=device)

    board = Board(size=board_size)
    to_move = BLACK
    move_no = 0

    renderer = None
    if render == 'pygame':
        if not PYGAME_OK:
            print("pygame not available, falling back to console render")
            render = 'console'
        else:
            renderer = PygameRenderer(board_size=board_size, cell_size=30, padding=30)

    # main loop
    while not board.game_over() and move_no < board_size*board_size*4:
        if render == 'console':
            print("\nMove", move_no, "to_move:", "BLACK" if to_move==BLACK else "WHITE")
            print(board_to_ascii(board))
        elif render == 'pygame':
            renderer.draw(board, to_move)
            # handle events to allow window close
            for ev in pygame.event.get():
                if ev.type == pygame.QUIT:
                    pygame.quit()
                    return

        # choose MCTS for current player
        if to_move == BLACK:
            pi = mcts1.run(board, to_move)  # visit-probs
        else:
            pi = mcts2.run(board, to_move)

        # pick action deterministically if temperature==0
        # convert pi to mask and argmax
        # build legal mask and choose
        size = board_size
        mask = np.zeros_like(pi, dtype=np.bool_)
        for mv in board.legal_moves(to_move):
            if mv is None:
                idx = size*size
            else:
                idx = mv[1]*size + mv[0]
            mask[idx] = True
        masked = np.copy(pi)
        masked[~mask] = 0.0
        if masked.sum() <= 0:
            # fallback random legal
            legal = board.legal_moves(to_move)
            mv = random.choice(legal)
        else:
            if temperature == 0.0:
                a = int(np.argmax(masked))
            else:
                p = masked ** (1.0 / temperature)
                p = p / p.sum()
                a = int(np.random.choice(len(p), p=p))
            if a == size*size:
                mv = None
            else:
                mv = (a % size, a // size)
            renderer.last_pass = mv == None

        # apply move (we use internal _play_move_no_checks to avoid extra superko checks done by is_legal,
        # but to be safe, you can check is_legal before applying)
        if mv is None:
            board._play_move_no_checks(to_move, None)
        else:
            # ensure legal
            if not board.is_legal(to_move, mv):
                mv = None   # либо pass
            board._play_move_no_checks(to_move, mv)

        to_move = opponent(to_move)
        move_no += 1
        time.sleep(pause_time)

    # finished
    bscore, wscore = board.score_chinese()
    winner = board.winner()
    if render == 'console':
        print("\nFinal board:")
        print(board_to_ascii(board))
        print(f"Score Black: {bscore:.1f}, White(with komi): {wscore:.1f}")
        print("Winner:", "BLACK" if winner==BLACK else "WHITE")
    elif render == 'pygame':
        renderer.draw(board, to_move)
        print("Game finished. Close window to exit.")
        print(f"Score Black(with komi): {bscore:.1f}, White: {wscore:.1f}")
        print("Winner:", "BLACK" if winner==BLACK else "WHITE")
        # keep window open until closed
        while True:
            for ev in pygame.event.get():
                if ev.type == pygame.QUIT:
                    pygame.quit()
                    return
            time.sleep(0.1)

# ---------------------------
# Example usage:
# ---------------------------
# 1) если у тебя есть сохраненные веса:
# net_black = load_net("best_black.pth", board_size=9, device='cpu')
# net_white = load_net("best_white.pth", board_size=9, device='cpu')
# play_match(net_black, net_white, board_size=9, mcts_simulations=128, render='console')

# 2) если у тебя есть одна сеть (самонаблюдение), можно использовать одну и ту же модель:
# net = load_net("model.pth", board_size=9)
# play_match(net, net, board_size=9, mcts_simulations=128, render='pygame', pause_time=0.4)

# 3) если ты вызвал train_loop и получил net object:
# net_trained = train_loop(...)
# play_match(net_trained, net_trained, board_size=9, mcts_simulations=128, render='console')

play_match(net, net, board_size=9, mcts_simulations=256, render='pygame', pause_time=0.25)


Game finished. Close window to exit.
Score Black(with komi): 81.0, White: 3.5
Winner: BLACK
