In [None]:
# Cell 1: Install dependencies
!pip install numpy torch coloredlogs
!pip install  tqdm

Collecting coloredlogs
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)


 Utilities

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import random
import pickle
import matplotlib.pyplot as plt
import time

GoLogic (Board mechanics, stone expiration, limited stones)

In [2]:
# ============================
# Module 1: Board and Game Logic
# ============================

class Board:
    def __init__(self, n=5):
        self.n = n
        self.pieces = [[(0, 0) for _ in range(n)] for _ in range(n)]
        self.stones_left = {1: 30, -1: 30}
        self.turn_count = 0

    def copy(self):
        new_board = Board(self.n)
        new_board.pieces = [row.copy() for row in self.pieces]
        new_board.stones_left = self.stones_left.copy()
        new_board.turn_count = self.turn_count
        return new_board

    def get_legal_moves(self, color):
        if self.stones_left[color] <= 0:
            return []
        return [(x, y) for y in range(self.n) for x in range(self.n) if self.pieces[y][x][0] == 0]

    def execute_move(self, move, color):
      if move is not None:
        x, y = move
        self.pieces[y][x] = (color, 0)
        self.stones_left[color] -= 1

    # advance ages only every second move  ➜ one full “round”
      if self.turn_count % 2 == 1:
        self.increment_ages_and_expire()

      self.turn_count += 1


    def increment_ages_and_expire(self):
        for y in range(self.n):
            for x in range(self.n):
                color, age = self.pieces[y][x]
                if color != 0:
                    age += 1
                    self.pieces[y][x] = (0, 0) if age >= 3 else (color, age)

    def count_stones(self, color):
        return sum(1 for row in self.pieces for stone, _ in row if stone == color)

    def get_remaining_stones(self, color):
        return self.stones_left[color]

    def is_full(self):
        return all(stone != 0 for row in self.pieces for stone, _ in row)

    def display(self):
        header = "  " + " ".join(map(str, range(self.n)))
        board_representation = [header]
        for y, row in enumerate(self.pieces):
            row_str = f"{y} " + " ".join(
                f"W{age}" if color == 1 else f"B{age}" if color == -1 else " . "
                for color, age in row
            )
            board_representation.append(row_str)
        footer = f"Stones left - White: {self.stones_left[1]}, Black: {self.stones_left[-1]}"
        board_representation.append(footer)
        print("\n".join(board_representation))



 GoGame (Game rules and interface)

