In [23]:
from datasets import load_dataset

# Login using e.g. `huggingface-cli login` to access this dataset
ds = load_dataset("sapientinc/sudoku-extreme")
ds

DatasetDict({
    train: Dataset({
        features: ['source', 'question', 'answer', 'rating'],
        num_rows: 3831994
    })
    test: Dataset({
        features: ['source', 'question', 'answer', 'rating'],
        num_rows: 422786
    })
})

In [24]:
# Filter dataset to only include easy sources
easy_sources = ['puzzles0_kaggle', 'puzzles1_unbiased', 'puzzles2_17_clue']
easy_ds = ds.filter(lambda x: x['source'] in easy_sources)
easy_ds

DatasetDict({
    train: Dataset({
        features: ['source', 'question', 'answer', 'rating'],
        num_rows: 1034600
    })
    test: Dataset({
        features: ['source', 'question', 'answer', 'rating'],
        num_rows: 114558
    })
})

In [25]:
easy_ds['train'][0]


{'source': 'puzzles1_unbiased',
 'question': '..8..23412........9......5.........4..3..89.7....53....6.3.1.7...7.4.8...5.8.9...',
 'answer': '678592341235184796914736258182967534543218967796453182869321475327645819451879623',
 'rating': 0}

In [26]:
len(easy_ds['train'][0]['question']), len(easy_ds['train'][0]['answer'])

(81, 81)

In [31]:
# =============================================================================
# COMPLETE TRAINING PIPELINE FOR SUDOKU RNN
# =============================================================================
# This pipeline trains an RNN to solve Sudoku puzzles.
# Key concepts:
#   - Input: 81 cells, each one-hot encoded (10 dims: 0=unknown, 1-9=digits)
#   - Output: 81 predictions, each classifying into 9 classes (digits 1-9)
#   - Masking: We only compute loss on cells that were originally unknown ('.')
# =============================================================================

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import (
    DataLoader,
    Dataset,
)
from tqdm import tqdm

In [32]:
# =============================================================================
# STEP 1: CUSTOM DATASET CLASS
# =============================================================================
# PyTorch Dataset that:
#   1. Converts question string to one-hot tensor (81, 10)
#   2. Converts answer to class indices (81,) - shifted to 0-8 for CrossEntropyLoss
#   3. Creates mask: True where cell was unknown (we need to predict these)
# =============================================================================


class SudokuDataset(Dataset):
    """PyTorch Dataset for Sudoku puzzles.

    Each sample contains:
        - question: One-hot encoded input puzzle (81, 10)
        - answer: Target class indices for each cell (81,) with values 0-8
        - mask: Boolean mask where True = cell needs prediction (81,)
    """

    def __init__(self, hf_dataset):
        """Initialize dataset from HuggingFace dataset.

        Args:
            hf_dataset: HuggingFace dataset with 'question' and 'answer' columns.
        """
        self.data = hf_dataset

    def __len__(self):
        """Return the number of samples in the dataset."""
        return len(self.data)

    def __getitem__(self, idx):
        """Get a single sample by index.

        Args:
            idx: Index of the sample.

        Returns:
            Dictionary containing question (one-hot), answer (class indices), and mask.
        """
        row = self.data[idx]
        question_str = row['question']
        answer_str = row['answer']

        # Step 1: Parse question string
        # '.' means unknown cell -> we use 0 to represent it
        # Known cells have digits 1-9
        question_ints = [0 if c == '.' else int(c) for c in question_str]

        # Step 2: Create mask BEFORE one-hot encoding
        # mask[i] = True means cell i was unknown and needs prediction
        mask = torch.tensor([q == 0 for q in question_ints], dtype=torch.bool)

        # Step 3: One-hot encode the question
        # Shape: (81, 10) - each cell has 10 possible values (0-9)
        question_tensor = torch.tensor(question_ints, dtype=torch.long)
        question_onehot = F.one_hot(question_tensor, num_classes=10).float()

        # Step 4: Parse answer and shift to 0-8 range
        # CrossEntropyLoss expects class indices starting from 0
        # Original digits are 1-9, so we subtract 1 to get 0-8
        answer_ints = [int(c) - 1 for c in answer_str]
        answer_tensor = torch.tensor(answer_ints, dtype=torch.long)

        return {
            'question': question_onehot,  # (81, 10) - input to RNN
            'answer': answer_tensor,       # (81,) - target class indices 0-8
            'mask': mask,                  # (81,) - True where we need prediction
        }


