# Transformer-based Sudoku Solver

This notebook implements a Transformer-based Sudoku solver with positional embeddings.

Key improvements over the LSTM version:
1. **Transformer Architecture**: Better at capturing global dependencies (row, column, box constraints).
2. **Positional Embeddings**: Explicit row, column, and 3x3 box position encodings.
3. **Learning Rate Scheduling**: OneCycleLR with warmup.
4. **Gradient Clipping**: Prevents exploding gradients.
5. **Label Smoothing**: Reduces overconfidence and improves generalization.

In [None]:
import os

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import (
    DataLoader,
    Dataset,
)
from tqdm import tqdm

## 1. Data Processing

Functions to download, augment, and save the dataset as binary files.

In [None]:
from datasets import load_dataset


def shuffle_sudoku(board_flat, solution_flat):
    """Apply valid Sudoku transformations (permutations) to a board and solution."""
    board = board_flat.reshape(9, 9)
    sol = solution_flat.reshape(9, 9)

    # 1. Permute digits (1-9)
    digit_map = np.arange(10)
    digit_map[1:] = np.random.permutation(np.arange(1, 10))

    # 2. Random Transpose
    if np.random.rand() < 0.5:
        board = board.T
        sol = sol.T

    # 3. Permute Bands (groups of 3 rows)
    bands = np.random.permutation(3)
    row_perm = np.concatenate([b * 3 + np.random.permutation(3) for b in bands])

    # 4. Permute Stacks (groups of 3 cols)
    stacks = np.random.permutation(3)
    col_perm = np.concatenate([s * 3 + np.random.permutation(3) for s in stacks])

    # Apply permutations
    board = board[row_perm, :][:, col_perm]
    sol = sol[row_perm, :][:, col_perm]

    # Map digits
    board = digit_map[board]
    sol = digit_map[sol]

    return board.flatten(), sol.flatten()


def preprocess_dataset(output_dir='data/processed', num_aug=1):
    """Download, filter, augment, and save dataset as .npy."""
    if os.path.exists(output_dir):
        print(f'Dataset already exists at {output_dir}. Skipping generation.')
        return

    print('Loading dataset from HuggingFace...')
    ds = load_dataset('sapientinc/sudoku-extreme')

    # Filter easy sources
    easy_sources = ['puzzles0_kaggle', 'puzzles1_unbiased', 'puzzles2_17_clue']

    os.makedirs(output_dir, exist_ok=True)

    for split in ['train', 'test']:
        print(f'Processing {split} split...')
        split_ds = ds[split].filter(lambda x: x['source'] in easy_sources)

        questions = []
        answers = []

        print('Converting to integers and augmenting...')
        for item in tqdm(split_ds):
            q = np.array(
                [0 if c == '.' else int(c) for c in item['question']],
                dtype=np.uint8,
            )
            a = np.array([int(c) for c in item['answer']], dtype=np.uint8)

            questions.append(q)
            answers.append(a)

            # Augmentations (only for train)
            if split == 'train' and num_aug > 0:
                for _ in range(num_aug):
                    q_aug, a_aug = shuffle_sudoku(q, a)
                    questions.append(q_aug)
                    answers.append(a_aug)

        q_arr = np.array(questions, dtype=np.uint8)
        a_arr = np.array(answers, dtype=np.uint8)

        print(f'Saving {len(q_arr)} samples to {output_dir}...')
        np.save(os.path.join(output_dir, f'{split}_questions.npy'), q_arr)
        np.save(os.path.join(output_dir, f'{split}_answers.npy'), a_arr)


# Run preprocessing (will skip if already exists)
preprocess_dataset(num_aug=1)

## 2. Dataset Class

Loads data directly from `.npy` files.

In [None]:
class FastSudokuDataset(Dataset):
    """Fast Sudoku dataset loading from preprocessed .npy files."""

    def __init__(self, data_dir, split):
        """Initialize the dataset.

        Args:
            data_dir: Directory containing the .npy files.
            split: Either 'train' or 'test'.
        """
        self.questions = np.load(os.path.join(data_dir, f'{split}_questions.npy'))
        self.answers = np.load(os.path.join(data_dir, f'{split}_answers.npy'))

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        q = self.questions[idx].astype(np.int64)
        a = self.answers[idx].astype(np.int64)

        # Create mask (where q was 0 = unknown cell)
        mask = q == 0

        # Target for Loss: PyTorch CrossEntropy expects 0-8 for classes.
        # Our answers are 1-9, so we subtract 1.
        target = a - 1

        return {
            'question': torch.from_numpy(q),
            'answer': torch.from_numpy(target),
            'mask': torch.from_numpy(mask),
        }

## 3. Transformer Model with Positional Embeddings

