<a href="https://colab.research.google.com/github/MaximusDonald/medipass/blob/main/Chess.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Cellule 0 – Installation et imports

In [None]:
# ============================================================
# CELLULE 0 – Installation des dépendances
# ============================================================
!pip install python-chess --quiet
!pip install h5py --quiet

import h5py
import os
import json
import random
import numpy as np
import torch
import chess
import chess.pgn
import io
import pickle
from torch.utils.data import Dataset, DataLoader
from typing import List, Tuple, Optional
from tqdm.auto import tqdm

# Reproductibilité
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Versions
print(f"PyTorch     : {torch.__version__}")
print(f"Chess       : {chess.__version__}")
print(f"CUDA dispo  : {torch.cuda.is_available()}")
print(f"GPU         : {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/6.1 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/6.1 MB[0m [31m33.9 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m6.1/6.1 MB[0m [31m89.1 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.1/6.1 MB[0m [31m67.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for chess (setup.py) ... [?25l[?25hdone
PyTorch     : 2.10.0+cu128
Chess       : 1.11.2
CUDA dispo  : True
GPU         : Tesla T4


# Cellule 1 – Montage Google Drive et configuration des chemins

In [None]:
# ============================================================
# CELLULE 1 – Google Drive + chemins
# ============================================================
from google.colab import drive
drive.mount('/content/drive')

# ── À MODIFIER selon ton organisation Drive ──────────────────
NDJSON_PATH   = "/content/drive/MyDrive/chess/lichess_games.ndjson"
OUTPUT_DIR    = "/content/drive/MyDrive/chess/dataset_processed"
# ─────────────────────────────────────────────────────────────

os.makedirs(OUTPUT_DIR, exist_ok=True)

POSITIONS_PATH = os.path.join(OUTPUT_DIR, "positions.npz")
META_PATH      = os.path.join(OUTPUT_DIR, "meta.pkl")

print(f"Source      : {NDJSON_PATH}")
print(f"Destination : {OUTPUT_DIR}")
print(f"Fichier     : {os.path.exists(NDJSON_PATH)}")

Mounted at /content/drive
Source      : /content/drive/MyDrive/chess/lichess_games.ndjson
Destination : /content/drive/MyDrive/chess/dataset_processed
Fichier     : True


# Cellule 2 – Encodage UCI → index (4672 coups)

In [None]:
# ============================================================
# CELLULE 2 – Mapping UCI moves ↔ index (4672 coups)
# Encodage standard Leela/AlphaZero
# ============================================================

def build_move_index() -> Tuple[dict, dict]:
    """
    Construit deux dictionnaires :
      uci2idx : str  -> int  (coup UCI -> index 0..4671)
      idx2uci : int  -> str  (index -> coup UCI)
    Couvre tous les coups légaux possibles aux échecs standard.
    """
    moves = []

    # Toutes les combinaisons source → destination
    squares = list(chess.SQUARES)  # 0..63
    for from_sq in squares:
        for to_sq in squares:
            if from_sq == to_sq:
                continue
            # Coup normal
            moves.append(chess.Move(from_sq, to_sq).uci())
            # Promotions (uniquement sur les rangées de promotion)
            from_rank = chess.square_rank(from_sq)
            to_rank   = chess.square_rank(to_sq)
            if (from_rank == 6 and to_rank == 7) or \
               (from_rank == 1 and to_rank == 0):
                for promo in [chess.QUEEN, chess.ROOK,
                              chess.BISHOP, chess.KNIGHT]:
                    moves.append(chess.Move(from_sq, to_sq, promo).uci())

    # Dédoublonnage + tri pour stabilité
    moves = sorted(set(moves))
    uci2idx = {m: i for i, m in enumerate(moves)}
    idx2uci = {i: m for m, i in uci2idx.items()}

    print(f"Taille du vocabulaire de coups : {len(moves)}")
    return uci2idx, idx2uci

UCI2IDX, IDX2UCI = build_move_index()

Taille du vocabulaire de coups : 4544


# Cellule 3 – Encodage d'une position en tenseur [16, 8, 8]

In [None]:
# ============================================================
# CELLULE 3 – Encodage position → tenseur [16, 8, 8]
# ============================================================

# Ordre des pièces : P N B R Q K pour chaque couleur
PIECE_ORDER = [
    chess.PAWN, chess.KNIGHT, chess.BISHOP,
    chess.ROOK, chess.QUEEN,  chess.KING
]

def board_to_tensor(board: chess.Board) -> np.ndarray:
    """
    Encode un chess.Board en tenseur float32 [16, 8, 8].

    Plans :
      0– 5  : pièces blanches (P N B R Q K)
      6–11  : pièces noires   (p n b r q k)
      12    : trait (1.0 = blancs, 0.0 = noirs)
      13    : droits de roque (4 cases encodées)
      14    : case en passant (1.0 sur la case cible)
      15    : compteur demi-coups sans capture (normalisé)
    """
    tensor = np.zeros((16, 8, 8), dtype=np.float32)

    # Plans 0–11 : pièces
    for plane_idx, piece_type in enumerate(PIECE_ORDER):
        # Blancs
        for sq in board.pieces(piece_type, chess.WHITE):
            r, c = divmod(sq, 8)
            tensor[plane_idx, r, c] = 1.0
        # Noirs
        for sq in board.pieces(piece_type, chess.BLACK):
            r, c = divmod(sq, 8)
            tensor[plane_idx + 6, r, c] = 1.0

    # Plan 12 : trait
    tensor[12, :, :] = 1.0 if board.turn == chess.WHITE else 0.0

    # Plan 13 : droits de roque (encodés sur 4 cases fixes)
    if board.has_kingside_castling_rights(chess.WHITE):
        tensor[13, 0, 7] = 1.0
    if board.has_queenside_castling_rights(chess.WHITE):
        tensor[13, 0, 0] = 1.0
    if board.has_kingside_castling_rights(chess.BLACK):
        tensor[13, 7, 7] = 1.0
    if board.has_queenside_castling_rights(chess.BLACK):
        tensor[13, 7, 0] = 1.0

    # Plan 14 : en passant
    if board.ep_square is not None:
        r, c = divmod(board.ep_square, 8)
        tensor[14, r, c] = 1.0

    # Plan 15 : demi-coups sans capture (normalisé sur 100)
    tensor[15, :, :] = min(board.halfmove_clock / 100.0, 1.0)

    return tensor


def move_to_index(move: chess.Move, uci2idx: dict) -> Optional[int]:
    """Convertit un coup chess.Move en index entier. Retourne None si inconnu."""
    uci = move.uci()
    return uci2idx.get(uci, None)

# Cellule 4 – Parser NDJSON et extraction des positions

In [None]:
# ============================================================
# CELLULE 4 – Parser NDJSON en streaming + extraction
# ============================================================

def parse_ndjson_to_arrays(
    ndjson_path: str,
    uci2idx: dict,
    min_elo: int = 1800,
    max_games: Optional[int] = None,
    skip_first_n_moves: int = 0
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Lit le fichier NDJSON en streaming, filtre les parties standard,
    et extrait (position, coup joué, résultat) pour chaque demi-coup.

    Args:
        ndjson_path       : chemin vers le fichier .ndjson
        uci2idx           : dictionnaire UCI -> index
        min_elo           : Elo minimum des deux joueurs
        max_games         : limite de parties (None = toutes)
        skip_first_n_moves: ignore les N premiers coups (théorie d'ouverture)

    Returns:
        positions  : np.ndarray float16 [N, 16, 8, 8]
        move_idxs  : np.ndarray int32   [N]
        outcomes   : np.ndarray float32 [N]  (1.0 / 0.5 / 0.0)
    """

    positions_list  = []
    move_idxs_list  = []
    outcomes_list   = []

    games_parsed    = 0
    games_skipped   = 0
    positions_total = 0
    errors          = 0

    with open(ndjson_path, "r", encoding="utf-8") as f:
        for line_num, line in enumerate(tqdm(f, desc="Parsing NDJSON")):
            line = line.strip()
            if not line:
                continue

            try:
                game = json.loads(line)
            except json.JSONDecodeError:
                errors += 1
                continue

            # ── Filtres ───────────────────────────────────────────
            if game.get("variant", "") != "standard":
                games_skipped += 1
                continue

            if not game.get("rated", False):
                games_skipped += 1
                continue

            players  = game.get("players", {})
            w_rating = players.get("white", {}).get("rating", 0)
            b_rating = players.get("black", {}).get("rating", 0)
            if w_rating < min_elo or b_rating < min_elo:
                games_skipped += 1
                continue

            moves_san = game.get("moves", "").strip()
            if not moves_san:
                games_skipped += 1
                continue

            # ── Résultat de la partie ─────────────────────────────
            winner = game.get("winner", None)
            if winner == "white":
                outcome_white = 1.0
            elif winner == "black":
                outcome_white = 0.0
            else:
                outcome_white = 0.5   # nulle ou non terminée

            # ── Replay de la partie ───────────────────────────────
            board = chess.Board()
            try:
                pgn_io  = io.StringIO(
                    f'[Variant "Standard"]\n\n{moves_san}'
                )
                pgn_game = chess.pgn.read_game(pgn_io)
                if pgn_game is None:
                    games_skipped += 1
                    continue
                moves_list = list(pgn_game.mainline_moves())
            except Exception:
                errors += 1
                continue

            if len(moves_list) < skip_first_n_moves + 2:
                games_skipped += 1
                continue

            for move_num, move in enumerate(moves_list):
                if move_num < skip_first_n_moves:
                    board.push(move)
                    continue

                # Vérification légalité
                if move not in board.legal_moves:
                    break

                # Encodage position
                tensor = board_to_tensor(board)

                # Encodage coup
                idx = move_to_index(move, uci2idx)
                if idx is None:
                    board.push(move)
                    continue

                # Résultat du point de vue du joueur au trait
                if board.turn == chess.WHITE:
                    outcome = outcome_white
                else:
                    outcome = 1.0 - outcome_white

                positions_list.append(tensor)
                move_idxs_list.append(idx)
                outcomes_list.append(outcome)
                positions_total += 1

                board.push(move)

            games_parsed += 1
            if max_games and games_parsed >= max_games:
                break

    print(f"\n{'='*50}")
    print(f"Parties parsées   : {games_parsed:>10,}")
    print(f"Parties ignorées  : {games_skipped:>10,}")
    print(f"Erreurs           : {errors:>10,}")
    print(f"Positions totales : {positions_total:>10,}")
    print(f"{'='*50}")

    # Conversion en arrays numpy
    positions  = np.array(positions_list,  dtype=np.float16)
    move_idxs  = np.array(move_idxs_list,  dtype=np.int32)
    outcomes   = np.array(outcomes_list,   dtype=np.float32)

    return positions, move_idxs, outcomes

# Cellule 5 – Lancement du parsing et sauvegarde

In [6]:
# ============================================================
# CELLULE 5 bis – Parser NDJSON → HDF5 en streaming
# RAM utilisée à tout instant : < 500 Mo
# ============================================================

HDF5_PATH = os.path.join(OUTPUT_DIR, "chess_dataset.h5")
CHUNK_SIZE = 10_000   # positions écrites sur disque toutes les 10k

def parse_ndjson_to_hdf5(
    ndjson_path:        str,
    hdf5_path:          str,
    uci2idx:            dict,
    min_elo:            int   = 1800,
    max_games:          int   = None,
    skip_first_n_moves: int   = 0,
    chunk_size:         int   = CHUNK_SIZE
):
    """
    Parse le NDJSON en streaming et écrit les positions
    dans un fichier HDF5 chunk par chunk.
    Ne garde jamais plus de `chunk_size` positions en RAM.
    """

    # Buffers temporaires (vidés toutes les chunk_size positions)
    buf_positions = []
    buf_moves     = []
    buf_outcomes  = []

    games_parsed    = 0
    games_skipped   = 0
    positions_total = 0
    errors          = 0

    def flush_buffer(h5f, buf_pos, buf_mov, buf_out, total_written):
        """Écrit le buffer courant dans le HDF5 et le vide."""
        n = len(buf_mov)
        if n == 0:
            return total_written

        arr_pos = np.array(buf_pos, dtype=np.float16)
        arr_mov = np.array(buf_mov, dtype=np.int32)
        arr_out = np.array(buf_out, dtype=np.float32)

        # Agrandit les datasets HDF5 et écrit
        new_size = total_written + n
        h5f["positions"].resize(new_size, axis=0)
        h5f["move_idxs"].resize(new_size, axis=0)
        h5f["outcomes"].resize(new_size, axis=0)

        h5f["positions"][total_written:new_size] = arr_pos
        h5f["move_idxs"][total_written:new_size] = arr_mov
        h5f["outcomes"][total_written:new_size]  = arr_out

        return new_size

    # Création du fichier HDF5 avec datasets extensibles
    with h5py.File(hdf5_path, "w") as h5f:
        h5f.create_dataset(
            "positions",
            shape=(0, 16, 8, 8),
            maxshape=(None, 16, 8, 8),
            dtype=np.float16,
            chunks=(chunk_size, 16, 8, 8),
            compression="lzf"   # compression rapide (pas lente comme gzip)
        )
        h5f.create_dataset(
            "move_idxs",
            shape=(0,),
            maxshape=(None,),
            dtype=np.int32,
            chunks=(chunk_size,)
        )
        h5f.create_dataset(
            "outcomes",
            shape=(0,),
            maxshape=(None,),
            dtype=np.float32,
            chunks=(chunk_size,)
        )

        total_written = 0

        with open(ndjson_path, "r", encoding="utf-8") as f:
            pbar = tqdm(f, desc="Parsing NDJSON", unit=" lignes")
            for line in pbar:
                line = line.strip()
                if not line:
                    continue

                # ── Parse JSON ───────────────────────────────────
                try:
                    game = json.loads(line)
                except json.JSONDecodeError:
                    errors += 1
                    continue

                # ── Filtres ──────────────────────────────────────
                if game.get("variant", "") != "standard":
                    games_skipped += 1
                    continue

                if not game.get("rated", False):
                    games_skipped += 1
                    continue

                players  = game.get("players", {})
                w_rating = players.get("white", {}).get("rating", 0)
                b_rating = players.get("black", {}).get("rating", 0)
                if w_rating < min_elo or b_rating < min_elo:
                    games_skipped += 1
                    continue

                moves_san = game.get("moves", "").strip()
                if not moves_san:
                    games_skipped += 1
                    continue

                # ── Résultat ─────────────────────────────────────
                winner = game.get("winner", None)
                if winner == "white":
                    outcome_white = 1.0
                elif winner == "black":
                    outcome_white = 0.0
                else:
                    outcome_white = 0.5

                # ── Replay ───────────────────────────────────────
                board = chess.Board()
                try:
                    pgn_io   = io.StringIO(
                        f'[Variant "Standard"]\n\n{moves_san}'
                    )
                    pgn_game = chess.pgn.read_game(pgn_io)
                    if pgn_game is None:
                        games_skipped += 1
                        continue
                    moves_list = list(pgn_game.mainline_moves())
                except Exception:
                    errors += 1
                    continue

                if len(moves_list) < skip_first_n_moves + 2:
                    games_skipped += 1
                    continue

                for move_num, move in enumerate(moves_list):
                    if move_num < skip_first_n_moves:
                        board.push(move)
                        continue

                    if move not in board.legal_moves:
                        break

                    tensor = board_to_tensor(board)
                    idx    = move_to_index(move, uci2idx)
                    if idx is None:
                        board.push(move)
                        continue

                    outcome = (
                        outcome_white
                        if board.turn == chess.WHITE
                        else 1.0 - outcome_white
                    )

                    buf_positions.append(tensor)
                    buf_moves.append(idx)
                    buf_outcomes.append(outcome)
                    positions_total += 1

                    board.push(move)

                    # ── Flush si buffer plein ─────────────────────
                    if len(buf_moves) >= chunk_size:
                        total_written = flush_buffer(
                            h5f,
                            buf_positions, buf_moves, buf_outcomes,
                            total_written
                        )
                        buf_positions.clear()
                        buf_moves.clear()
                        buf_outcomes.clear()

                games_parsed += 1

                # Mise à jour barre de progression
                if games_parsed % 5000 == 0:
                    pbar.set_postfix({
                        "parties": f"{games_parsed:,}",
                        "positions": f"{positions_total:,}",
                        "RAM_Mo": f"{__import__('psutil').Process().memory_info().rss/1e6:.0f}"
                    })

                if max_games and games_parsed >= max_games:
                    break

        # Flush final du buffer résiduel
        if buf_moves:
            total_written = flush_buffer(
                h5f,
                buf_positions, buf_moves, buf_outcomes,
                total_written
            )

        print(f"\n{'='*50}")
        print(f"Parties parsées   : {games_parsed:>10,}")
        print(f"Parties ignorées  : {games_skipped:>10,}")
        print(f"Erreurs           : {errors:>10,}")
        print(f"Positions totales : {total_written:>10,}")
        print(f"{'='*50}")
        print(f"HDF5 sauvegardé   : {hdf5_path}")
        print(f"Taille fichier    : {os.path.getsize(hdf5_path)/1e6:.1f} Mo")

    return total_written


# Lancement
!pip install psutil --quiet
total_positions = parse_ndjson_to_hdf5(
    ndjson_path        = NDJSON_PATH,
    hdf5_path          = HDF5_PATH,
    uci2idx            = UCI2IDX,
    min_elo            = 1800,
    max_games          = None,
    skip_first_n_moves = 0,
    chunk_size         = CHUNK_SIZE
)

# Sauvegarde du mapping UCI
with open(META_PATH, "wb") as f:
    pickle.dump({"uci2idx": UCI2IDX, "idx2uci": IDX2UCI}, f)
print(f"Meta sauvegardé : {META_PATH}")

Parsing NDJSON: 0 lignes [00:00, ? lignes/s]


Parties parsées   :    519,139
Parties ignorées  :     79,531
Erreurs           :          0
Positions totales : 44,747,887
HDF5 sauvegardé   : /content/drive/MyDrive/chess/dataset_processed/chess_dataset.h5
Taille fichier    : 4210.9 Mo
Meta sauvegardé : /content/drive/MyDrive/chess/dataset_processed/meta.pkl


# Cellule 6 – Split train/val/test + Dataset PyTorch

In [7]:
# ============================================================
# CELLULE 6 bis – ChessHDF5Dataset + splits + DataLoaders
# ============================================================

class ChessHDF5Dataset(Dataset):
    """
    Dataset PyTorch qui lit le HDF5 à la volée sans tout charger en RAM.
    Chaque worker ouvre sa propre connexion HDF5 (thread-safe).
    """

    def __init__(self, hdf5_path: str, indices: np.ndarray):
        self.hdf5_path = hdf5_path
        self.indices   = indices
        self._h5f      = None   # ouvert par worker, pas par le process principal

    def __len__(self) -> int:
        return len(self.indices)

    def __getitem__(self, idx: int):
        # Ouverture lazy (une fois par worker)
        if self._h5f is None:
            self._h5f = h5py.File(self.hdf5_path, "r")

        real_idx = int(self.indices[idx])
        position = torch.from_numpy(
            self.hdf5_path and
            self._h5f["positions"][real_idx].astype(np.float32)
        )
        move_idx = int(self._h5f["move_idxs"][real_idx])
        outcome  = float(self._h5f["outcomes"][real_idx])

        return (
            torch.from_numpy(
                self._h5f["positions"][real_idx].astype(np.float32)
            ),
            torch.tensor(move_idx, dtype=torch.long),
            torch.tensor(outcome,  dtype=torch.float32)
        )

    def __del__(self):
        if self._h5f is not None:
            try:
                self._h5f.close()
            except Exception:
                pass


def create_hdf5_splits(
    hdf5_path:  str,
    n_total:    int,
    val_ratio:  float = 0.10,
    test_ratio: float = 0.05,
    seed:       int   = SEED
):
    rng     = np.random.default_rng(seed)
    indices = rng.permutation(n_total)

    n_test  = int(n_total * test_ratio)
    n_val   = int(n_total * val_ratio)
    n_train = n_total - n_val - n_test

    print(f"Split ({n_total:,} positions) :")
    print(f"  Train : {n_train:>10,}  ({n_train/n_total*100:.1f}%)")
    print(f"  Val   : {n_val:>10,}  ({n_val/n_total*100:.1f}%)")
    print(f"  Test  : {n_test:>10,}  ({n_test/n_total*100:.1f}%)")

    train_ds = ChessHDF5Dataset(hdf5_path, indices[:n_train])
    val_ds   = ChessHDF5Dataset(hdf5_path, indices[n_train:n_train+n_val])
    test_ds  = ChessHDF5Dataset(hdf5_path, indices[n_train+n_val:])

    return train_ds, val_ds, test_ds


BATCH_SIZE = 512

train_ds, val_ds, test_ds = create_hdf5_splits(HDF5_PATH, total_positions)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False,
                          num_workers=2, pin_memory=True)

print(f"\nDataLoaders prêts — batches train : {len(train_loader):,}")

Split (44,747,887 positions) :
  Train : 38,035,705  (85.0%)
  Val   :  4,474,788  (10.0%)
  Test  :  2,237,394  (5.0%)

DataLoaders prêts — batches train : 74,289


# Cellule 7 – Vérification et test rapide

In [8]:
# ============================================================
# CELLULE 7 bis – Sanity check
# ============================================================
sample_pos, sample_moves, sample_outcomes = next(iter(train_loader))

print(f"Positions  : {sample_pos.shape}  dtype={sample_pos.dtype}")
print(f"Moves      : {sample_moves.shape}  range=[{sample_moves.min()}, {sample_moves.max()}]")
print(f"Outcomes   : {sample_outcomes.unique()}")
print(f"UCI premier coup : {IDX2UCI[sample_moves[0].item()]}")

Positions  : torch.Size([512, 16, 8, 8])  dtype=torch.float32
Moves      : torch.Size([512])  range=[3, 4536]
Outcomes   : tensor([0.0000, 0.5000, 1.0000])
UCI premier coup : g8h8


# Cellule 8 – Architecture ResNet

In [9]:
# ============================================================
# CELLULE 8 – Architecture ResNet + Policy/Value Heads
# ============================================================
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    """
    Bloc résiduel standard :
    Conv(3×3) → BN → ReLU → Conv(3×3) → BN → (+skip) → ReLU
    """
    def __init__(self, channels: int):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels,
                               kernel_size=3, padding=1, bias=False)
        self.bn1   = nn.BatchNorm2d(channels)
        self.conv2 = nn.Conv2d(channels, channels,
                               kernel_size=3, padding=1, bias=False)
        self.bn2   = nn.BatchNorm2d(channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.bn2(self.conv2(x))
        return F.relu(x + residual)


class ChessResNet(nn.Module):
    """
    Architecture ResNet pour les échecs.

    Entrée  : [B, 16, 8, 8]
    Sorties : logits_policy [B, 4672]
              logit_value   [B, 1]

    Paramètres totaux : ~7.2M (128 filtres, 10 blocs)
    Mémoire GPU (fp16, batch=512) : ~3.5 Go
    """

    def __init__(
        self,
        in_channels:   int = 16,
        num_filters:   int = 128,
        num_blocks:    int = 10,
        policy_size:   int = 4672
    ):
        super().__init__()

        # ── Stem : conv initiale ──────────────────────────────────
        self.stem = nn.Sequential(
            nn.Conv2d(in_channels, num_filters,
                      kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(num_filters),
            nn.ReLU(inplace=True)
        )

        # ── Corps : blocs résiduels ───────────────────────────────
        self.res_blocks = nn.Sequential(
            *[ResidualBlock(num_filters) for _ in range(num_blocks)]
        )

        # ── Policy Head ───────────────────────────────────────────
        # Conv 1×1 : 128 → 2 filtres, puis FC → 4672 logits
        self.policy_conv = nn.Sequential(
            nn.Conv2d(num_filters, 2, kernel_size=1, bias=False),
            nn.BatchNorm2d(2),
            nn.ReLU(inplace=True)
        )
        self.policy_fc = nn.Linear(2 * 8 * 8, policy_size)

        # ── Value Head ────────────────────────────────────────────
        # Conv 1×1 : 128 → 1 filtre, puis FC(64) → FC(1) → sigmoid
        self.value_conv = nn.Sequential(
            nn.Conv2d(num_filters, 1, kernel_size=1, bias=False),
            nn.BatchNorm2d(1),
            nn.ReLU(inplace=True)
        )
        self.value_fc = nn.Sequential(
            nn.Linear(1 * 8 * 8, 64),
            nn.ReLU(inplace=True),
            nn.Linear(64, 1)
        )

        # ── Initialisation des poids ──────────────────────────────
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(
                    m.weight, mode="fan_out", nonlinearity="relu"
                )
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)

    def forward(
        self, x: torch.Tensor
    ):
        # Trunk
        x = self.stem(x)
        x = self.res_blocks(x)

        # Policy head
        p = self.policy_conv(x)
        p = p.view(p.size(0), -1)       # [B, 128]
        logits_policy = self.policy_fc(p)  # [B, 4672]

        # Value head
        v = self.value_conv(x)
        v = v.view(v.size(0), -1)       # [B, 64]
        logit_value = self.value_fc(v)  # [B, 1]

        return logits_policy, logit_value


# ── Instanciation et vérification ────────────────────────────
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device : {DEVICE}")

model = ChessResNet(
    in_channels = 16,
    num_filters  = 128,
    num_blocks   = 10,
    policy_size  = len(UCI2IDX)
).to(DEVICE)

# Compte les paramètres
total_params = sum(p.numel() for p in model.parameters())
train_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Paramètres totaux    : {total_params:,}")
print(f"Paramètres entraîn.  : {train_params:,}")

# Test forward pass
with torch.no_grad():
    dummy = torch.randn(4, 16, 8, 8).to(DEVICE)
    pol, val = model(dummy)
    print(f"Policy output shape  : {pol.shape}")
    print(f"Value output shape   : {val.shape}")
    print("Forward pass OK ✅")

Device : cuda
Paramètres totaux    : 3,563,719
Paramètres entraîn.  : 3,563,719
Policy output shape  : torch.Size([4, 4544])
Value output shape   : torch.Size([4, 1])
Forward pass OK ✅


# Cellule 9 – Fonctions de loss et métriques

In [None]:
# ============================================================
# CELLULE 9 – Loss functions + métriques
# ============================================================

def compute_loss(
    logits_policy: torch.Tensor,   # [B, 4672]
    logit_value:   torch.Tensor,   # [B, 1]
    target_moves:  torch.Tensor,   # [B]     int64
    target_values: torch.Tensor,   # [B]     float32
    value_weight:  float = 1.0
) -> tuple:
    """
    Calcule la loss totale et ses composantes.

    policy_loss : CrossEntropy(logits, move_idx)
    value_loss  : MSE(sigmoid(logit_value), target_value)
    total_loss  : policy_loss + value_weight * value_loss
    """
    # Policy loss
    policy_loss = F.cross_entropy(logits_policy, target_moves)

    # Value loss
    pred_value  = torch.sigmoid(logit_value).squeeze(1)  # [B]
    value_loss  = F.mse_loss(pred_value, target_values)

    total_loss = policy_loss + value_weight * value_loss

    return total_loss, policy_loss, value_loss


@torch.no_grad()
def compute_accuracy(
    logits_policy: torch.Tensor,
    target_moves:  torch.Tensor
) -> tuple:
    """
    Calcule Accuracy@1 et Accuracy@5.
    """
    # Top-1
    pred_top1  = logits_policy.argmax(dim=1)
    acc1       = (pred_top1 == target_moves).float().mean().item()

    # Top-5
    top5_preds = logits_policy.topk(5, dim=1).indices  # [B, 5]
    acc5       = (top5_preds == target_moves.unsqueeze(1)).any(dim=1).float().mean().item()

    return acc1, acc5

# Cellule 10 – Configuration entraînement

In [None]:
# ============================================================
# CELLULE 10 – Hyperparamètres + optimizer + scheduler
# ============================================================

# ── Hyperparamètres ───────────────────────────────────────────
NUM_EPOCHS    = 5
LR_MAX        = 3e-3      # pic OneCycleLR
LR_MIN        = 1e-5      # fin d'entraînement
WEIGHT_DECAY  = 1e-4
VALUE_WEIGHT  = 1.0       # pondération value loss
GRAD_CLIP     = 1.0       # norme max du gradient
CHECKPOINT_EVERY_N_BATCHES = 10_000   # sauvegarde toutes les 10k batches

CHECKPOINT_DIR = os.path.join(OUTPUT_DIR, "checkpoints")
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# ── Optimizer ─────────────────────────────────────────────────
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr           = LR_MAX,
    weight_decay = WEIGHT_DECAY
)

# ── Scheduler OneCycleLR ──────────────────────────────────────
# Monte jusqu'à LR_MAX en 30% des steps, descend ensuite
total_steps = NUM_EPOCHS * len(train_loader)

scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr          = LR_MAX,
    total_steps     = total_steps,
    pct_start       = 0.3,
    anneal_strategy = "cos",
    final_div_factor= LR_MAX / LR_MIN
)

# ── Mixed precision (AMP fp16) ────────────────────────────────
scaler = torch.cuda.amp.GradScaler(enabled=True)

print(f"Epochs          : {NUM_EPOCHS}")
print(f"Batches/epoch   : {len(train_loader):,}")
print(f"Steps totaux    : {total_steps:,}")
print(f"LR max          : {LR_MAX}")
print(f"LR min          : {LR_MIN}")
print(f"Checkpoints     : {CHECKPOINT_DIR}")

# Cellule 11 – Boucle d'entraînement complète

In [None]:
# ============================================================
# CELLULE 11 – Boucle d'entraînement
# ============================================================
import time

def save_checkpoint(
    epoch:        int,
    batch_idx:    int,
    model:        nn.Module,
    optimizer:    torch.optim.Optimizer,
    scheduler,
    scaler,
    metrics:      dict,
    path:         str
):
    torch.save({
        "epoch":          epoch,
        "batch_idx":      batch_idx,
        "model_state":    model.state_dict(),
        "optimizer_state":optimizer.state_dict(),
        "scheduler_state":scheduler.state_dict(),
        "scaler_state":   scaler.state_dict(),
        "metrics":        metrics,
        "uci2idx":        UCI2IDX,
    }, path)


def load_checkpoint(path: str, model, optimizer, scheduler, scaler):
    """Reprend l'entraînement depuis un checkpoint."""
    ckpt = torch.load(path, map_location=DEVICE)
    model.load_state_dict(ckpt["model_state"])
    optimizer.load_state_dict(ckpt["optimizer_state"])
    scheduler.load_state_dict(ckpt["scheduler_state"])
    scaler.load_state_dict(ckpt["scaler_state"])
    print(f"Checkpoint chargé : epoch {ckpt['epoch']}, "
          f"batch {ckpt['batch_idx']}")
    return ckpt["epoch"], ckpt["batch_idx"], ckpt["metrics"]


@torch.no_grad()
def validate(model, val_loader, value_weight, device):
    """Passe de validation complète."""
    model.eval()
    total_loss = total_policy = total_value = 0.0
    total_acc1 = total_acc5  = 0.0
    n_batches  = 0

    for positions, moves, outcomes in tqdm(
        val_loader, desc="Validation", leave=False
    ):
        positions = positions.to(device, non_blocking=True)
        moves     = moves.to(device, non_blocking=True)
        outcomes  = outcomes.to(device, non_blocking=True)

        with torch.cuda.amp.autocast():
            logits_policy, logit_value = model(positions)
            loss, p_loss, v_loss = compute_loss(
                logits_policy, logit_value,
                moves, outcomes, value_weight
            )

        acc1, acc5 = compute_accuracy(logits_policy, moves)

        total_loss   += loss.item()
        total_policy += p_loss.item()
        total_value  += v_loss.item()
        total_acc1   += acc1
        total_acc5   += acc5
        n_batches    += 1

    model.train()
    return {
        "val_loss":        total_loss   / n_batches,
        "val_policy_loss": total_policy / n_batches,
        "val_value_loss":  total_value  / n_batches,
        "val_acc1":        total_acc1   / n_batches,
        "val_acc5":        total_acc5   / n_batches,
    }


# ── Boucle principale ─────────────────────────────────────────
def train(
    model, optimizer, scheduler, scaler,
    train_loader, val_loader,
    num_epochs, value_weight, grad_clip,
    checkpoint_dir, checkpoint_every_n_batches,
    device,
    start_epoch=0, start_batch=0
):
    model.train()
    history = []
    best_val_acc1  = 0.0
    best_ckpt_path = os.path.join(checkpoint_dir, "best_model.pt")

    for epoch in range(start_epoch, num_epochs):
        epoch_start = time.time()

        # Métriques running (réinitialisées chaque epoch)
        run_loss = run_policy = run_value = 0.0
        run_acc1 = run_acc5  = 0.0
        n_batches = 0

        pbar = tqdm(
            enumerate(train_loader),
            total=len(train_loader),
            desc=f"Epoch {epoch+1}/{num_epochs}"
        )

        for batch_idx, (positions, moves, outcomes) in pbar:

            # Ignore les batches déjà traités si reprise
            if epoch == start_epoch and batch_idx < start_batch:
                continue

            positions = positions.to(device, non_blocking=True)
            moves     = moves.to(device, non_blocking=True)
            outcomes  = outcomes.to(device, non_blocking=True)

            # ── Forward (fp16) ────────────────────────────────
            optimizer.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast():
                logits_policy, logit_value = model(positions)
                loss, p_loss, v_loss = compute_loss(
                    logits_policy, logit_value,
                    moves, outcomes, value_weight
                )

            # ── Backward + gradient clipping ──────────────────
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()

            # ── Métriques running ─────────────────────────────
            acc1, acc5 = compute_accuracy(logits_policy, moves)
            run_loss   += loss.item()
            run_policy += p_loss.item()
            run_value  += v_loss.item()
            run_acc1   += acc1
            run_acc5   += acc5
            n_batches  += 1

            # Mise à jour barre
            if n_batches % 100 == 0:
                pbar.set_postfix({
                    "loss":   f"{run_loss/n_batches:.4f}",
                    "acc@1":  f"{run_acc1/n_batches*100:.1f}%",
                    "acc@5":  f"{run_acc5/n_batches*100:.1f}%",
                    "lr":     f"{scheduler.get_last_lr()[0]:.2e}"
                })

            # ── Checkpoint périodique ─────────────────────────
            if (batch_idx + 1) % checkpoint_every_n_batches == 0:
                ckpt_path = os.path.join(
                    checkpoint_dir,
                    f"ckpt_e{epoch+1}_b{batch_idx+1}.pt"
                )
                save_checkpoint(
                    epoch, batch_idx, model, optimizer,
                    scheduler, scaler,
                    {"batch_loss": run_loss / n_batches},
                    ckpt_path
                )
                print(f"\n  → Checkpoint sauvegardé : {ckpt_path}")

        # ── Validation fin d'epoch ────────────────────────────
        val_metrics = validate(
            model, val_loader, value_weight, device
        )

        epoch_time = (time.time() - epoch_start) / 60

        train_metrics = {
            "epoch":             epoch + 1,
            "train_loss":        run_loss   / n_batches,
            "train_policy_loss": run_policy / n_batches,
            "train_value_loss":  run_value  / n_batches,
            "train_acc1":        run_acc1   / n_batches,
            "train_acc5":        run_acc5   / n_batches,
            "epoch_time_min":    epoch_time,
            **val_metrics
        }
        history.append(train_metrics)

        print(f"\n{'='*65}")
        print(f"Epoch {epoch+1}/{num_epochs}  ({epoch_time:.1f} min)")
        print(f"  Train  loss={train_metrics['train_loss']:.4f}  "
              f"acc@1={train_metrics['train_acc1']*100:.2f}%  "
              f"acc@5={train_metrics['train_acc5']*100:.2f}%")
        print(f"  Val    loss={val_metrics['val_loss']:.4f}  "
              f"acc@1={val_metrics['val_acc1']*100:.2f}%  "
              f"acc@5={val_metrics['val_acc5']*100:.2f}%")
        print(f"{'='*65}\n")

        # ── Sauvegarde du meilleur modèle ─────────────────────
        if val_metrics["val_acc1"] > best_val_acc1:
            best_val_acc1 = val_metrics["val_acc1"]
            save_checkpoint(
                epoch, len(train_loader), model, optimizer,
                scheduler, scaler, train_metrics, best_ckpt_path
            )
            print(f"  ★ Meilleur modèle sauvegardé "
                  f"(val_acc@1={best_val_acc1*100:.2f}%)\n")

        # Checkpoint fin d'epoch
        epoch_ckpt = os.path.join(
            checkpoint_dir, f"epoch_{epoch+1}_final.pt"
        )
        save_checkpoint(
            epoch, len(train_loader), model, optimizer,
            scheduler, scaler, train_metrics, epoch_ckpt
        )

    return history


# ── Lancement ─────────────────────────────────────────────────
history = train(
    model         = model,
    optimizer     = optimizer,
    scheduler     = scheduler,
    scaler        = scaler,
    train_loader  = train_loader,
    val_loader    = val_loader,
    num_epochs    = NUM_EPOCHS,
    value_weight  = VALUE_WEIGHT,
    grad_clip     = GRAD_CLIP,
    checkpoint_dir= CHECKPOINT_DIR,
    checkpoint_every_n_batches = CHECKPOINT_EVERY_N_BATCHES,
    device        = DEVICE,
    start_epoch   = 0,
    start_batch   = 0
)

# Cellule 12 – Reprise après crash (optionnelle)

In [None]:
# ============================================================
# CELLULE 12 – Reprise depuis checkpoint si session plantée
# ============================================================

# Décommente et adapte le chemin si Colab crashe en cours de route

# RESUME_FROM = os.path.join(CHECKPOINT_DIR, "ckpt_e1_b10000.pt")
# start_epoch, start_batch, saved_metrics = load_checkpoint(
#     RESUME_FROM, model, optimizer, scheduler, scaler
# )
# history = train(
#     model, optimizer, scheduler, scaler,
#     train_loader, val_loader,
#     num_epochs    = NUM_EPOCHS,
#     value_weight  = VALUE_WEIGHT,
#     grad_clip     = GRAD_CLIP,
#     checkpoint_dir= CHECKPOINT_DIR,
#     checkpoint_every_n_batches = CHECKPOINT_EVERY_N_BATCHES,
#     device        = DEVICE,
#     start_epoch   = start_epoch,
#     start_batch   = start_batch + 1
# )

# Cellule 13 – Courbes d'apprentissage

In [None]:
# ============================================================
# CELLULE 13 – Visualisation des courbes
# ============================================================
import matplotlib.pyplot as plt

def plot_history(history: list):
    epochs = [h["epoch"] for h in history]

    fig, axes = plt.subplots(1, 3, figsize=(16, 4))

    # Loss
    axes[0].plot(epochs, [h["train_loss"] for h in history],
                 label="Train", marker="o")
    axes[0].plot(epochs, [h["val_loss"] for h in history],
                 label="Val",   marker="s")
    axes[0].set_title("Loss totale")
    axes[0].set_xlabel("Epoch")
    axes[0].legend()
    axes[0].grid(True)

    # Accuracy@1
    axes[1].plot(epochs, [h["train_acc1"]*100 for h in history],
                 label="Train acc@1", marker="o")
    axes[1].plot(epochs, [h["val_acc1"]*100 for h in history],
                 label="Val acc@1",   marker="s")
    axes[1].set_title("Policy Accuracy@1 (%)")
    axes[1].set_xlabel("Epoch")
    axes[1].legend()
    axes[1].grid(True)

    # Accuracy@5
    axes[2].plot(epochs, [h["train_acc5"]*100 for h in history],
                 label="Train acc@5", marker="o")
    axes[2].plot(epochs, [h["val_acc5"]*100 for h in history],
                 label="Val acc@5",   marker="s")
    axes[2].set_title("Policy Accuracy@5 (%)")
    axes[2].set_xlabel("Epoch")
    axes[2].legend()
    axes[2].grid(True)

    plt.tight_layout()
    plt.savefig(os.path.join(OUTPUT_DIR, "training_curves.png"), dpi=120)
    plt.show()
    print("Courbes sauvegardées sur Drive.")

if history:
    plot_history(history)