# Create datasets
train_dataset = SudokuDataset(easy_ds['train'])
test_dataset = SudokuDataset(easy_ds['test'])

# Verify shapes
sample = train_dataset[0]
print('Sample shapes:')
print(f"  question: {sample['question'].shape}")  # Should be (81, 10)
print(f"  answer: {sample['answer'].shape}")       # Should be (81,)
print(f"  mask: {sample['mask'].shape}")           # Should be (81,)
print(f"  mask sum (unknown cells): {sample['mask'].sum().item()}")  # How many cells to predict

Sample shapes:
  question: torch.Size([81, 10])
  answer: torch.Size([81])
  mask: torch.Size([81])
  mask sum (unknown cells): 56


In [None]:
# =============================================================================
# STEP 2: IMPROVED RNN MODEL
# =============================================================================
# Architecture:
#   Input: (batch, 81, 10) - 81 cells, each one-hot encoded
#   RNN: Processes sequence, outputs (batch, 81, hidden_size)
#   FC: Maps hidden to 9 classes, outputs (batch, 81, 9)
#
# Using LSTM instead of vanilla RNN for better gradient flow.
# Using bidirectional to capture constraints from both directions.
# =============================================================================


class SudokuLSTM(nn.Module):
    """LSTM model for solving Sudoku puzzles.

    Processes the 81 cells as a sequence and predicts the digit (1-9)
    for each cell. Uses bidirectional LSTM to capture constraints
    from both directions in the sequence.
    """

    def __init__(
        self,
        input_size=10,
        hidden_size=128,
        num_layers=2,
        num_classes=9,
        dropout=0.2,
        bidirectional=True,
    ):
        """Initialize the LSTM model.

        Args:
            input_size: Dimension of input features (10 for one-hot).
            hidden_size: Number of LSTM hidden units.
            num_layers: Number of stacked LSTM layers.
            num_classes: Number of output classes (9 for digits 1-9).
            dropout: Dropout rate between LSTM layers.
            bidirectional: Whether to use bidirectional LSTM.
        """
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.bidirectional = bidirectional

        # LSTM layer
        # Input: (batch, seq_len=81, input_size=10)
        # Output: (batch, seq_len=81, hidden_size * num_directions)
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=bidirectional,
        )

        # Fully connected layer to map LSTM output to class predictions
        # If bidirectional, LSTM output has size hidden_size * 2
        fc_input_size = hidden_size * 2 if bidirectional else hidden_size
        self.fc = nn.Linear(fc_input_size, num_classes)

    def forward(self, x):
        """Forward pass through the network.

        Args:
            x: Input tensor of shape (batch, 81, 10).

        Returns:
            Logits tensor of shape (batch, 81, 9).
        """
        # x shape: (batch, 81, 10)

        # Pass through LSTM
        # lstm_out shape: (batch, 81, hidden_size * num_directions)
        lstm_out, _ = self.lstm(x)

        # Pass through fully connected layer
        # out shape: (batch, 81, 9)
        out = self.fc(lstm_out)

        return out


# Create model
# Device selection: prefer CUDA > MPS (Apple Silicon) > CPU
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')
print(f'Using device: {device}')