Key features:
- **Value Embedding**: Maps cell values (0-9) to hidden dimension.
- **Row/Column/Box Embeddings**: Encode spatial position in the 9x9 grid.
- **Transformer Encoder**: Self-attention captures global dependencies.

The positional embeddings help the model understand Sudoku constraints:
- Row embedding: cells in the same row share the same row index.
- Column embedding: cells in the same column share the same column index.
- Box embedding: cells in the same 3x3 box share the same box index.

In [None]:
class SudokuTransformer(nn.Module):
    """Transformer-based Sudoku solver with positional embeddings."""

    def __init__(
        self,
        hidden_size=256,
        num_layers=6,
        num_heads=8,
        dropout=0.1,
    ):
        """Initialize the Transformer model.

        Args:
            hidden_size: Dimension of embeddings and transformer hidden states.
            num_layers: Number of transformer encoder layers.
            num_heads: Number of attention heads.
            dropout: Dropout probability.
        """
        super().__init__()

        # Value embedding: 10 possible values (0=unknown, 1-9=digits)
        self.value_embed = nn.Embedding(10, hidden_size)

        # Positional embeddings for Sudoku structure
        self.row_embed = nn.Embedding(9, hidden_size)
        self.col_embed = nn.Embedding(9, hidden_size)
        self.box_embed = nn.Embedding(9, hidden_size)

        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_size,
            nhead=num_heads,
            dim_feedforward=hidden_size * 4,
            dropout=dropout,
            batch_first=True,
            activation='gelu',
        )
        self.transformer = nn.TransformerEncoder(
            encoder_layer,
            num_layers=num_layers,
        )

        # Output layer: 9 classes (digits 1-9)
        self.fc = nn.Linear(hidden_size, 9)

        # Precompute position indices (constant, not learned)
        # For a 9x9 grid flattened to 81 positions:
        # - row_idx[i] = i // 9 (which row, 0-8)
        # - col_idx[i] = i % 9 (which column, 0-8)
        # - box_idx[i] = (i // 9 // 3) * 3 + (i % 9 // 3) (which 3x3 box, 0-8)
        row_idx = torch.arange(81) // 9
        col_idx = torch.arange(81) % 9
        box_idx = (torch.arange(81) // 9 // 3) * 3 + (torch.arange(81) % 9 // 3)

        self.register_buffer('row_idx', row_idx)
        self.register_buffer('col_idx', col_idx)
        self.register_buffer('box_idx', box_idx)

    def forward(self, x):
        """Forward pass.

        Args:
            x: Input tensor of shape (batch, 81) with values 0-9.

        Returns:
            Logits of shape (batch, 81, 9).
        """
        # Value embedding: (batch, 81) -> (batch, 81, hidden_size)
        val_emb = self.value_embed(x)

        # Position embeddings: (81, hidden_size) each, broadcasted to batch
        pos_emb = (
            self.row_embed(self.row_idx)
            + self.col_embed(self.col_idx)
            + self.box_embed(self.box_idx)
        )

        # Combine value and position: (batch, 81, hidden_size)
        x = val_emb + pos_emb

        # Transformer: (batch, 81, hidden_size)
        x = self.transformer(x)

        # Output: (batch, 81, 9)
        return self.fc(x)

## 4. Training Utilities

Loss function, accuracy metrics, and training loop.

In [None]:
def masked_loss(preds, targets, mask, label_smoothing=0.1):
    """Compute CrossEntropyLoss only on masked (unknown) cells.

    Args:
        preds: Model output logits, shape (batch, 81, 9).
        targets: Target class indices, shape (batch, 81).
        mask: Boolean mask, shape (batch, 81). True = compute loss.
        label_smoothing: Label smoothing factor for regularization.

    Returns:
        Scalar loss value.
    """
    loss = F.cross_entropy(
        preds.reshape(-1, 9),
        targets.reshape(-1),
        reduction='none',
        label_smoothing=label_smoothing,
    )
    loss = loss.reshape(targets.shape)
    masked = loss * mask.float()
    return masked.sum() / (mask.sum() + 1e-6)


def compute_accuracy(predictions, targets, mask):
    """Compute cell-level and puzzle-level accuracy.

    Args:
        predictions: Model output logits, shape (batch, 81, 9).
        targets: Target class indices, shape (batch, 81).
        mask: Boolean mask, shape (batch, 81). True = cell needs prediction.

    Returns:
        Tuple of (cell_accuracy, puzzle_accuracy).
    """
    predicted_classes = predictions.argmax(dim=-1)

    # Cell accuracy: correct predictions among masked cells
    correct = (predicted_classes == targets) & mask
    cell_accuracy = correct.sum().float() / (mask.sum().float() + 1e-6)

    # Puzzle accuracy: all masked cells correct for each puzzle
    correct_per_puzzle = correct.sum(dim=1)
    masked_per_puzzle = mask.sum(dim=1)
    puzzles_solved = (correct_per_puzzle == masked_per_puzzle).float()
    puzzle_accuracy = puzzles_solved.mean()

    return cell_accuracy.item(), puzzle_accuracy.item()

In [None]:
def train_epoch(model, loader, optimizer, scheduler, scaler, device, max_grad_norm=1.0):
    """Train for one epoch.

    Args:
        model: The model to train.
        loader: DataLoader for training data.
        optimizer: Optimizer instance.
        scheduler: Learning rate scheduler.
        scaler: GradScaler for mixed precision (or None).
        device: Device to train on.
        max_grad_norm: Maximum gradient norm for clipping.

    Returns:
        Dictionary with loss, cell_accuracy, puzzle_accuracy.
    """
    model.train()
    total_loss = 0
    total_cell_acc = 0
    total_puzzle_acc = 0
    num_batches = 0
    pbar = tqdm(loader)

    for batch in pbar:
        q = batch['question'].to(device)
        a = batch['answer'].to(device)
        m = batch['mask'].to(device)

        optimizer.zero_grad()

        # Mixed Precision Training
        if device.type == 'cuda' and scaler is not None:
            with torch.amp.autocast('cuda'):
                preds = model(q)
                loss = masked_loss(preds, a, m)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            scaler.step(optimizer)
            scaler.update()

        elif device.type == 'mps':
            with torch.autocast(device_type='mps', dtype=torch.float16):
                preds = model(q)
                loss = masked_loss(preds, a, m)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()

        else:
            preds = model(q)
            loss = masked_loss(preds, a, m)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
            optimizer.step()

        # Step scheduler
        scheduler.step()

        # Compute accuracy
        with torch.no_grad():
            cell_acc, puzzle_acc = compute_accuracy(preds, a, m)

        total_loss += loss.item()
        total_cell_acc += cell_acc
        total_puzzle_acc += puzzle_acc
        num_batches += 1

        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'cell': f'{cell_acc:.2%}',
            'puzzle': f'{puzzle_acc:.2%}',
            'lr': f"{scheduler.get_last_lr()[0]:.2e}",
        })

    return {
        'loss': total_loss / num_batches,
        'cell_accuracy': total_cell_acc / num_batches,
        'puzzle_accuracy': total_puzzle_acc / num_batches,
    }


