In [8]:
#LeetArxiv implementation of the paper: Sinkhorn Solves Sudoku
#Full walkthrough: https://leetarxiv.substack.com/p/sinkhorn-solves-sudoku
import numpy as np
import warnings

"""Ensure input matrix is square, 2D, non-negative, and has total support."""
"""The C version we wrote here has better checking:https://leetarxiv.substack.com/p/sinkhorn-solves-sudoku """
def CheckMatrix(P):
    P = np.asarray(P)
    assert np.all(P >= 0)
    assert P.ndim == 2
    assert P.shape[0] == P.shape[1]

    if not np.all(P.T.dot(np.ones((P.shape[0], 1))) != 0) or not np.all(P.dot(np.ones((P.shape[0], 1))) != 0):
        warnings.warn("Matrix P must have total support.", UserWarning)
    return P


def _normalize(P, max_iter=1000, epsilon=1e-3):
    N = P.shape[0]
    r = np.ones((N, 1))
    c = 1 / (P.T @ r)
    r = 1 / (P @ c)

    def stopping_criteria(P_eps):
        row_sums = np.sum(P_eps, axis=1)
        col_sums = np.sum(P_eps, axis=0)
        return (
            np.all((1 - epsilon) <= row_sums) and np.all(row_sums <= (1 + epsilon)) and
            np.all((1 - epsilon) <= col_sums) and np.all(col_sums <= (1 + epsilon))
        )

    def iterate(P, r, c, iteration=0):
        D1 = np.diag(np.squeeze(r))
        D2 = np.diag(np.squeeze(c))
        P_eps = D1 @ P @ D2

        if stopping_criteria(P_eps):
            return P_eps, D1, D2, iteration, "epsilon"
        if iteration >= max_iter:
            return P_eps, D1, D2, iteration, "max_iter"

        c_new = 1 / (P.T @ r)
        r_new = 1 / (P @ c_new)
        return iterate(P, r_new, c_new, iteration + 1)

    return iterate(P, r, c)


def sinkhorn_knopp(P, max_iter=1000, epsilon=1e-3):
    assert isinstance(max_iter, (int, float)) and max_iter > 0
    assert isinstance(epsilon, (int, float)) and 0 < epsilon < 1

    P = CheckMatrix(P)
    P_ds, D1, D2, iterations, stopping_condition = _normalize(P, int(max_iter), epsilon)

    return P_ds, {
        "iterations": iterations,
        "stopping_condition": stopping_condition,
        "D1": D1,
        "D2": D2
    }

def test_sinkhorn_knopp():
    P = np.array([
    [0.000, 9.000, 6.000, 3.000, 8.000],
    [5.000, 6.000, 1.000, 1.000, 5.000],
    [9.000, 8.000, 4.000, 8.000, 1.000],
    [0.000, 3.000, 0.000, 4.000, 4.000],
    [4.000, 4.000, 7.000, 6.000, 3.000]
])

    print("Original Matrix:\n", P)

    # Functional Version
    P_ds_func, meta = sinkhorn_knopp(P)
    print("\n Doubly stochastic matrix:\n", P_ds_func)
    print("Row sums:", np.sum(P_ds_func, axis=1))
    print("Column sums:", np.sum(P_ds_func, axis=0))

test_sinkhorn_knopp()

Original Matrix:
 [[0. 9. 6. 3. 8.]
 [5. 6. 1. 1. 5.]
 [9. 8. 4. 8. 1.]
 [0. 3. 0. 4. 4.]
 [4. 4. 7. 6. 3.]]

 Doubly stochastic matrix:
 [[0.         0.25679466 0.35602847 0.11239826 0.27477861]
 [0.39528436 0.23542506 0.08160024 0.05152242 0.23616792]
 [0.39283441 0.17330808 0.18020997 0.2275693  0.02607824]
 [0.         0.22957702 0.         0.40194071 0.36848227]
 [0.21149306 0.10496824 0.38201989 0.2067493  0.09476951]]