model = SudokuLSTM(
    input_size=10,
    hidden_size=512,   # Increased from 128
    num_layers=6,      # Increased from 2
    num_classes=9,
    dropout=0.3,       # Slightly more regularization
    bidirectional=True,
)
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f'Total parameters: {total_params:,}')
print(model)

Using device: cuda
Total parameters: 33,653,769
SudokuLSTM(
  (lstm): LSTM(10, 512, num_layers=6, batch_first=True, dropout=0.3, bidirectional=True)
  (fc): Linear(in_features=1024, out_features=9, bias=True)
)


In [34]:
# =============================================================================
# STEP 3: MASKED LOSS FUNCTION
# =============================================================================
# Key insight: We only want to compute loss on cells that were UNKNOWN.
# Known cells should not contribute to the loss because the model
# doesn't need to "learn" to keep them unchanged.
#
# How it works:
#   1. Compute CrossEntropyLoss for ALL cells (reduction='none')
#   2. Multiply by mask (True=1 for unknown, False=0 for known)
#   3. Average only over the masked positions
# =============================================================================


def masked_cross_entropy_loss(predictions, targets, mask):
    """Compute CrossEntropyLoss only on masked (unknown) cells.

    Args:
        predictions: Model output logits, shape (batch, 81, 9).
        targets: Target class indices, shape (batch, 81).
        mask: Boolean mask, shape (batch, 81). True = compute loss.

    Returns:
        Scalar loss value (averaged over all masked positions).
    """
    batch_size = predictions.shape[0]

    # Reshape for CrossEntropyLoss
    # CrossEntropyLoss expects: (N, C) predictions and (N,) targets
    # where C = number of classes
    predictions_flat = predictions.view(-1, 9)  # (batch * 81, 9)
    targets_flat = targets.view(-1)              # (batch * 81,)

    # Compute loss for each position (no reduction)
    loss_per_position = F.cross_entropy(
        predictions_flat,
        targets_flat,
        reduction='none',  # Don't reduce - we'll mask first
    )

    # Reshape back to (batch, 81)
    loss_per_position = loss_per_position.view(batch_size, 81)

    # Apply mask: only count loss where mask is True
    # mask.float() converts True->1.0, False->0.0
    masked_loss = loss_per_position * mask.float()

    # Average over masked positions only
    # Sum all losses, divide by number of True values in mask
    num_masked = mask.sum()

    if num_masked > 0:
        loss = masked_loss.sum() / num_masked
    else:
        # Edge case: no masked positions (shouldn't happen in practice)
        loss = masked_loss.sum()

    return loss


# Test the loss function with dummy data
dummy_pred = torch.randn(2, 81, 9)  # 2 samples, 81 cells, 9 classes
dummy_target = torch.randint(0, 9, (2, 81))  # Random targets 0-8
dummy_mask = torch.rand(2, 81) > 0.5  # Random mask

loss = masked_cross_entropy_loss(dummy_pred, dummy_target, dummy_mask)
print(f'Test loss: {loss.item():.4f}')

Test loss: 2.5333


In [35]:
# =============================================================================
# STEP 4: ACCURACY METRICS
# =============================================================================
# We track two types of accuracy:
#   1. Cell accuracy: % of individual cells predicted correctly (masked only)
#   2. Puzzle accuracy: % of puzzles completely solved correctly
# =============================================================================


