In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/chess-50k/chess_positions.pt


In [3]:
!pip install chess

Collecting chess
  Downloading chess-1.11.2.tar.gz (6.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.1/6.1 MB[0m [31m59.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: chess
  Building wheel for chess (setup.py) ... [?25l[?25hdone
  Created wheel for chess: filename=chess-1.11.2-py3-none-any.whl size=147775 sha256=3c589611ff10ca93dab73a27b31928281ee264e855b1b1b6e295fe7823a54d2c
  Stored in directory: /root/.cache/pip/wheels/fb/5d/5c/59a62d8a695285e59ec9c1f66add6f8a9ac4152499a2be0113
Successfully built chess
Installing collected packages: chess
Successfully installed chess-1.11.2


In [9]:
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
import wandb

import math
import torch.nn as nn
from typing import Optional
import chess
import chess.svg
from PIL import Image
from io import BytesIO
from torch.utils.data import Dataset, DataLoader

In [10]:
OUTPUT_FILE = '/kaggle/input/chess-50k/chess_positions.pt'
ABSORBING_STATE_INT = 13
VOCAB_SIZE = 14  # 12 pieces + 1 empty + 1 absorbing

BATCH_SIZE = 128
NUM_WORKERS = 4

# Model
MODEL_DIM = 512
MODEL_DEPTH = 8
MODEL_HEADS = 8

# Diffusion
NUM_TIMESTEPS = 1000

# Training
LEARNING_RATE = 2e-4
NUM_EPOCHS = 1
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
PRECISION = "16-mixed" if DEVICE == "cuda" else "32"

# Logging & Sampling
PROJECT_NAME = "chess-d3pm"
SAMPLE_EVERY_N_STEPS = 500
NUM_SAMPLES_TO_GENERATE = 4

In [11]:
def tensor_to_fen(board_tensor: torch.Tensor, active_color='w', castling='KQkq', en_passant='-', halfmove_clock=0, fullmove_number=1) -> str:
    """
    Converts a 1D tensor of 64 tokens back into a FEN string.

    Note: This function only reconstructs the piece placement part of the FEN.
    Other game state information (turn, castling, etc.) must be provided.

    Args:
        board_tensor (torch.Tensor): A 1D tensor of shape (64,) representing the board.
        active_color (str): 'w' or 'b'.
        castling (str): Castling availability (e.g., 'KQkq', '-', 'Kq').
        en_passant (str): En passant target square (e.g., 'e3', '-').
        halfmove_clock (int): Halfmove clock value.
        fullmove_number (int): Fullmove number.

    Returns:
        str: The full FEN string for the position.
    """
    if board_tensor.shape != (64,):
        raise ValueError("Input tensor must have shape (64,)")

    fen_parts = []
    for rank_index in range(7, -1, -1):  # Iterate from rank 8 down to 1
        rank_fen = ""
        empty_count = 0
        for file_index in range(8):  # Iterate from file 'a' to 'h'
            square_index = rank_index * 8 + file_index
            piece_int = board_tensor[square_index].item()

            if piece_int == 0:
                empty_count += 1
            else:
                if empty_count > 0:
                    rank_fen += str(empty_count)
                    empty_count = 0
                rank_fen += INT_TO_PIECE[piece_int]

        if empty_count > 0:
            rank_fen += str(empty_count)

        fen_parts.append(rank_fen)

    piece_placement = "/".join(fen_parts)

    # Combine all parts of the FEN string
    full_fen = f"{piece_placement} {active_color} {castling} {en_passant} {halfmove_clock} {fullmove_number}"

    return full_fen

def fen_to_tensor(fen_string: str) -> torch.Tensor:
    """
    Parses a FEN string and converts it into a 1D tensor of 64 tokens.
    Each token represents a square on the board, with an integer value
    corresponding to the piece on it (0 for empty).

    Args:
        fen_string (str): The FEN string for a board position.

    Returns:
        torch.Tensor: A 1D tensor of shape (64,) with integer piece representations.
    """
    # The board tensor, initialized to 0 (empty)
    board_tensor = torch.zeros(64, dtype=torch.long)

    # The first part of FEN is the piece placement
    piece_placement = fen_string.split(' ')[0]

    rank_index = 7  # Start from rank 8 (index 7)
    file_index = 0  # Start from file 'a' (index 0)

    for char in piece_placement:
        if char == '/':
            rank_index -= 1
            file_index = 0
        elif char.isdigit():
            file_index += int(char)
        else:
            square_index = rank_index * 8 + file_index
            board_tensor[square_index] = PIECE_TO_INT[char]
            file_index += 1

    return board_tensor

In [12]:
def display_board_from_tensor(board_tensor: torch.Tensor, size=400, save_path=None, show_image=False):
    """
    Generates a visual representation of the board from a tensor.

    Args:
        board_tensor (torch.Tensor): The 1D tensor of shape (64,) representing the board.
        size (int): The size of the output image in pixels.
        save_path (str, optional): Path to save the image file (e.g., 'board.png').
                                   If None, the image is not saved. Defaults to None.
        show_image (bool, optional): If True, attempts to open the image in the default
                                     system viewer. Defaults to False.
    """
    # Convert the tensor to a FEN string. We use default values for game state
    # as they don't affect the visual piece placement.
    fen = tensor_to_fen(board_tensor)

    # Create a board object from the FEN
    board = chess.Board(fen)

    # Generate an SVG image of the board
    svg_data = chess.svg.board(board=board, size=size)

    # Convert SVG to a PNG and handle it
    try:
        from cairosvg import svg2png
        png_data = svg2png(bytestring=svg_data.encode('utf-8'))
        img = Image.open(BytesIO(png_data))

        if save_path:
            # Ensure the directory exists before saving
            output_dir = os.path.dirname(save_path)
            if output_dir:
                os.makedirs(output_dir, exist_ok=True)
            img.save(save_path)
            print(f"Board image saved to: {save_path}")

        if show_image:
            img.show() # This will open the image in your default image viewer

        return img
    except ImportError:
        print("CairoSVG not found. Cannot display or save image.")
        print("Board FEN:", fen)
        # The SVG data can still be useful for debugging
        if save_path and save_path.endswith(".svg"):
             with open(save_path, "w") as f:
                f.write(svg_data)
             print(f"Saved board as SVG to: {save_path}")
        else:
            print("Board SVG data:\n", svg_data)
        return None

In [13]:
class ChessDataset(Dataset):
    """
    A PyTorch Dataset for loading pre-processed chess board tensors.
    """
    def __init__(self, tensor_file=OUTPUT_FILE):
        """
        Args:
            tensor_file (str): Path to the .pt file containing the board tensors.
        """
        if not os.path.exists(tensor_file):
            raise FileNotFoundError(
                f"Dataset file not found: {tensor_file}. "
                f"Please run create_chess_dataset.py first to generate it."
            )
        print(f"Loading dataset from {tensor_file}...")
        self.data = torch.load(tensor_file)
        print("Dataset loaded successfully.")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [14]:
class D3PM(nn.Module):
    """
    The core D3PM engine for discrete state spaces using an absorbing state.
    """
    def __init__(self, x0_model: nn.Module, n_T: int, num_classes: int, hybrid_loss_coeff=0.0,) -> None:
        """
        Args:
            x0_model (nn.Module): The neural network that predicts the original data x_0 from a noisy input x_t.
            n_T (int): The total number of diffusion timesteps.
            num_classes (int): The total number of discrete states in the vocabulary, INCLUDING the absorbing state.
                               For chess, this will be 12 pieces + 1 empty + 1 absorbing = 14.
            hybrid_loss_coeff (float, optional): The coefficient for the variational bound loss.
                                                 Defaults to 0.0, which means only CrossEntropyLoss is used.
        """
        super(D3PM, self).__init__()
        self.x0_model = x0_model
        self.n_T = n_T
        self.num_classes = num_classes
        self.hybrid_loss_coeff = hybrid_loss_coeff
        self.eps = 1e-6

        # --- Set up the noise schedule and transition matrices for an absorbing state ---

        # Cosine noise schedule
        steps = torch.arange(n_T + 1, dtype=torch.float64) / n_T
        alpha_bar = torch.cos((steps + 0.008) / 1.008 * torch.pi / 2)
        self.beta_t = torch.minimum(
            1 - alpha_bar[1:] / alpha_bar[:-1], torch.ones_like(alpha_bar[1:]) * 0.999
        )

        q_onestep_mats = []
        # The absorbing state is assumed to be the last class index (num_classes - 1)
        absorbing_state_idx = self.num_classes - 1

        for beta in self.beta_t:
            # Create a one-step transition matrix for the absorbing state diffusion process
            mat = torch.eye(self.num_classes, dtype=torch.float64) * (1 - beta)
            mat[:, absorbing_state_idx] += beta
            q_onestep_mats.append(mat)

        q_one_step_mats = torch.stack(q_onestep_mats, dim=0)

        # This will be used for q_posterior_logits
        q_one_step_transposed = q_one_step_mats.transpose(1, 2)

        # Calculate the cumulative transition matrices q(x_t | x_0) by matrix multiplication
        q_mat_t = q_one_step_mats[0]
        q_mats = [q_mat_t]
        for idx in range(1, self.n_T):
            q_mat_t = q_mat_t @ q_one_step_mats[idx]
            q_mats.append(q_mat_t)
        q_mats = torch.stack(q_mats, dim=0)

        # Register buffers so they are moved to the correct device with the model
        self.register_buffer("q_one_step_transposed", q_one_step_transposed)
        self.register_buffer("q_mats", q_mats)

        assert self.q_mats.shape == (self.n_T, self.num_classes, self.num_classes)

    def _at(self, a: torch.Tensor, t: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
        """Helper function to select rows from a based on t and x."""
        bs = t.shape[0]
        t_broadcast = t.reshape(bs, *([1] * (x.dim() - 1)))
        # out[i, j, k, l, m] = a[t[i], x[i, j, k, l], m]
        return a[t_broadcast - 1, x, :]

    def q_posterior_logits(self, x_0: torch.Tensor, x_t: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        """
        Calculates the logits of the posterior distribution q(x_{t-1} | x_t, x_0).

        Args:
            x_0 (torch.Tensor): The original clean data. Can be integer tensor or logits.
            x_t (torch.Tensor): The noisy data at timestep t.
            t (torch.Tensor): The current timestep.
        """
        # If x_0 is integer, convert it to one-hot logits
        if x_0.dtype in [torch.int64, torch.int32]:
            x_0_logits = torch.log(
                torch.nn.functional.one_hot(x_0, self.num_classes) + self.eps
            )
        else:
            x_0_logits = x_0.clone()

        # Equation (3) from the D3PM paper, simplified for our case
        fact1 = self._at(self.q_one_step_transposed, t, x_t)

        # We need q_mats for t-2, so handle the t=1 case
        safe_t = torch.max(t, torch.ones_like(t) * 2)
        qmats2 = self.q_mats[safe_t - 2].to(dtype=torch.float32)

        softmaxed_x0 = torch.softmax(x_0_logits, dim=-1)
        fact2 = torch.einsum("b...c,bcd->b...d", softmaxed_x0, qmats2)

        out = torch.log(fact1 + self.eps) + torch.log(fact2 + self.eps)

        # For t=1, the posterior is just the distribution over x_0
        t_broadcast = t.reshape(t.shape[0], *([1] * (x_t.dim())))
        return torch.where(t_broadcast == 1, x_0_logits, out)

    def vb(self, dist1: torch.Tensor, dist2: torch.Tensor) -> torch.Tensor:
        """Calculates the KL-divergence for the variational bound loss."""
        dist1_flat = dist1.flatten(start_dim=0, end_dim=-2)
        dist2_flat = dist2.flatten(start_dim=0, end_dim=-2)

        kl_div = torch.softmax(dist1_flat + self.eps, dim=-1) * (
            torch.log_softmax(dist1_flat + self.eps, dim=-1)
            - torch.log_softmax(dist2_flat + self.eps, dim=-1)
        )
        return kl_div.sum(dim=-1).mean()

    def q_sample(self, x_0: torch.Tensor, t: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
        """
        The forward process q(x_t | x_0). Corrupts the clean input x_0 to a noisy x_t.
        """
        logits = torch.log(self._at(self.q_mats, t, x_0) + self.eps)

        # Use Gumbel-Max trick for sampling from the categorical distribution
        gumbel_noise = -torch.log(-torch.log(torch.clip(noise, self.eps, 1.0)))
        return torch.argmax(logits + gumbel_noise, dim=-1)

    def model_predict(self, x_t: torch.Tensor, t: torch.Tensor, cond: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Calls the underlying x0_model to predict the logits of the original data x_0.
        """
        return self.x0_model(x_t, t, cond)

    def forward(self, x: torch.Tensor, cond: Optional[torch.Tensor] = None) -> tuple[torch.Tensor, dict]:
        """
        The main training step.
        1. Samples a timestep t.
        2. Corrupts the input x to x_t.
        3. Predicts the original x_0 using the model.
        4. Calculates the loss.
        """
        # Sample a random timestep for each item in the batch
        t = torch.randint(1, self.n_T + 1, (x.shape[0],), device=x.device)

        # Create x_t by running the forward diffusion process
        noise = torch.rand((*x.shape, self.num_classes), device=x.device)
        x_t = self.q_sample(x, t, noise)

        # Get the model's prediction of the original x_0's logits
        predicted_x0_logits = self.model_predict(x_t, t, cond)

        # --- Calculate Loss ---
        # 1. The primary loss: Cross-Entropy between predicted x_0 and true x_0
        ce_loss = torch.nn.functional.cross_entropy(
            predicted_x0_logits.permute(0, 2, 1), # Shape: (B, Vocab, SeqLen)
            x                                      # Shape: (B, SeqLen)
        )

        # 2. The auxiliary variational bound loss (optional)
        vb_loss = torch.tensor(0.0)
        if self.hybrid_loss_coeff > 0:
            true_q_posterior = self.q_posterior_logits(x, x_t, t)
            pred_q_posterior = self.q_posterior_logits(predicted_x0_logits, x_t, t)
            vb_loss = self.vb(true_q_posterior, pred_q_posterior)

        total_loss = ce_loss + self.hybrid_loss_coeff * vb_loss

        return total_loss, {
            "vb_loss": vb_loss.detach().item(),
            "ce_loss": ce_loss.detach().item(),
        }

    @torch.no_grad()
    def p_sample(self, x_t: torch.Tensor, t: torch.Tensor, cond: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        The reverse process step p(x_{t-1} | x_t). Samples x_{t-1} given x_t.
        """
        predicted_x0_logits = self.model_predict(x_t, t, cond)
        pred_q_posterior_logits = self.q_posterior_logits(predicted_x0_logits, x_t, t)

        # Use Gumbel-Max trick for sampling
        noise = torch.rand_like(pred_q_posterior_logits)
        gumbel_noise = -torch.log(-torch.log(torch.clip(noise, self.eps, 1.0)))

        # Don't add noise at the last step (t=1)
        not_first_step = (t != 1).float().reshape(x_t.shape[0], *([1] * (x_t.dim())))
        sample = torch.argmax(
            pred_q_posterior_logits + gumbel_noise * not_first_step, dim=-1
        )
        return sample

    @torch.no_grad()
    def sample(self, initial_noise: torch.Tensor, cond: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Generates a full sample from noise by iterating through the reverse process.

        Args:
            initial_noise (torch.Tensor): A tensor of shape (B, SeqLen) filled with the absorbing state index.
            cond (torch.Tensor, optional): Conditioning information, if any. Defaults to None.
        """
        x = initial_noise
        for i in reversed(range(1, self.n_T + 1)):
            t = torch.full((x.shape[0],), i, device=x.device, dtype=torch.long)
            x = self.p_sample(x, t, cond)
        return x

    @torch.no_grad()
    def sample_with_history(self, initial_noise: torch.Tensor, cond: Optional[torch.Tensor] = None, stride: int = 10) -> list[torch.Tensor]:
        """
        Generates a full sample and saves intermediate steps.
        """
        x = initial_noise
        history = []
        for i in reversed(range(1, self.n_T + 1)):
            t = torch.full((x.shape[0],), i, device=x.device, dtype=torch.long)
            x = self.p_sample(x, t, cond)
            if (i - 1) % stride == 0 or i == 1:
                history.append(x.cpu())
        return history


In [15]:
def modulate(x, shift, scale):
    """
    Modulates the input tensor 'x' using a scale and shift.
    This is the core of Adaptive Layer Normalization (adaLN).
    """
    # The unsqueeze is to make the scale and shift broadcastable to the sequence length.
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)

class TimestepEmbedder(nn.Module):
    """
    Embeds a discrete timestep t into a continuous vector.
    """
    def __init__(self, hidden_size, frequency_embedding_size=256):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(frequency_embedding_size, hidden_size, bias=True),
            nn.SiLU(),
            nn.Linear(hidden_size, hidden_size, bias=True),
        )
        self.frequency_embedding_size = frequency_embedding_size

    @staticmethod
    def timestep_embedding(t, dim, max_period=10000):
        """
        Creates sinusoidal timestep embeddings.
        Args:
            t (torch.Tensor): A 1-D Tensor of N indices, one per batch element.
            dim (int): The dimension of the output.
            max_period (int): The maximum period for the sinusoidal embeddings.
        """
        half = dim // 2
        freqs = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=t.device)
        args = t[:, None].float() * freqs[None]
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
        if dim % 2:
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
        return embedding

    def forward(self, t):
        t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
        t_emb = self.mlp(t_freq)
        return t_emb

class TransformerBlock(nn.Module):
    """
    A standard Transformer block with Adaptive Layer Normalization (adaLN).
    """
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
        super().__init__()
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, mlp_hidden_dim),
            nn.GELU(),
            nn.Linear(mlp_hidden_dim, hidden_size),
        )
        # This single MLP layer generates all the conditioning parameters (scale/shift)
        # for the entire block. 2 for norm1, 2 for norm2.
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 4 * hidden_size, bias=True)
        )

    def forward(self, x: torch.Tensor, adaln_input: torch.Tensor) -> torch.Tensor:
        # Generate scale and shift parameters from the timestep embedding
        shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1)

        # Attention block with adaLN
        x_norm1 = modulate(self.norm1(x), shift_msa, scale_msa)
        attn_output, _ = self.attn(x_norm1, x_norm1, x_norm1)
        x = x + attn_output

        # MLP block with adaLN
        x_norm2 = modulate(self.norm2(x), shift_mlp, scale_mlp)
        mlp_output = self.mlp(x_norm2)
        x = x + mlp_output

        return x

class FinalLayer(nn.Module):
    """
    The final layer of the DiT, which projects the sequence of vectors
    back to the vocabulary space (logits).
    """
    def __init__(self, hidden_size, out_channels):
        super().__init__()
        self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        self.linear = nn.Linear(hidden_size, out_channels, bias=True)
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 2 * hidden_size, bias=True)
        )

    def forward(self, x, c):
        shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
        x = modulate(self.norm_final(x), shift, scale)
        x = self.linear(x)
        return x

class ChessDiT(nn.Module):
    """
    A Diffusion Transformer (DiT) specifically for chess board generation.
    This model predicts the clean board x_0 from a noisy input x_t.

    Args:
        vocab_size (int): The size of the vocabulary (14 for chess: 12 pieces + empty + absorbing).
        hidden_size (int): The dimensionality of the model (D).
        depth (int): The number of Transformer blocks.
        num_heads (int): The number of attention heads.
        mlp_ratio (float): The ratio for the MLP's hidden dimension in Transformer blocks.
    """
    def __init__(
        self,
        vocab_size: int,
        hidden_size: int = 768,
        depth: int = 12,
        num_heads: int = 12,
        mlp_ratio: float = 4.0,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_heads = num_heads

        # 1. Input Embedders
        self.piece_embedder = nn.Embedding(vocab_size, hidden_size)
        self.pos_embedder = nn.Parameter(torch.randn(1, 64, hidden_size)) # Learnable positional embeddings
        self.t_embedder = TimestepEmbedder(hidden_size)

        # 2. Transformer Blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
        ])

        # 3. Final Output Layer
        self.final_layer = FinalLayer(hidden_size, vocab_size)

        # Initialize weights
        self.initialize_weights()

    def initialize_weights(self):
        # Initialize positional embedding and token embedding
        nn.init.normal_(self.pos_embedder, std=0.02)
        nn.init.normal_(self.piece_embedder.weight, std=0.02)

        # Initialize all Linear layers
        def _basic_init(module):
            if isinstance(module, nn.Linear):
                torch.nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0)
        self.apply(_basic_init)

        # Zero-out output layers:
        # The last layer of each MLP block
        for block in self.blocks:
            nn.init.constant_(block.mlp[-1].bias, 0)
            nn.init.constant_(block.mlp[-1].weight, 0)
        # The final projection layer
        nn.init.constant_(self.final_layer.linear.bias, 0)
        nn.init.constant_(self.final_layer.linear.weight, 0)

    def forward(self, x_t: torch.Tensor, t: torch.Tensor, cond: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Forward pass of the ChessDiT model.

        Args:
            x_t (torch.Tensor): The noisy input board tensor of shape (B, 64).
            t (torch.Tensor): The timestep tensor of shape (B,).
            cond (Optional[torch.Tensor]): Unused in this unconditional model, kept for API consistency.

        Returns:
            torch.Tensor: The predicted logits for the clean board x_0, of shape (B, 64, vocab_size).
        """
        # (B, 64) -> (B, 64, D)
        x_emb = self.piece_embedder(x_t)
        # (B,) -> (B, D)
        t_emb = self.t_embedder(t)

        # Add positional embeddings
        x = x_emb + self.pos_embedder  # (B, 64, D)

        # Process through Transformer blocks
        for block in self.blocks:
            x = block(x, adaln_input=t_emb)

        # Final projection to get logits
        # The final layer is also conditioned on the timestep
        logits = self.final_layer(x, t_emb) # (B, 64, D) -> (B, 64, vocab_size)

        return logits

In [16]:
class D3PMLightning(pl.LightningModule):
    """
    PyTorch Lightning wrapper for our D3PM model.
    This class handles the training, optimization, and logging.
    """
    def __init__(self, learning_rate: float):
        super().__init__()
        self.save_hyperparameters() # Saves learning_rate, etc. to the checkpoint

        # 1. Create the x0-prediction model (the Transformer)
        x0_model = ChessDiT(
            vocab_size=VOCAB_SIZE,
            hidden_size=MODEL_DIM,
            depth=MODEL_DEPTH,
            num_heads=MODEL_HEADS,
        )

        # 2. Create the D3PM diffusion engine
        self.d3pm = D3PM(
            x0_model=x0_model,
            n_T=NUM_TIMESTEPS,
            num_classes=VOCAB_SIZE,
            hybrid_loss_coeff=0.0, # Using only CE loss for simplicity
        )

    def training_step(self, batch, batch_idx):
        # The batch is a tensor of clean chess boards from our dataset
        clean_boards = batch

        # The forward pass of our D3PM module does everything:
        # - picks a random t
        # - corrupts the input
        # - runs the model
        # - calculates the loss
        loss, info = self.d3pm(clean_boards)

        # Log the loss and its components
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("ce_loss", info["ce_loss"], on_step=True, logger=True)
        self.log("vb_loss", info["vb_loss"], on_step=True, logger=True)

        return loss

    def configure_optimizers(self):
        # The optimizer should only target the parameters of the x0_model
        optimizer = torch.optim.AdamW(
            self.d3pm.x0_model.parameters(),
            lr=self.hparams.learning_rate
        )
        return optimizer

    @torch.no_grad()
    def sample(self, num_samples: int) -> torch.Tensor:
        """ Helper function to generate samples from noise. """
        # Start with a board full of the absorbing state token
        initial_noise = torch.full(
            (num_samples, 64),
            ABSORBING_STATE_INT,
            device=self.device,
            dtype=torch.long
        )
        # Use the D3PM engine's sample method
        return self.d3pm.sample(initial_noise)

class SamplingCallback(pl.Callback):
    """
    A PyTorch Lightning Callback to periodically generate and log board samples.
    """
    def __init__(self, sample_every_n_steps: int, num_samples: int):
        super().__init__()
        self.sample_every_n_steps = sample_every_n_steps
        self.num_samples = num_samples

    def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
        """ Called after each training step. """
        # Ensure we have a logger and wandb is available
        if not trainer.logger or not wandb:
            return

        # Check if it's time to sample
        global_step = trainer.global_step
        if (global_step + 1) % self.sample_every_n_steps == 0:
            print(f"\n--- Sampling at step {global_step+1} ---")
            pl_module.eval() # Set model to evaluation mode

            # Generate samples
            generated_boards = pl_module.sample(self.num_samples)

            # Visualize the first generated board
            first_board_tensor = generated_boards[0].cpu()
            img = display_board_from_tensor(first_board_tensor, show_image=False)

            if img:
                # Log the image to wandb
                trainer.logger.experiment.log({
                    "generated_sample": wandb.Image(img, caption=f"Step {global_step+1}")
                })

            # Also log the FEN string for text-based inspection
            reconstructed_fen = tensor_to_fen(first_board_tensor)
            trainer.logger.experiment.log({
                "generated_fen": reconstructed_fen
            })
            print(f"Sampled FEN: {reconstructed_fen}")
            print("--- End Sampling ---")

            pl_module.train() # Set model back to training mode

In [None]:
# --- 1. Setup Data ---
if not os.path.exists(OUTPUT_FILE):
    raise FileNotFoundError(
        f"Dataset file not found: {OUTPUT_FILE}. "
        f"Please run create_chess_dataset.py first to generate it."
    )
dataset = ChessDataset(tensor_file=OUTPUT_FILE)
dataloader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

# --- 2. Setup Model ---
model = D3PMLightning(learning_rate=LEARNING_RATE)

# --- 3. Setup Logging & Callbacks ---
wandb_logger = WandbLogger(project=PROJECT_NAME, log_model="all")
# wandb_logger.watch(model, log="all") # Optional: log gradients

# Callback to save the model checkpoint
checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints/",
    filename="chess-d3pm-{epoch:02d}-{train_loss:.2f}",
    save_top_k=3,
    monitor="train_loss",
    mode="min",
)

# Callback for sampling
sampling_callback = SamplingCallback(
    sample_every_n_steps=SAMPLE_EVERY_N_STEPS,
    num_samples=NUM_SAMPLES_TO_GENERATE
)

# Callback for monitoring learning rate
lr_monitor = LearningRateMonitor(logging_interval='step')

# --- 4. Setup Trainer and Start Training ---
trainer = pl.Trainer(
    accelerator=DEVICE,
    precision=PRECISION,
    max_epochs=NUM_EPOCHS,
    logger=wandb_logger,
    callbacks=[checkpoint_callback, sampling_callback, lr_monitor],
    log_every_n_steps=10,
)

print("--- Starting Training ---")
trainer.fit(model, dataloader)
print("--- Training Finished ---")

Loading dataset from /kaggle/input/chess-50k/chess_positions.pt...
Dataset loaded successfully.
--- Starting Training ---


<IPython.core.display.Javascript object>