Row sums: [1. 1. 1. 1. 1.]
Column sums: [0.99961182 1.00007306 0.99985856 1.00017999 1.00027656]


# Step 2: Download Sudoku Dataset

In [9]:
from typing import Optional
import os
import csv
import json
import numpy as np
import pydantic

from pydantic import BaseModel
from tqdm import tqdm
from huggingface_hub import hf_hub_download
from typing import List, Optional

import numpy as np


# Global list mapping each dihedral transform id to its inverse.
# Index corresponds to the original tid, and the value is its inverse.
DIHEDRAL_INVERSE = [0, 3, 2, 1, 4, 5, 6, 7]


class PuzzleDatasetMetadata(pydantic.BaseModel):
    pad_id: int
    ignore_label_id: Optional[int]
    blank_identifier_id: int
    vocab_size: int
    seq_len: int
    num_puzzle_identifiers: int
    total_groups: int
    mean_puzzle_examples: float
    total_puzzles: int
    sets: List[str]


def dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
    """8 dihedral symmetries by rotate, flip and mirror"""
    
    if tid == 0:
        return arr  # identity
    elif tid == 1:
        return np.rot90(arr, k=1)
    elif tid == 2:
        return np.rot90(arr, k=2)
    elif tid == 3:
        return np.rot90(arr, k=3)
    elif tid == 4:
        return np.fliplr(arr)       # horizontal flip
    elif tid == 5:
        return np.flipud(arr)       # vertical flip
    elif tid == 6:
        return arr.T                # transpose (reflection along main diagonal)
    elif tid == 7:
        return np.fliplr(np.rot90(arr, k=1))  # anti-diagonal reflection
    else:
        return arr
    
    
def inverse_dihedral_transform(arr: np.ndarray, tid: int) -> np.ndarray:
    return dihedral_transform(arr, DIHEDRAL_INVERSE[tid])


class DataProcessConfig(BaseModel):
    source_repo: str = "sapientinc/sudoku-extreme"
    output_dir: str = "data/sudoku-extreme-full"

    subsample_size: Optional[int] = None
    min_difficulty: Optional[int] = None
    num_aug: int = 0