def compute_accuracy(predictions, targets, mask):
    """Compute cell-level and puzzle-level accuracy.

    Args:
        predictions: Model output logits, shape (batch, 81, 9).
        targets: Target class indices, shape (batch, 81).
        mask: Boolean mask, shape (batch, 81). True = cell needs prediction.

    Returns:
        Tuple of (cell_accuracy, puzzle_accuracy).
    """
    # Get predicted classes (argmax over the 9 classes)
    predicted_classes = predictions.argmax(dim=-1)  # (batch, 81)

    # Cell accuracy: correct predictions among masked cells
    correct = (predicted_classes == targets) & mask  # Both correct AND masked
    cell_accuracy = correct.sum().float() / mask.sum().float()

    # Puzzle accuracy: all masked cells correct for each puzzle
    # For each puzzle, check if ALL masked positions are correct
    # A puzzle is solved if: for all positions where mask=True, prediction=target
    correct_per_puzzle = correct.sum(dim=1)  # (batch,) - correct cells per puzzle
    masked_per_puzzle = mask.sum(dim=1)       # (batch,) - masked cells per puzzle
    puzzles_solved = (correct_per_puzzle == masked_per_puzzle).float()
    puzzle_accuracy = puzzles_solved.mean()

    return cell_accuracy.item(), puzzle_accuracy.item()

In [36]:
# =============================================================================
# STEP 5: TRAINING AND EVALUATION FUNCTIONS
# =============================================================================


def train_one_epoch(model, dataloader, optimizer, device):
    """Train the model for one epoch.

    Args:
        model: The neural network model.
        dataloader: DataLoader for training data.
        optimizer: Optimizer for updating weights.
        device: Device to run on (cuda/cpu).

    Returns:
        Dictionary with average loss and accuracies for the epoch.
    """
    model.train()  # Set model to training mode

    total_loss = 0.0
    total_cell_acc = 0.0
    total_puzzle_acc = 0.0
    num_batches = 0

    # Progress bar
    pbar = tqdm(dataloader, desc='Training')

    for batch in pbar:
        # Move data to device
        questions = batch['question'].to(device)  # (batch, 81, 10)
        answers = batch['answer'].to(device)       # (batch, 81)
        masks = batch['mask'].to(device)           # (batch, 81)

        # Forward pass
        optimizer.zero_grad()
        predictions = model(questions)  # (batch, 81, 9)

        # Compute masked loss
        loss = masked_cross_entropy_loss(predictions, answers, masks)

        # Backward pass
        loss.backward()
        optimizer.step()

        # Compute accuracy
        with torch.no_grad():
            cell_acc, puzzle_acc = compute_accuracy(predictions, answers, masks)

        # Accumulate metrics
        total_loss += loss.item()
        total_cell_acc += cell_acc
        total_puzzle_acc += puzzle_acc
        num_batches += 1

        # Update progress bar
        pbar.set_postfix({
            'loss': f'{loss.item():.4f}',
            'cell_acc': f'{cell_acc:.2%}',
            'puzzle_acc': f'{puzzle_acc:.2%}',
        })

    return {
        'loss': total_loss / num_batches,
        'cell_accuracy': total_cell_acc / num_batches,
        'puzzle_accuracy': total_puzzle_acc / num_batches,
    }


def evaluate(model, dataloader, device):
    """Evaluate the model on a dataset.

    Args:
        model: The neural network model.
        dataloader: DataLoader for evaluation data.
        device: Device to run on (cuda/cpu).

    Returns:
        Dictionary with average loss and accuracies.
    """
    model.eval()  # Set model to evaluation mode

    total_loss = 0.0
    total_cell_acc = 0.0
    total_puzzle_acc = 0.0
    num_batches = 0

    with torch.no_grad():  # No gradient computation during evaluation
        pbar = tqdm(dataloader, desc='Evaluating')

        for batch in pbar:
            questions = batch['question'].to(device)
            answers = batch['answer'].to(device)
            masks = batch['mask'].to(device)

            predictions = model(questions)
            loss = masked_cross_entropy_loss(predictions, answers, masks)
            cell_acc, puzzle_acc = compute_accuracy(predictions, answers, masks)

            total_loss += loss.item()
            total_cell_acc += cell_acc
            total_puzzle_acc += puzzle_acc
            num_batches += 1

    return {
        'loss': total_loss / num_batches,
        'cell_accuracy': total_cell_acc / num_batches,
        'puzzle_accuracy': total_puzzle_acc / num_batches,
    }

