# Improved Sudoku Solver Training

This notebook implements an optimized pipeline for training a Sudoku solver. 
Key improvements over the original:
1.  **Offline Binary Data**: Pre-processes data to simple integer arrays (`.npy`), avoiding slow string parsing during training.
2.  **Embeddings**: Uses `nn.Embedding` instead of One-Hot encoding for better memory efficiency.
3.  **Data Augmentation**: Implements Sudoku-valid permutations (rows, cols, digits) to expand the dataset.
4.  **Mixed Precision**: Uses `torch.amp` for faster training.
5.  **Large Batch Size**: Enabled by the above optimizations.

In [1]:
import os
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

## 1. Data Processing
We defines functions to download, augment, and save the dataset as binary files.

In [2]:
from datasets import load_dataset

def shuffle_sudoku(board_flat, solution_flat):
    """Apply valid Sudoku transformations (permutations) to a board and solution."""
    # Reshape to 9x9
    board = board_flat.reshape(9, 9)
    sol = solution_flat.reshape(9, 9)
    
    # 1. Permute digits (1-9)
    # Create a mapping: 0->0 (unknown), 1-9 -> permuted 1-9
    digit_map = np.arange(10)
    digit_map[1:] = np.random.permutation(np.arange(1, 10))
    
    # 2. Random Transpose
    if np.random.rand() < 0.5:
        board = board.T
        sol = sol.T
        
    # 3. Permute Bands (groups of 3 rows)
    bands = np.random.permutation(3)
    row_perm = np.concatenate([b * 3 + np.random.permutation(3) for b in bands])
    
    # 4. Permute Stacks (groups of 3 cols)
    stacks = np.random.permutation(3)
    col_perm = np.concatenate([s * 3 + np.random.permutation(3) for s in stacks])
    
    # Apply permutations
    # We can do this by indexing: new_board[i, j] = old_board[row_perm[i], col_perm[j]]
    # Or simpler: reorder rows, then reorder cols
    board = board[row_perm, :][:, col_perm]
    sol = sol[row_perm, :][:, col_perm]
    
    # Map digits
    board = digit_map[board]
    sol = digit_map[sol]
    
    return board.flatten(), sol.flatten()

def preprocess_dataset(output_dir="data/processed", num_aug=1):
    """Download, filter, augment, and save dataset as .npy."""
    if os.path.exists(output_dir):
        print(f"Dataset already exists at {output_dir}. Skipping generation.")
        return
        
    print("Loading dataset from HuggingFace...")
    ds = load_dataset("sapientinc/sudoku-extreme")
    
    # Filter easy sources
    easy_sources = ['puzzles0_kaggle', 'puzzles1_unbiased', 'puzzles2_17_clue']
    
    os.makedirs(output_dir, exist_ok=True)
    
    for split in ['train', 'test']:
        print(f"Processing {split} split...")
        # Filter
        split_ds = ds[split].filter(lambda x: x['source'] in easy_sources)
        
        questions = []
        answers = []
        
        print("Converting to integers and augmenting...")
        for item in tqdm(split_ds):
            # Parse strings one last time
            # '.' -> 0, '1'-'9' -> 1-9
            q = np.array([0 if c == '.' else int(c) for c in item['question']], dtype=np.uint8)
            # Answer is 1-9. We keep it 1-9 for now (0 means N/A if needed, but answers are full)
            # Original code shifted answer to 0-8. Here we keep 1-9 to match input features,
            # but will shift for loss calculation if needed.
            a = np.array([int(c) for c in item['answer']], dtype=np.uint8)
            
            # Original data
            questions.append(q)
            answers.append(a)
            
            # Augmentations (only for train)
            if split == 'train' and num_aug > 0:
                for _ in range(num_aug):
                    q_aug, a_aug = shuffle_sudoku(q, a)
                    questions.append(q_aug)
                    answers.append(a_aug)
        
        # Save as .npy
        q_arr = np.array(questions, dtype=np.uint8)
        a_arr = np.array(answers, dtype=np.uint8)
        
        print(f"Saving {len(q_arr)} samples to {output_dir}...")
        np.save(os.path.join(output_dir, f"{split}_questions.npy"), q_arr)
        np.save(os.path.join(output_dir, f"{split}_answers.npy"), a_arr)