In [3]:
class GoGame:
    def __init__(self, n=5):
        self.n = n

    # ---------- helpers -------------------------------------------------
    @staticmethod
    def territory_score(board, color):
        """
        Very simple territory heuristic:
        • live stones count 1 point
        • each empty intersection counts 0.5 point for BOTH players
          (so total board score is always 25)
        This prevents the 0-0 draws you saw when all stones evaporate.
        """
        own   = sum(s == color  for row in board.pieces for s,_ in row)
        empty = sum(s == 0      for row in board.pieces for s,_ in row)
        return own + 0.5 * empty
    # --------------------------------------------------------------------

    def getInitBoard(self):
        return Board(self.n)

    def getBoardSize(self):
        return self.n, self.n

    def getActionSize(self):
        return self.n * self.n + 1          # +1 for pass

    def getNextState(self, board, player, action):
        b = board.copy()
        if action == self.n * self.n:       # Pass
            b.increment_ages_and_expire()
            b.turn_count += 1
            return b, -player
        move = (action // self.n, action % self.n)
        b.execute_move(move, player)
        return b, -player

    def getValidMoves(self, board, player):
        valids = [0] * self.getActionSize()
        legal  = board.get_legal_moves(player)
        if not legal or board.stones_left[player] <= 0:
            valids[-1] = 1                  # only pass
        else:
            for x, y in legal:
                valids[self.n * x + y] = 1
        return np.array(valids)

    def getGameEnded(self, board, player, pass_count=0):
    # game ends on two passes or both players have no stones left
      if pass_count >= 2 or (board.stones_left[1] == 0 and board.stones_left[-1] == 0):
        # 1) primary score = territory heuristic
        p1 = self.territory_score(board,  1)
        p2 = self.territory_score(board, -1)

        # 2) if still tied (e.g. empty board) use "stones placed" tiebreak
        if abs(p1 - p2) < 1e-6:
            placed1 = 30 - board.stones_left[1]
            placed2 = 30 - board.stones_left[-1]
            if   placed1 > placed2: return 1
            elif placed2 > placed1: return -1
            else:                   return 1e-4     # absolute tie (very rare)
        else:
            return 1 if p1 > p2 else -1
      return 0


    def display(self, board):
        board.display()


 Neural Network (GoNNet)

In [4]:
# ============================
# Module 2: Neural Network (GoNNet)
# ============================


# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class ResBlock(nn.Module):
    """Residual Block with two convolution layers"""
    def __init__(self, num_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(num_channels, num_channels, 3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.conv2 = nn.Conv2d(num_channels, num_channels, 3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(num_channels)

    def forward(self, x):
        residual = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual  # Skip connection
        out = F.relu(out)
        return out


class GoNNet(nn.Module):
    """Convolutional Neural Network for MicroGo"""
    def __init__(self, board_size=5, num_channels=64):
        super().__init__()
        self.board_size = board_size
        self.conv1 = nn.Conv2d(4, num_channels, 3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(num_channels)

        # Residual blocks
        self.res_blocks = nn.ModuleList([ResBlock(num_channels) for _ in range(3)])

        # Policy head
        self.policy_conv = nn.Conv2d(num_channels, 32, 1)
        self.policy_bn = nn.BatchNorm2d(32)
        self.policy_fc = nn.Linear(32 * board_size * board_size, board_size * board_size + 1)

        # Value head
        self.value_conv = nn.Conv2d(num_channels, 32, 1)
        self.value_bn = nn.BatchNorm2d(32)
        self.value_fc1 = nn.Linear(32 * board_size * board_size, 64)
        self.value_fc2 = nn.Linear(64, 1)

    def forward(self, x):
        x = x.to(device)  # Send to GPU if available
        x = F.relu(self.bn1(self.conv1(x)))

        # Apply residual blocks
        for block in self.res_blocks:
            x = block(x)

        # Policy head
        policy = F.relu(self.policy_bn(self.policy_conv(x)))
        policy = policy.view(policy.size(0), -1)
        policy = self.policy_fc(policy)

        # Value head
        value = F.relu(self.value_bn(self.value_conv(x)))
        value = value.view(value.size(0), -1)
        value = F.relu(self.value_fc1(value))
        value = torch.tanh(self.value_fc2(value))

        return policy, value


Using device: cuda


In [5]:
# ============================
# Module 3: Data Augmentation & Tensor Conversion
# ============================
def board_to_tensor(board: Board, player):
    """
    Return a *Torch tensor* shaped (4, 5, 5) on **CPU**.
    No batch-dim here, no GPU move here.
    """
    n = board.n
    player_stones   = np.zeros((n, n), dtype=np.float32)
    opponent_stones = np.zeros((n, n), dtype=np.float32)
    stone_ages      = np.zeros((n, n), dtype=np.float32)
    legal_moves     = np.zeros((n, n), dtype=np.float32)

    for y in range(n):
        for x in range(n):
            c, age = board.pieces[y][x]
            if c == player:
                player_stones[y, x] = 1
            elif c == -player:
                opponent_stones[y, x] = 1
            if c != 0:
                stone_ages[y, x] = age / 3.0

    for x, y in board.get_legal_moves(player):
        legal_moves[y, x] = 1

    tensor = np.stack([player_stones, opponent_stones, stone_ages, legal_moves], axis=0)
    return torch.tensor(tensor, dtype=torch.float32)        # ← Torch tensor, CPU

def augment_data(state, pi, board_size=5):
    """
    state : (4,5,5) ndarray
    pi    : list/ndarray length 26
    returns lists of numpy arrays (NOT torch tensors)
    """
    state   = np.asarray(state, dtype=np.float32)
    pi      = np.asarray(pi,    dtype=np.float32)

    pi_board  = pi[:-1].reshape(board_size, board_size)
    pass_move = pi[-1]

    aug_states, aug_pis = [], []

    for k in range(4):
        rot_s  = np.rot90(state, k=k, axes=(1, 2)).copy()     # (4,5,5)
        rot_pi = np.rot90(pi_board, k=k).copy()

        aug_states.append(rot_s)
        aug_pis   .append(np.append(rot_pi.flatten(), pass_move))

        flip_s  = np.flip(rot_s, axis=2).copy()               # h-flip
        flip_pi = np.flip(rot_pi, axis=1).copy()
        aug_states.append(flip_s)
        aug_pis  .append(np.append(flip_pi.flatten(), pass_move))

    return aug_states, aug_pis


In [6]:
# ============================
# Module 4: Monte Carlo Tree Search (MCTS)
# ============================

class MCTS:
    def __init__(self, game, nnet, args):
        """
        Monte Carlo Tree Search for game simulation and evaluation.
        Parameters:
        - game: The game instance (MicroGo)
        - nnet: Neural Network for policy and value estimation
        - args: Hyperparameters for MCTS
        """
        self.game = game
        self.nnet = nnet
        self.args = args
        self.Qsa = {}       # Q values for (state, action)
        self.Nsa = {}       # Visit count for (state, action)
        self.Ns = {}        # Visit count for state
        self.Ps = {}        # Policy returned by the neural net

    def getActionProb(self, board, player, temp=1):
        """
        Executes MCTS simulations and returns the action probabilities.
        """
        for _ in range(self.args['numMCTSSims']):
            self.search(board.copy(), player)

        s = self.string_rep(board, player)
        counts = [self.Nsa.get((s, a), 0) for a in range(self.game.getActionSize())]

        if temp == 0:
            bestA = np.argmax(counts)
            probs = [0] * len(counts)
            probs[bestA] = 1
            return probs

        counts = [x ** (1. / temp) for x in counts]
        total = float(sum(counts))
        return [x / total for x in counts]

    def search(self, board, player):
        """
        Executes one simulation of MCTS.
        """
        game_result = self.game.getGameEnded(board, player)
        if game_result != 0:
            return -game_result  # Return the value for the previous player

        s = self.string_rep(board, player)

        if s not in self.Ps:
            # Leaf node - evaluate with neural network
            board_tensor = board_to_tensor(board, player).unsqueeze(0).to(device)
            policy, v = self.nnet(board_tensor)
            policy = torch.softmax(policy, dim=1).detach().cpu().numpy()[0]

            valids = self.game.getValidMoves(board, player)
            policy = policy * valids

            sum_p = np.sum(policy)
            if sum_p > 0:
                policy /= sum_p
            else:
                # If all valid moves are masked, make all valid moves equally probable
                policy = policy + valids
                policy /= np.sum(policy)

            self.Ps[s] = policy
            self.Ns[s] = 0
            return -v.item()

        # Select the action with the maximum UCB (Upper Confidence Bound)
        valids = self.game.getValidMoves(board, player)
        cur_best = -float('inf')
        best_act = -1

        for a in range(self.game.getActionSize()):
            if valids[a]:
                u = (self.Qsa.get((s, a), 0) +
                     self.args['cpuct'] * self.Ps[s][a] * np.sqrt(self.Ns[s]) / (1 + self.Nsa.get((s, a), 0)))
                if u > cur_best:
                    cur_best = u
                    best_act = a

        a = best_act
        next_board, next_player = self.game.getNextState(board, player, a)
        v = self.search(next_board, next_player)

        # Update Q, N values
        if (s, a) in self.Qsa:
            self.Qsa[(s, a)] = (self.Nsa[(s, a)] * self.Qsa[(s, a)] + v) / (self.Nsa[(s, a)] + 1)
            self.Nsa[(s, a)] += 1
        else:
            self.Qsa[(s, a)] = v
            self.Nsa[(s, a)] = 1

        self.Ns[s] += 1
        return -v

    def string_rep(self, board, player):
        """
        Generates a string representation of the board.
        Useful for efficient lookup in MCTS.
        """
        flat = []
        for row in board.pieces:
            for c, a in row:
                flat.append(str(c) + str(a))
        return str(player) + ''.join(flat)


In [7]:
# ============================
# Module 5: Arena and Self-Play Logic
# ============================

import random

class Arena:
    def __init__(self, player1, player2, game, display=None):
        """
        Arena is the battle ground for two agents:
        - player1: Agent 1 (e.g. MCTS or Human)
        - player2: Agent 2 (e.g. MCTS or Human)
        """
        self.player1 = player1
        self.player2 = player2
        self.game = game
        self.display = display or (lambda board: None)

    def playGame(self, verbose=False):
        """
        Executes one full game between two players.
        """
        players = [self.player2, None, self.player1]
        current_player = 1
        board = self.game.getInitBoard()
        pass_count = 0

        while True:
            if verbose:
                self.display(board)

            action = players[current_player + 1](board.copy())
            if action == self.game.getActionSize() - 1:  # Pass
                pass_count += 1
            else:
                pass_count = 0

            board, current_player = self.game.getNextState(board, current_player, action)

            # Check if the game has ended
            if self.game.getGameEnded(board, current_player, pass_count) != 0:
                if verbose:
                    self.display(board)
                return self.game.getGameEnded(board, current_player, pass_count)

# ============================
# Self-Play Logic
# ============================

def self_play_episode(game, nnet, args):
    """
    Executes one self-play game with MCTS guidance.
    """
    train_examples = []
    board = game.getInitBoard()
    player = 1
    episode_step = 0

    mcts = MCTS(game, nnet, args)

    while True:
        episode_step += 1

        # Get action probabilities
        pi = mcts.getActionProb(board, player, temp=1 if episode_step < 10 else 0.5)

        # Store the board state and policy
        state_tensor = board_to_tensor(board, player).numpy()
        train_examples.append([state_tensor, pi, None])  # Value will be set later

        # Choose action
        action = np.random.choice(len(pi), p=pi)

        # Make the move
        board, player = game.getNextState(board, player, action)

        # Check if game is over
        game_result = game.getGameEnded(board, player)
        if game_result != 0:
            # Update all examples with final result
            for i in range(len(train_examples)):
                train_examples[i][2] = game_result * ((-1) ** (i % 2))  # Flip result based on player
            return train_examples


In [8]:
# ╔════════════════════════════════════════╗
# ║ Cell 1 – (Random) game generator / loader
# ╚════════════════════════════════════════╝
import os, pickle, numpy as np
from tqdm import tqdm

NUM_RANDOM_GAMES = 10_000            # same as your original
RANDOM_PICKLE    = 'random_games.pkl'

def random_agent(board, game, player):
    valids  = game.getValidMoves(board, player)
    actions = np.where(valids == 1)[0]
    return np.random.choice(actions)

def generate_random_games(game, num_games=NUM_RANDOM_GAMES):
    print(f"Generating {num_games:,} random games …")
    data = []
    for _ in tqdm(range(num_games)):
        board  = game.getInitBoard()
        player = 1
        states, actions = [], []
        pass_count = 0
        while True:
            valids = game.getValidMoves(board, player)
            action = np.random.choice(np.where(valids == 1)[0])

            # -------- store snapshot ----------
            states.append(board_to_tensor(board, player).cpu().numpy())
            actions.append(action)
            # -----------------------------------

            board, player = game.getNextState(board, player, action)

            pass_count = pass_count + 1 if action == game.getActionSize() - 1 else 0
            if game.getGameEnded(board, player, pass_count) != 0:
                result = game.getGameEnded(board, player, pass_count)
                for i in range(len(states)):
                    signed_res =  result if i % 2 == 0 else -result
                    data.append((states[i], actions[i], signed_res))
                break
    print("✓ random-game generation complete")
    return data

# ---------- load or build ----------
if os.path.exists(RANDOM_PICKLE):
    print("Loading cached random games …")
    with open(RANDOM_PICKLE, 'rb') as f: random_data = pickle.load(f)
else:
    game        = GoGame(n=5)
    random_data = generate_random_games(game, NUM_RANDOM_GAMES)
    with open(RANDOM_PICKLE, 'wb') as f: pickle.dump(random_data, f)
print(f"Random dataset size: {len(random_data):,}")


Generating 10,000 random games …


100%|██████████| 10000/10000 [00:58<00:00, 170.09it/s]


✓ random-game generation complete
Random dataset size: 600,000


Supervised Pretraining on Random Data

In [9]:
# ╔════════════════════════════════════════╗
# ║ Cell 2 – fast pre-train w/ sampling + AMP
# ╚════════════════════════════════════════╝
import torch, torch.nn.functional as F
from torch.cuda.amp import autocast, GradScaler
from torch.utils.data import DataLoader, TensorDataset, Subset
import torch.optim as optim, numpy as np
from pathlib import Path
import math, random

PRETRAIN_CKPT = Path('checkpoints/pretrained_policy.pth')
BATCH_SIZE    = 256
EPOCHS        = 12
LR            = 3e-3
SAMPLE_POS    = 120_000          # ⇦  adjust: 0 = full data

def make_dataset(random_data):
    states, acts, res = zip(*random_data)
    arr = np.array(states)
    if arr.ndim == 5:
        arr = arr.squeeze(1)
    return TensorDataset(torch.tensor(arr,  dtype=torch.float32),
                         torch.tensor(acts, dtype=torch.long),
                         torch.tensor(res,  dtype=torch.float32).unsqueeze(1))

def supervised_pretrain(net, rnd_data):
    full_ds = make_dataset(rnd_data)
    if SAMPLE_POS and SAMPLE_POS < len(full_ds):
        idx = random.sample(range(len(full_ds)), SAMPLE_POS)
        ds  = Subset(full_ds, idx)
        print(f"Using random subset {SAMPLE_POS:,}/{len(full_ds):,}")
    else:
        ds = full_ds

    dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True,
                    pin_memory=True)

    net.train()
    opt = optim.AdamW(net.parameters(), lr=LR, weight_decay=1e-4)
    scaler = GradScaler()                     # for mixed precision

    for ep in range(1, EPOCHS+1):
        run_loss, correct, tot = 0,0,0
        for s,a,v in dl:
            s,a,v = s.to(device, non_blocking=True), a.to(device), v.to(device)

            opt.zero_grad()
            with autocast():
                logits, val = net(s)
                loss_pi = F.cross_entropy(logits, a)
                loss_v  = F.mse_loss(val, v)
                loss    = loss_pi + 0.25*loss_v

            scaler.scale(loss).backward()
            scaler.unscale_(opt)
            torch.nn.utils.clip_grad_norm_(net.parameters(), 2.0)
            scaler.step(opt)
            scaler.update()

            run_loss += loss.item()*s.size(0)
            correct  += (logits.argmax(1)==a).sum().item()
            tot      += a.size(0)

        print(f"Epoch {ep:2}/{EPOCHS}  loss {run_loss/tot:.4f}  "
              f"π-acc {correct/tot*100:5.1f}%")

    PRETRAIN_CKPT.parent.mkdir(exist_ok=True)
    torch.save(net.state_dict(), PRETRAIN_CKPT)
    print("✓ saved →", PRETRAIN_CKPT)

# ---- run ----
game = GoGame(n=5)
nnet = GoNNet(board_size=5).to(device)
supervised_pretrain(nnet, random_data)


Using random subset 120,000/600,000


  scaler = GradScaler()                     # for mixed precision
  with autocast():


Epoch  1/12  loss 3.0954  π-acc   4.9%
Epoch  2/12  loss 3.0437  π-acc   5.0%
Epoch  3/12  loss 3.0370  π-acc   5.2%
Epoch  4/12  loss 3.0332  π-acc   5.5%
Epoch  5/12  loss 3.0300  π-acc   5.7%
Epoch  6/12  loss 3.0265  π-acc   5.9%
Epoch  7/12  loss 3.0222  π-acc   6.3%
Epoch  8/12  loss 3.0182  π-acc   6.5%
Epoch  9/12  loss 3.0130  π-acc   6.8%
Epoch 10/12  loss 3.0059  π-acc   7.3%
Epoch 11/12  loss 2.9982  π-acc   7.7%
Epoch 12/12  loss 2.9899  π-acc   8.2%
✓ saved → checkpoints/pretrained_policy.pth


Self-Play with MCTS & RL Training


In [10]:
# ╔════════════════════════════════════════╗
# ║  CLEAN RL CELL – ready for full run    ║
# ╚════════════════════════════════════════╝
import torch, numpy as np, time
from tqdm.auto import trange, tqdm
from pathlib import Path
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim
import torch.nn.functional as F

# ------------ hyper-parameters -----------------
TRAIN_ITERS    = 8          # total RL iterations
GAMES_PER_ITER = 60         # self-play games each iter
args_mcts      = {'numMCTSSims': 40,  # ← 10 for speed, 40 for strength
                  'cpuct': 1.5}

BATCH_RL  = 128
EPOCHS_RL = 5
LR_RL     = 5e-4
CKPT_DIR  = Path('checkpoints'); CKPT_DIR.mkdir(exist_ok=True)
# -----------------------------------------------


# ---  board_to_tensor (FINAL, keep in one place) --------------------
def board_to_tensor(board, player):
    n = board.n
    player_stones   = np.zeros((n, n), dtype=np.float32)
    opponent_stones = np.zeros((n, n), dtype=np.float32)
    stone_ages      = np.zeros((n, n), dtype=np.float32)
    legal_moves     = np.zeros((n, n), dtype=np.float32)
    for y in range(n):
        for x in range(n):
            c, a = board.pieces[y][x]
            if c == player:       player_stones[y, x] = 1
            elif c == -player:    opponent_stones[y, x] = 1
            if c != 0:            stone_ages[y, x] = a / 3.0
    for x, y in board.get_legal_moves(player):
        legal_moves[y, x] = 1
    arr = np.stack([player_stones, opponent_stones, stone_ages, legal_moves], 0)
    return torch.tensor(arr, dtype=torch.float32)          # (4,5,5)  CPU
# --------------------------------------------------------------------


#  ------------------- train_nnet ------------------------------------
def train_nnet(net, train_examples):
    states, pis, vals = zip(*train_examples)
    states_t = torch.tensor(np.stack(states), dtype=torch.float32)     # (N,4,5,5)
    pis_t    = torch.tensor(np.stack(pis),    dtype=torch.float32)     # (N,26)
    vals_t   = torch.tensor(vals, dtype=torch.float32).unsqueeze(1)    # (N,1)

    ds = TensorDataset(states_t, pis_t, vals_t)
    dl = DataLoader(ds, batch_size=BATCH_RL, shuffle=True)

    net.train()
    opt = optim.Adam(net.parameters(), lr=LR_RL, weight_decay=1e-4)
    for ep in range(1, EPOCHS_RL+1):
        tot = 0.0
        for s, pi, v in tqdm(dl, desc=f"  Train-epoch {ep}/{EPOCHS_RL}", leave=False):
            s, pi, v = s.to(device), pi.to(device), v.to(device)
            opt.zero_grad()
            out_pi, out_v = net(s)
            loss = F.mse_loss(out_pi, pi) + F.mse_loss(out_v, v)
            loss.backward()
            opt.step()
            tot += loss.item()*s.size(0)
        print(f"    epoch {ep}: loss {tot/len(ds):.4f}")
# --------------------------------------------------------------------


# ------------------- RL iteration -----------------------------------
def train_iteration(net, it_no):
    print(f"\n=== RL Iter {it_no}/{TRAIN_ITERS} ({time.strftime('%H:%M:%S')}) ===")
    t0 = time.time()

    # -- self-play
    buffer = []
    for _ in trange(GAMES_PER_ITER, desc="Self-play"):
        buffer += self_play_episode(game, net, args_mcts)
    print(f"Collected {len(buffer):,} states  (t = {time.time()-t0:.1f}s)")

    # -- augment
    aug = []
    for s, pi, v in buffer:
        s_list, pi_list = augment_data(s, pi, game.n)   # numpy arrays
        aug.extend((ss, pp, v) for ss, pp in zip(s_list, pi_list))
    print(f"After augmentation: {len(aug):,}")

    # -- train
    train_nnet(net, aug)

    # -- save
    ckpt = CKPT_DIR / f"rl_iter_{it_no}.pth"
    torch.save(net.state_dict(), ckpt)
    print(f"✓ checkpoint saved → {ckpt}  (elapsed {(time.time()-t0)/60:.1f} min)")
# --------------------------------------------------------------------


# ---------------- run the loop --------------------------------------
game = GoGame(n=5)
nnet.to(device)                                 # pretrained weights already loaded
for it in range(1, TRAIN_ITERS+1):
    train_iteration(nnet, it)



=== RL Iter 1/8 (19:37:42) ===


Self-play:   0%|          | 0/60 [00:00<?, ?it/s]

Collected 3,600 states  (t = 377.1s)
After augmentation: 28,800


  Train-epoch 1/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 1: loss 1.6494


  Train-epoch 2/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 2: loss 0.2433


  Train-epoch 3/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 3: loss 0.2231


  Train-epoch 4/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 4: loss 0.2061


  Train-epoch 5/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 5: loss 0.1854
✓ checkpoint saved → checkpoints/rl_iter_1.pth  (elapsed 6.4 min)

=== RL Iter 2/8 (19:44:09) ===


Self-play:   0%|          | 0/60 [00:00<?, ?it/s]

Collected 3,600 states  (t = 372.3s)
After augmentation: 28,800


  Train-epoch 1/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 1: loss 0.0842


  Train-epoch 2/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 2: loss 0.0008


  Train-epoch 3/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 3: loss 0.0005


  Train-epoch 4/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 4: loss 0.0004


  Train-epoch 5/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 5: loss 0.0003
✓ checkpoint saved → checkpoints/rl_iter_2.pth  (elapsed 6.4 min)

=== RL Iter 3/8 (19:50:31) ===


Self-play:   0%|          | 0/60 [00:00<?, ?it/s]

Collected 3,600 states  (t = 371.5s)
After augmentation: 28,800


  Train-epoch 1/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 1: loss 0.0003


  Train-epoch 2/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 2: loss 0.0002


  Train-epoch 3/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 3: loss 0.0002


  Train-epoch 4/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 4: loss 0.0002


  Train-epoch 5/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 5: loss 0.0002
✓ checkpoint saved → checkpoints/rl_iter_3.pth  (elapsed 6.4 min)

=== RL Iter 4/8 (19:56:53) ===


Self-play:   0%|          | 0/60 [00:00<?, ?it/s]

Collected 3,600 states  (t = 373.8s)
After augmentation: 28,800


  Train-epoch 1/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 1: loss 0.0002


  Train-epoch 2/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 2: loss 0.0002


  Train-epoch 3/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 3: loss 0.0002


  Train-epoch 4/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 4: loss 0.0002


  Train-epoch 5/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 5: loss 0.0002
✓ checkpoint saved → checkpoints/rl_iter_4.pth  (elapsed 6.4 min)

=== RL Iter 5/8 (20:03:17) ===


Self-play:   0%|          | 0/60 [00:00<?, ?it/s]

Collected 3,600 states  (t = 367.0s)
After augmentation: 28,800


  Train-epoch 1/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 1: loss 0.0002


  Train-epoch 2/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 2: loss 0.0002


  Train-epoch 3/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 3: loss 0.0002


  Train-epoch 4/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 4: loss 0.0002


  Train-epoch 5/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 5: loss 0.0002
✓ checkpoint saved → checkpoints/rl_iter_5.pth  (elapsed 6.3 min)

=== RL Iter 6/8 (20:09:34) ===


Self-play:   0%|          | 0/60 [00:00<?, ?it/s]

Collected 3,600 states  (t = 369.4s)
After augmentation: 28,800


  Train-epoch 1/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 1: loss 0.0002


  Train-epoch 2/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 2: loss 0.0002


  Train-epoch 3/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 3: loss 0.0002


  Train-epoch 4/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 4: loss 0.0002


  Train-epoch 5/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 5: loss 0.0002
✓ checkpoint saved → checkpoints/rl_iter_6.pth  (elapsed 6.3 min)

=== RL Iter 7/8 (20:15:53) ===


Self-play:   0%|          | 0/60 [00:00<?, ?it/s]

Collected 3,600 states  (t = 363.3s)
After augmentation: 28,800


  Train-epoch 1/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 1: loss 0.0002


  Train-epoch 2/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 2: loss 0.0002


  Train-epoch 3/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 3: loss 0.0002


  Train-epoch 4/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 4: loss 0.0002


  Train-epoch 5/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 5: loss 0.0002
✓ checkpoint saved → checkpoints/rl_iter_7.pth  (elapsed 6.2 min)

=== RL Iter 8/8 (20:22:06) ===


Self-play:   0%|          | 0/60 [00:00<?, ?it/s]

Collected 3,600 states  (t = 362.6s)
After augmentation: 28,800


  Train-epoch 1/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 1: loss 0.0002


  Train-epoch 2/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 2: loss 0.0002


  Train-epoch 3/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 3: loss 0.0002


  Train-epoch 4/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 4: loss 0.0002


  Train-epoch 5/5:   0%|          | 0/225 [00:00<?, ?it/s]

    epoch 5: loss 0.0002
✓ checkpoint saved → checkpoints/rl_iter_8.pth  (elapsed 6.2 min)


 Play a Game Between Two MCTS Agents

In [11]:
def self_play_episode(game, nnet, args, temperature=1.0):
    train_examples = []
    board = game.getInitBoard()
    player = 1
    episode_step = 0

    mcts = MCTS(game, nnet, args)

    while True:
        episode_step += 1

        # Adjust temperature - start high for exploration, gradually reduce
        if episode_step < 10:
            temp = temperature
        else:
            temp = temperature * 0.75

        # Get action probabilities
        pi = mcts.getActionProb(board, player, temp=temp)

        # Store the board state and policy
        state_tensor = board_to_tensor(board, player).numpy()
        train_examples.append([state_tensor, pi, None])  # Value will be set later

        # Choose action - sometimes randomly for exploration
        if random.random() < 0.9:  # 90% of the time, choose from policy
            action = np.random.choice(len(pi), p=pi)
        else:  # 10% random valid move for additional exploration
            valid_moves = game.getValidMoves(board, player)
            valid_indices = np.where(valid_moves == 1)[0]
            action = np.random.choice(valid_indices)

        # Make the move
        board, player = game.getNextState(board, player, action)

        # Check if game is over
        game_result = game.getGameEnded(board, player)
        if game_result != 0:
            # Update all examples with final result
            for i in range(len(train_examples)):
                train_examples[i][2] = game_result * ((-1) ** (i % 2))  # Flip result based on player
            break

    return train_examples

In [12]:
game = GoGame(n=5)

 Save/Load Utilities

In [19]:
# ╔════════════════════════════════════════╗
# ║ Quick evaluation of checkpoints        ║
# ╚════════════════════════════════════════╝
# --- helper (only needed if Module 7 wasn't run) ------------------
def make_mcts_agent(nnet, sims=60, c=1.5):
    args = {'numMCTSSims': sims, 'cpuct': c}
    mcts = MCTS(game, nnet, args)
    def agent(board):            # deterministic (temp=0)
        probs = mcts.getActionProb(board, player=1, temp=0)
        return int(np.argmax(probs))
    return agent
# ------------------------------------------------------------------

# Quick evaluation of checkpoints
import torch, numpy as np
from pathlib import Path
from tqdm.auto import trange

PRETRAIN_CKPT   = Path('checkpoints/pretrained_policy.pth')
LATEST_RL_CKPT  = sorted(Path('checkpoints').glob('rl_iter_8.pth'))[-1]

net_pre = GoNNet(board_size=5).to(device)
net_pre.load_state_dict(torch.load(PRETRAIN_CKPT,  map_location=device)); net_pre.eval()

net_rl  = GoNNet(board_size=5).to(device)
net_rl.load_state_dict(torch.load(LATEST_RL_CKPT, map_location=device)); net_rl.eval()

pre_agent = make_mcts_agent(net_pre, sims=60)
rl_agent  = make_mcts_agent(net_rl,  sims=60)

# quick random agent (if Module 7 wasn’t run)
def random_agent(board, game=game, player=1):
    valids  = game.getValidMoves(board, player)
    return int(np.random.choice(np.where(valids == 1)[0]))

def quick_series(agentA, agentB, games=50):
    w = l = d = 0
    arena = Arena(agentA, agentB, game)
    for _ in trange(games, desc="Games"):
        res = arena.playGame()
        if   res == 1:  w += 1
        elif res == -1: l += 1
        else:           d += 1
    return w, l, d

for title, a in [("Pre-train", pre_agent), ("RL latest", rl_agent)]:
    w, l, d = quick_series(a, random_agent, games=50)
    print(f"{title:10s} vs Random → wins:{w:2d}  loses:{l:2d}  draws:{d:2d}")

w, l, d = quick_series(rl_agent, pre_agent, games=30)
print(f"\nRL latest vs Pre-train → wins:{w:2d}  loses:{l:2d}  draws:{d:2d}")



Games:   0%|          | 0/50 [00:00<?, ?it/s]

Pre-train  vs Random → wins:50  loses: 0  draws: 0


Games:   0%|          | 0/50 [00:00<?, ?it/s]

RL latest  vs Random → wins:50  loses: 0  draws: 0


Games:   0%|          | 0/30 [00:00<?, ?it/s]


RL latest vs Pre-train → wins:30  loses: 0  draws: 0


In [30]:
# ╔════════════════════════════════════════╗
# ║  In-notebook MicroGo visual interface  ║
# ╚════════════════════════════════════════╝
import json, uuid, IPython.display as D
from IPython.display import HTML
from google.colab import output

BOARD_SIZE = 5
cell_id    = "microgo-" + str(uuid.uuid4())

# ---------- JS / HTML ----------
html = f"""
<div id='{cell_id}'></div>
<script>
const N = {BOARD_SIZE}, cellSize = 60, stoneR = 20;
const root = document.getElementById('{cell_id}');
root.style.cssText = `width:${{N*cellSize}}px;height:${{N*cellSize}}px;
                      position:relative;border:1px solid #555`;

/* empty grid */
for (let y=0;y<N;++y)
 for (let x=0;x<N;++x){{
   const s=document.createElement('div');
   s.style.cssText=`width:${{cellSize}}px;height:${{cellSize}}px;
                    box-sizing:border-box;border:1px solid #aaa;
                    position:absolute;left:${{x*cellSize}}px;top:${{y*cellSize}}px`;
   s.dataset.x=x; s.dataset.y=y;
   s.onclick=()=>google.colab.kernel.invokeFunction('notebook.on_click',[x,y],{{}});
   root.appendChild(s);
}}

function redraw(stonesJSON){{
  [...root.querySelectorAll('svg')].forEach(e=>e.remove());
  for (const [x,y,c] of JSON.parse(stonesJSON)){{
      const svg=document.createElementNS('http://www.w3.org/2000/svg','svg');
      svg.setAttribute('width',cellSize);svg.setAttribute('height',cellSize);
      svg.style.cssText=`position:absolute;left:${{x*cellSize}}px;top:${{y*cellSize}}px`;
      const circ=document.createElementNS(svg.namespaceURI,'circle');
      circ.setAttribute('cx',cellSize/2);circ.setAttribute('cy',cellSize/2);
      circ.setAttribute('r',stoneR);
      circ.setAttribute('fill',c==1?'#eee':'#222');
      svg.appendChild(circ); root.appendChild(svg);
  }}
}}
</script>
"""

# ---------- Python ↔ JS ----------
current_board = game.getInitBoard()
cur_player    = 1                       # human (White) begins
ai_agent      = make_mcts_agent(nnet,sims=60)   # already-trained network

def board_json(b):
    stones=[]
    for y in range(BOARD_SIZE):
        for x in range(BOARD_SIZE):
            col,_ = b.pieces[y][x]
            if col: stones.append([x,y,col])
    return json.dumps(stones)

def winner_msg(res):                    # res >0 → White, <0 → Black, else draw
    return "You win! 🎉" if res>0 else "AI wins!" if res<0 else "Draw!"

def play_action(action):
    global current_board, cur_player
    if game.getValidMoves(current_board, cur_player)[action]==0:
        return False                    # illegal move
    current_board, cur_player = game.getNextState(current_board, cur_player, action)
    return True

def on_click(x, y):
    if not play_action(BOARD_SIZE*x + y):         # human move
        return
    # AI reply (if game not ended by human move)
    if game.getGameEnded(current_board, cur_player)==0:
        play_action(ai_agent(current_board))      # AI move
    # redraw & maybe declare winner
    output.eval_js(f"redraw({json.dumps(board_json(current_board))})")
    res = game.getGameEnded(current_board, cur_player)
    if res!=0:
        print(winner_msg(res))

output.register_callback('notebook.on_click', on_click)

# ------ show board ------
D.display(HTML(html))
output.eval_js(f"redraw({json.dumps(board_json(current_board))})")


You win! 🎉