In [37]:
# =============================================================================
# STEP 6: MAIN TRAINING LOOP
# =============================================================================
# Hyperparameters and training configuration
# =============================================================================

# Hyperparameters
BATCH_SIZE = 128
LEARNING_RATE = 1e-3
NUM_EPOCHS = 2
# NOTE: num_workers=0 is required in Jupyter notebooks on macOS
# because multiprocessing can't pickle classes defined in the notebook.
# pin_memory=False because MPS (Apple Silicon) doesn't support it.
NUM_WORKERS = 0

# Create DataLoaders
# shuffle=True for training to randomize order each epoch
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=False,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,  # No need to shuffle test data
    num_workers=NUM_WORKERS,
    pin_memory=False,
)

# Optimizer
# Adam is a good default choice for most deep learning tasks
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# Learning rate scheduler (optional)
# Reduces LR when validation loss plateaus
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=2,

)

print(f'Training samples: {len(train_dataset):,}')
print(f'Test samples: {len(test_dataset):,}')
print(f'Batches per epoch: {len(train_loader):,}')

Training samples: 1,034,600
Test samples: 114,558
Batches per epoch: 8,083


In [None]:
# =============================================================================
# STEP 7: RUN TRAINING
# =============================================================================
# This is where the actual training happens.
# For each epoch:
#   1. Train on all training batches
#   2. Evaluate on test set
#   3. Update learning rate if needed
#   4. Save best model
# =============================================================================

best_puzzle_acc = 0.0
history = {'train': [], 'test': []}

for epoch in range(NUM_EPOCHS):
    print(f'\n{"="*60}')
    print(f'Epoch {epoch + 1}/{NUM_EPOCHS}')
    print(f'{"="*60}')

    # Train for one epoch
    train_metrics = train_one_epoch(model, train_loader, optimizer, device)
    print(f'\nTrain - Loss: {train_metrics["loss"]:.4f}, '
          f'Cell Acc: {train_metrics["cell_accuracy"]:.2%}, '
          f'Puzzle Acc: {train_metrics["puzzle_accuracy"]:.2%}')

    # Evaluate on test set
    test_metrics = evaluate(model, test_loader, device)
    print(f'Test  - Loss: {test_metrics["loss"]:.4f}, '
          f'Cell Acc: {test_metrics["cell_accuracy"]:.2%}, '
          f'Puzzle Acc: {test_metrics["puzzle_accuracy"]:.2%}')

    # Update learning rate based on test loss
    scheduler.step(test_metrics['loss'])

    # Save history
    history['train'].append(train_metrics)
    history['test'].append(test_metrics)

    # Save best model
    if test_metrics['puzzle_accuracy'] > best_puzzle_acc:
        best_puzzle_acc = test_metrics['puzzle_accuracy']
        torch.save(model.state_dict(), 'best_sudoku_model.pt')
        print(f'  → New best model saved! Puzzle accuracy: {best_puzzle_acc:.2%}')

print(f'\n{"="*60}')
print(f'Training complete! Best puzzle accuracy: {best_puzzle_acc:.2%}')
print(f'{"="*60}')


Epoch 1/2


Training:  31%|███       | 2494/8083 [30:34<1:07:58,  1.37it/s, loss=1.5235, cell_acc=31.72%, puzzle_acc=0.00%]

In [None]:
# =============================================================================
# STEP 8: INFERENCE - TEST ON A SINGLE PUZZLE
# =============================================================================
# After training, let's see how the model performs on a sample puzzle.
# =============================================================================