# Run preprocessing (will skip if already exists)
preprocess_dataset(num_aug=1)  # 1 augmentation -> 2x dataset size

Loading dataset from HuggingFace...


Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


README.md: 0.00B [00:00, ?B/s]

train.csv:   0%|          | 0.00/719M [00:00<?, ?B/s]

test.csv:   0%|          | 0.00/79.4M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/3831994 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/422786 [00:00<?, ? examples/s]

Processing train split...


Filter:   0%|          | 0/3831994 [00:00<?, ? examples/s]

Converting to integers and augmenting...


100%|██████████| 1034600/1034600 [04:16<00:00, 4026.06it/s]


Saving 2069200 samples to data/processed...
Processing test split...


Filter:   0%|          | 0/422786 [00:00<?, ? examples/s]

Converting to integers and augmenting...


100%|██████████| 114558/114558 [00:13<00:00, 8811.91it/s]


Saving 114558 samples to data/processed...


## 2. Optimized Dataset Class
Loads data directly from memory-mapped `.npy` files. Instant access.

In [3]:
class FastSudokuDataset(Dataset):
    def __init__(self, data_dir, split):
        self.questions = np.load(os.path.join(data_dir, f"{split}_questions.npy"))
        self.answers = np.load(os.path.join(data_dir, f"{split}_answers.npy"))
        
    def __len__(self):
        return len(self.questions)
    
    def __getitem__(self, idx):
        # Data is already uint8 0-9
        # q: (81,)
        # a: (81,)
        q = self.questions[idx].astype(np.int64) # Long for embedding
        a = self.answers[idx].astype(np.int64)
        
        # Create mask (where q was 0)
        # Note: In embedding, we can just feed 0. 
        # But for loss we need mask.
        mask = (q == 0)
        
        # Target for Loss: PyTorch CrossEntropy expects 0-8 for classes 0-8.
        # Our answers are 1-9. So we subtract 1.
        target = a - 1
        
        return {
            'question': torch.from_numpy(q),    # (81,) Indices 0-9
            'answer': torch.from_numpy(target), # (81,) Indices 0-8
            'mask': torch.from_numpy(mask)      # (81,) Bool
        }

## 3. Improved Model with Embeddings
Replaced One-Hot input with `nn.Embedding`. 
- Input dimension dropped from `(B, 81, 10)` to `(B, 81)` indices.
- Memory usage significantly reduced.

In [4]:
class SudokuLSTM_Improved(nn.Module):
    def __init__(
        self,
        hidden_size=512,
        num_layers=6,
        dropout=0.3,
    ):
        super().__init__()
        # Embedding layer
        # 10 possible values in input: 0 (unknown) + 1-9 (digits)
        self.embedding = nn.Embedding(10, hidden_size)
        
        self.lstm = nn.LSTM(
            input_size=hidden_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout,
            bidirectional=True,
        )

        # Output maps to 9 classes (digits 1-9)
        self.fc = nn.Linear(hidden_size * 2, 9)

    def forward(self, x):
        # x: (batch, 81) indices
        # embed: (batch, 81, hidden)
        x = self.embedding(x)
        
        # lstm_out: (batch, 81, hidden * 2)
        lstm_out, _ = self.lstm(x)
        
        # out: (batch, 81, 9)
        out = self.fc(lstm_out)
        return out

## 4. Optimized Training Loop
- **AMP (Automatic Mixed Precision)**: `torch.amp.autocast`
- **Larger Batch Size**
- **Gradient Clipping**

In [5]:
# Config
BATCH_SIZE = 1024  # Increased from 128
LEARNING_RATE = 1e-3
NUM_EPOCHS = 20    # Can run many more due to speed

device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Using device: {device}")

# Load Data
train_ds = FastSudokuDataset("data/processed", "train")
test_ds = FastSudokuDataset("data/processed", "test")

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# Model Setup
model = SudokuLSTM_Improved().to(device)