@torch.no_grad()
def evaluate(model, loader, device):
    """Evaluate the model on a dataset.

    Args:
        model: The model to evaluate.
        loader: DataLoader for evaluation data.
        device: Device to evaluate on.

    Returns:
        Dictionary with loss, cell_accuracy, puzzle_accuracy.
    """
    model.eval()
    total_loss = 0
    total_cell_acc = 0
    total_puzzle_acc = 0
    num_batches = 0

    for batch in tqdm(loader, desc='Evaluating'):
        q = batch['question'].to(device)
        a = batch['answer'].to(device)
        m = batch['mask'].to(device)

        preds = model(q)
        loss = masked_loss(preds, a, m, label_smoothing=0.0)

        cell_acc, puzzle_acc = compute_accuracy(preds, a, m)

        total_loss += loss.item()
        total_cell_acc += cell_acc
        total_puzzle_acc += puzzle_acc
        num_batches += 1

    return {
        'loss': total_loss / num_batches,
        'cell_accuracy': total_cell_acc / num_batches,
        'puzzle_accuracy': total_puzzle_acc / num_batches,
    }

## 5. Training

Configuration and main training loop with:
- AdamW optimizer with weight decay
- OneCycleLR scheduler with warmup
- Mixed precision training
- Gradient clipping

In [None]:
# Configuration
BATCH_SIZE = 512
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 0.01
NUM_EPOCHS = 20
MAX_GRAD_NORM = 1.0

# Model hyperparameters
HIDDEN_SIZE = 256
NUM_LAYERS = 6
NUM_HEADS = 8
DROPOUT = 0.1

# Device setup
device = torch.device(
    'cuda' if torch.cuda.is_available()
    else 'mps' if torch.backends.mps.is_available()
    else 'cpu'
)
print(f'Using device: {device}')

In [None]:
# Load Data
train_ds = FastSudokuDataset('data/processed', 'train')
test_ds = FastSudokuDataset('data/processed', 'test')

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0,
    pin_memory=True if device.type == 'cuda' else False,
)
test_loader = DataLoader(
    test_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True if device.type == 'cuda' else False,
)

print(f'Train samples: {len(train_ds):,}')
print(f'Test samples: {len(test_ds):,}')
print(f'Train batches per epoch: {len(train_loader):,}')

In [None]:
# Initialize Model
model = SudokuTransformer(
    hidden_size=HIDDEN_SIZE,
    num_layers=NUM_LAYERS,
    num_heads=NUM_HEADS,
    dropout=DROPOUT,
).to(device)