def shuffle_sudoku(board: np.ndarray, solution: np.ndarray):
    # Create a random digit mapping: a permutation of 1..9, with zero (blank) unchanged
    digit_map = np.pad(np.random.permutation(np.arange(1, 10)), (1, 0))
    
    # Randomly decide whether to transpose.
    transpose_flag = np.random.rand() < 0.5

    # Generate a valid row permutation:
    # - Shuffle the 3 bands (each band = 3 rows) and for each band, shuffle its 3 rows.
    bands = np.random.permutation(3)
    row_perm = np.concatenate([b * 3 + np.random.permutation(3) for b in bands])

    # Similarly for columns (stacks).
    stacks = np.random.permutation(3)
    col_perm = np.concatenate([s * 3 + np.random.permutation(3) for s in stacks])

    # Build an 81->81 mapping. For each new cell at (i, j)
    # (row index = i // 9, col index = i % 9),
    # its value comes from old row = row_perm[i//9] and old col = col_perm[i%9].
    mapping = np.array([row_perm[i // 9] * 9 + col_perm[i % 9] for i in range(81)])

    def apply_transformation(x: np.ndarray) -> np.ndarray:
        # Apply transpose flag
        if transpose_flag:
            x = x.T
        # Apply the position mapping.
        new_board = x.flatten()[mapping].reshape(9, 9).copy()
        # Apply digit mapping
        return digit_map[new_board]

    return apply_transformation(board), apply_transformation(solution)


def convert_subset(set_name: str, config: DataProcessConfig):
    # Read CSV
    inputs = []
    labels = []
    
    with open(hf_hub_download(config.source_repo, f"{set_name}.csv", repo_type="dataset"), newline="") as csvfile:
        reader = csv.reader(csvfile)
        next(reader)  # Skip header
        for source, q, a, rating in reader:
            if (config.min_difficulty is None) or (int(rating) >= config.min_difficulty):
                assert len(q) == 81 and len(a) == 81
                
                inputs.append(np.frombuffer(q.replace('.', '0').encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))
                labels.append(np.frombuffer(a.encode(), dtype=np.uint8).reshape(9, 9) - ord('0'))

    # If subsample_size is specified for the training set,
    # randomly sample the desired number of examples.
    if set_name == "train" and config.subsample_size is not None:
        total_samples = len(inputs)
        if config.subsample_size < total_samples:
            indices = np.random.choice(total_samples, size=config.subsample_size, replace=False)
            inputs = [inputs[i] for i in indices]
            labels = [labels[i] for i in indices]

    # Generate dataset
    num_augments = config.num_aug if set_name == "train" else 0

    results = {k: [] for k in ["inputs", "labels", "puzzle_identifiers", "puzzle_indices", "group_indices"]}
    puzzle_id = 0
    example_id = 0
    
    results["puzzle_indices"].append(0)
    results["group_indices"].append(0)
    
    for orig_inp, orig_out in zip(tqdm(inputs), labels):
        for aug_idx in range(1 + num_augments):
            # First index is not augmented
            if aug_idx == 0:
                inp, out = orig_inp, orig_out
            else:
                inp, out = shuffle_sudoku(orig_inp, orig_out)

            # Push puzzle (only single example)
            results["inputs"].append(inp)
            results["labels"].append(out)
            example_id += 1
            puzzle_id += 1
            
            results["puzzle_indices"].append(example_id)
            results["puzzle_identifiers"].append(0)
            
        # Push group
        results["group_indices"].append(puzzle_id)
        
    # To Numpy
    def _seq_to_numpy(seq):
        arr = np.concatenate(seq).reshape(len(seq), -1)
        
        assert np.all((arr >= 0) & (arr <= 9))
        return arr + 1
    
    results = {
        "inputs": _seq_to_numpy(results["inputs"]),
        "labels": _seq_to_numpy(results["labels"]),
        
        "group_indices": np.array(results["group_indices"], dtype=np.int32),
        "puzzle_indices": np.array(results["puzzle_indices"], dtype=np.int32),
        "puzzle_identifiers": np.array(results["puzzle_identifiers"], dtype=np.int32),
    }

    # Metadata
    metadata = PuzzleDatasetMetadata(
        seq_len=81,
        vocab_size=10 + 1,  # PAD + "0" ... "9"
        pad_id=0,
        ignore_label_id=0,
        blank_identifier_id=0,
        num_puzzle_identifiers=1,
        total_groups=len(results["group_indices"]) - 1,
        mean_puzzle_examples=1,
        total_puzzles=len(results["group_indices"]) - 1,
        sets=["all"]
    )

    # Save metadata as JSON.
    save_dir = os.path.join(config.output_dir, set_name)
    os.makedirs(save_dir, exist_ok=True)
    
    with open(os.path.join(save_dir, "dataset.json"), "w") as f:
        json.dump(metadata.dict(), f)
        
    # Save data
    for k, v in results.items():
        np.save(os.path.join(save_dir, f"all__{k}.npy"), v)
        
    # Save IDs mapping (for visualization only)
    with open(os.path.join(config.output_dir, "identifiers.json"), "w") as f:
        json.dump(["<blank>"], f)


config = DataProcessConfig(
    source_repo="sapientinc/sudoku-extreme", 
    output_dir="sudoku-data",        
    subsample_size=1000,                      
    min_difficulty=None,                      
    num_aug=2                                 
)

os.makedirs(config.output_dir, exist_ok=True)

convert_subset("train", config)
convert_subset("test", config)


100%|██████████| 1000/1000 [00:00<00:00, 2906.11it/s]


test.csv:   0%|          | 0.00/79.4M [00:00<?, ?B/s]

100%|██████████| 422786/422786 [00:00<00:00, 791739.15it/s]