# Compile model if available (Linux/CUDA usually)
if hasattr(torch, 'compile') and device.type == 'cuda':
    print("Compiling model...")
    model = torch.compile(model)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None

def masked_loss(preds, targets, mask):
    """Compute CrossEntropyLoss only on masked (unknown) cells."""
    # preds: (B, 81, 9)
    # targets: (B, 81)
    # mask: (B, 81)
    loss = F.cross_entropy(preds.reshape(-1, 9), targets.reshape(-1), reduction='none')
    loss = loss.reshape(targets.shape)
    masked_loss = loss * mask.float()
    return masked_loss.sum() / (mask.sum() + 1e-6)


def compute_accuracy(predictions, targets, mask):
    """Compute cell-level and puzzle-level accuracy.

    Args:
        predictions: Model output logits, shape (batch, 81, 9).
        targets: Target class indices, shape (batch, 81).
        mask: Boolean mask, shape (batch, 81). True = cell needs prediction.

    Returns:
        Tuple of (cell_accuracy, puzzle_accuracy).
    """
    # Get predicted classes (argmax over the 9 classes)
    predicted_classes = predictions.argmax(dim=-1)  # (batch, 81)

    # Cell accuracy: correct predictions among masked cells
    correct = (predicted_classes == targets) & mask  # Both correct AND masked
    cell_accuracy = correct.sum().float() / (mask.sum().float() + 1e-6)

    # Puzzle accuracy: all masked cells correct for each puzzle
    correct_per_puzzle = correct.sum(dim=1)  # (batch,) - correct cells per puzzle
    masked_per_puzzle = mask.sum(dim=1)       # (batch,) - masked cells per puzzle
    puzzles_solved = (correct_per_puzzle == masked_per_puzzle).float()
    puzzle_accuracy = puzzles_solved.mean()

    return cell_accuracy.item(), puzzle_accuracy.item()


def train_epoch(model, loader, optimizer, scaler):
    """Train for one epoch and return metrics."""
    model.train()
    total_loss = 0
    total_cell_acc = 0
    total_puzzle_acc = 0
    num_batches = 0
    pbar = tqdm(loader)
    
    for batch in pbar:
        q = batch['question'].to(device)
        a = batch['answer'].to(device)
        m = batch['mask'].to(device)
        
        optimizer.zero_grad()
        
        # Mixed Precision Context
        if device.type == 'cuda':
            with torch.amp.autocast('cuda'):
                preds = model(q)
                loss = masked_loss(preds, a, m)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        
        elif device.type == 'mps': # Mac Optimized
             with torch.autocast(device_type='mps', dtype=torch.float16):
                 preds = model(q)
                 loss = masked_loss(preds, a, m)
             loss.backward()
             optimizer.step()
             
        else: # CPU
            preds = model(q)
            loss = masked_loss(preds, a, m)
            loss.backward()
            optimizer.step()
        
        # Compute accuracy
        with torch.no_grad():
            cell_acc, puzzle_acc = compute_accuracy(preds, a, m)
            
        total_loss += loss.item()
        total_cell_acc += cell_acc
        total_puzzle_acc += puzzle_acc
        num_batches += 1
        
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'cell_acc': f'{cell_acc:.2%}',
            'puzzle_acc': f'{puzzle_acc:.2%}',
        })
        
    return {
        'loss': total_loss / num_batches,
        'cell_accuracy': total_cell_acc / num_batches,
        'puzzle_accuracy': total_puzzle_acc / num_batches,
    }


# Start Training
for epoch in range(NUM_EPOCHS):
    print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
    metrics = train_epoch(model, train_loader, optimizer, scaler)
    print(f"Average Loss: {metrics['loss']:.4f}, "
          f"Cell Acc: {metrics['cell_accuracy']:.2%}, "
          f"Puzzle Acc: {metrics['puzzle_accuracy']:.2%}")

Using device: cuda
Compiling model...


  scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None


Epoch 1/20


  4%|▍         | 76/2021 [01:30<38:42,  1.19s/it, loss=2.0504] 


KeyboardInterrupt: 