# üêç Mamba Sudoku Training (O(N) Linear Complexity)

Train a Sudoku solver using the **Mamba Depth Controller** with Selective State Space Models.

**Key Features:**
- O(N) linear complexity (vs O(N¬≤) for Transformer)
- Input-dependent state transitions (selective scan)
- Optional torch.compile optimization

**Reference:** [Mamba: Linear-Time Sequence Modeling with Selective State Spaces](https://arxiv.org/abs/2312.00752)


In [None]:
# Clone repository and install dependencies
!git clone https://github.com/Eran-BA/PoT.git
%cd PoT
!pip install -q -r requirements.txt


In [None]:
# Check GPU and verify Mamba controller
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

from src.pot.core import MambaDepthController, create_controller
print("\nMamba controller available: ‚úì")
print(f"mamba_ssm CUDA kernels: {MambaDepthController.is_mamba_ssm_available()}")

# Quick test
controller = create_controller("mamba", d_model=256, n_heads=8)
print(f"Controller params: {sum(p.numel() for p in controller.parameters()):,}")


## üöÄ Train with Mamba Controller


In [None]:
# Run Mamba training
!python experiments/sudoku_poh_benchmark.py \
    --download \
    --model hybrid \
    --controller mamba \
    --optimize-mamba \
    --d-model 512 \
    --n-heads 8 \
    --H-cycles 2 \
    --L-cycles 8 \
    --epochs 500 \
    --batch-size 256 \
    --lr 3e-4 \
    --subsample 1000 \
    --num-aug 100 \
    --eval-interval 50 \
    --output experiments/results/mamba_sudoku


## üìà Results


In [None]:
# Plot results
import json
import matplotlib.pyplot as plt

with open('experiments/results/mamba_sudoku/hybrid_results.json', 'r') as f:
    results = json.load(f)

print(f"Best Grid Accuracy: {results['best_grid_acc']:.2f}%")
print(f"Parameters: {results['parameters']:,}")

history = results['history']
epochs = [h['epoch'] for h in history]
train_acc = [h['train_grid_acc'] for h in history]
test_acc = [h['test_grid_acc'] for h in history]

plt.figure(figsize=(10, 5))
plt.plot(epochs, train_acc, label='Train')
plt.plot(epochs, test_acc, label='Test')
plt.xlabel('Epoch')
plt.ylabel('Grid Accuracy (%)')
plt.title('Mamba Controller - Sudoku Training')
plt.legend()
plt.grid(True)
plt.show()