# Count parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Model parameters: {num_params:,}')

# Compile model if available (PyTorch 2.0+, CUDA)
if hasattr(torch, 'compile') and device.type == 'cuda':
    print('Compiling model with torch.compile...')
    model = torch.compile(model)

In [None]:
# Optimizer and Scheduler
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
)

# OneCycleLR: warmup -> peak -> decay
scheduler = OneCycleLR(
    optimizer,
    max_lr=LEARNING_RATE,
    epochs=NUM_EPOCHS,
    steps_per_epoch=len(train_loader),
    pct_start=0.1,  # 10% warmup
    anneal_strategy='cos',
)

# GradScaler for mixed precision (CUDA only)
scaler = torch.amp.GradScaler('cuda') if device.type == 'cuda' else None

In [None]:
# Training Loop
best_puzzle_acc = 0.0

for epoch in range(NUM_EPOCHS):
    print(f'\n{"=" * 60}')
    print(f'Epoch {epoch + 1}/{NUM_EPOCHS}')
    print('=' * 60)

    # Train
    train_metrics = train_epoch(
        model, train_loader, optimizer, scheduler, scaler, device, MAX_GRAD_NORM
    )
    print(
        f"Train - Loss: {train_metrics['loss']:.4f}, "
        f"Cell Acc: {train_metrics['cell_accuracy']:.2%}, "
        f"Puzzle Acc: {train_metrics['puzzle_accuracy']:.2%}"
    )

    # Evaluate every 5 epochs or on last epoch
    if (epoch + 1) % 5 == 0 or epoch == NUM_EPOCHS - 1:
        test_metrics = evaluate(model, test_loader, device)
        print(
            f"Test  - Loss: {test_metrics['loss']:.4f}, "
            f"Cell Acc: {test_metrics['cell_accuracy']:.2%}, "
            f"Puzzle Acc: {test_metrics['puzzle_accuracy']:.2%}"
        )

        # Save best model
        if test_metrics['puzzle_accuracy'] > best_puzzle_acc:
            best_puzzle_acc = test_metrics['puzzle_accuracy']
            torch.save(model.state_dict(), 'best_transformer_model.pt')
            print(f'  -> New best model saved! (Puzzle Acc: {best_puzzle_acc:.2%})')

print('\n' + '=' * 60)
print(f'Training complete! Best puzzle accuracy: {best_puzzle_acc:.2%}')
print('=' * 60)

## 6. Iterative Inference (Optional)

For harder puzzles, iteratively fill the most confident cells.

In [None]:
@torch.no_grad()
def solve_iterative(model, puzzle, device, max_iters=81):
    """Solve a Sudoku puzzle by iteratively filling the most confident cell.

    Args:
        model: Trained SudokuTransformer model.
        puzzle: Tensor of shape (81,) with values 0-9 (0=unknown).
        device: Device to run inference on.
        max_iters: Maximum number of iterations.

    Returns:
        Solved puzzle tensor of shape (81,).
    """
    model.eval()
    puzzle = puzzle.clone().to(device)

    for _ in range(max_iters):
        mask = puzzle == 0
        if not mask.any():
            break

        logits = model(puzzle.unsqueeze(0)).squeeze(0)  # (81, 9)
        probs = F.softmax(logits, dim=-1)
        confidence, preds = probs.max(dim=-1)  # (81,)

        # Mask out already-filled cells
        confidence = confidence.masked_fill(~mask, -float('inf'))

        # Fill the most confident cell
        best_idx = confidence.argmax()
        puzzle[best_idx] = preds[best_idx] + 1  # +1 because preds are 0-8

    return puzzle


def display_sudoku(puzzle):
    """Display a Sudoku puzzle in a readable format.

    Args:
        puzzle: Tensor or array of shape (81,) with values 0-9.
    """
    if isinstance(puzzle, torch.Tensor):
        puzzle = puzzle.cpu().numpy()

    puzzle = puzzle.reshape(9, 9)
    for i in range(9):
        if i % 3 == 0 and i > 0:
            print('-' * 21)
        row = ''
        for j in range(9):
            if j % 3 == 0 and j > 0:
                row += '| '
            val = puzzle[i, j]
            row += f"{val if val > 0 else '.'} "
        print(row)

In [None]:
# Example: Test iterative solving on a sample
sample = test_ds[0]
puzzle = sample['question']
answer = sample['answer'] + 1  # Convert back to 1-9

print('Original puzzle:')
display_sudoku(puzzle)

print('\nSolved (iterative):')
solved = solve_iterative(model, puzzle, device)
display_sudoku(solved)

print('\nGround truth:')
display_sudoku(answer)

# Check correctness
mask = sample['mask']
correct = (solved.cpu() == answer)[mask].all()
print(f'\nCorrect: {correct}')