def solve_sudoku(model, puzzle_string, device):
    """Solve a single Sudoku puzzle using the trained model.

    Args:
        model: Trained SudokuLSTM model.
        puzzle_string: 81-character string with '.' for unknown cells.
        device: Device to run inference on.

    Returns:
        Tuple of (predicted_string, confidence_scores).
    """
    model.eval()

    # Convert puzzle string to one-hot tensor
    question_ints = [0 if c == '.' else int(c) for c in puzzle_string]
    question_tensor = torch.tensor(question_ints, dtype=torch.long)
    question_onehot = F.one_hot(question_tensor, num_classes=10).float()

    # Add batch dimension and move to device
    input_tensor = question_onehot.unsqueeze(0).to(device)  # (1, 81, 10)

    with torch.no_grad():
        output = model(input_tensor)  # (1, 81, 9)
        probabilities = F.softmax(output, dim=-1)  # Convert to probabilities
        predictions = output.argmax(dim=-1)  # (1, 81) - class indices 0-8

    # Convert predictions back to digits 1-9
    predictions = predictions.squeeze(0).cpu()  # (81,)
    probabilities = probabilities.squeeze(0).cpu()  # (81, 9)

    # Build result string
    result = []
    confidences = []
    for i, c in enumerate(puzzle_string):
        if c == '.':
            # Unknown cell - use model prediction
            pred_digit = predictions[i].item() + 1  # Convert 0-8 back to 1-9
            confidence = probabilities[i].max().item()
            result.append(str(pred_digit))
            confidences.append(confidence)
        else:
            # Known cell - keep original
            result.append(c)
            confidences.append(1.0)

    return ''.join(result), confidences


def display_sudoku(puzzle_str, title='Sudoku'):
    """Display a Sudoku puzzle in a readable format.

    Args:
        puzzle_str: 81-character string representing the puzzle.
        title: Title to display above the puzzle.
    """
    print(f'\n{title}:')
    print('+-------+-------+-------+')
    for i in range(9):
        row = puzzle_str[i*9:(i+1)*9]
        formatted = ' | '.join(
            ' '.join(row[j*3:(j+1)*3]) for j in range(3)
        )
        print(f'| {formatted} |')
        if i % 3 == 2:
            print('+-------+-------+-------+')


# Test on a sample from the test set
sample_idx = 0
sample = easy_ds['test'][sample_idx]
puzzle = sample['question']
solution = sample['answer']

print('Testing trained model on a sample puzzle:')
display_sudoku(puzzle, 'Input Puzzle')
display_sudoku(solution, 'Ground Truth')

# Solve using model
predicted, confidences = solve_sudoku(model, puzzle, device)
display_sudoku(predicted, 'Model Prediction')

# Check accuracy
correct = sum(p == s for p, s in zip(predicted, solution))
print(f'\nAccuracy: {correct}/81 cells correct ({correct/81:.1%})')
print(f'Average confidence: {sum(confidences)/len(confidences):.2%}')

In [None]:
# =============================================================================
# STEP 9: VISUALIZE TRAINING HISTORY
# =============================================================================

import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Plot loss
epochs = range(1, len(history['train']) + 1)
axes[0].plot(epochs, [h['loss'] for h in history['train']], 'b-', label='Train')
axes[0].plot(epochs, [h['loss'] for h in history['test']], 'r-', label='Test')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Loss over Epochs')
axes[0].legend()
axes[0].grid(True)

# Plot cell accuracy
axes[1].plot(epochs, [h['cell_accuracy'] for h in history['train']], 'b-', label='Train')
axes[1].plot(epochs, [h['cell_accuracy'] for h in history['test']], 'r-', label='Test')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Cell Accuracy')
axes[1].set_title('Cell Accuracy over Epochs')
axes[1].legend()
axes[1].grid(True)

# Plot puzzle accuracy
axes[2].plot(epochs, [h['puzzle_accuracy'] for h in history['train']], 'b-', label='Train')
axes[2].plot(epochs, [h['puzzle_accuracy'] for h in history['test']], 'r-', label='Test')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Puzzle Accuracy')
axes[2].set_title('Puzzle Accuracy over Epochs')
axes[2].legend()
axes[2].grid(True)

plt.tight_layout()
plt.show()