In [37]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from tqdm.auto import tqdm
import math

In [38]:
# ------------------------------------------------------------------------
# 1) A discrete and Sigmoid diffusion schedule
# ------------------------------------------------------------------------
class DiscreteDiffusionSchedule:
    """
    Simple linear schedule of alpha_t from t=1..T,
    where alpha_t = min_alpha + (max_alpha - min_alpha)*(t/T).
    """
    def __init__(self, T=10, min_alpha=0.1, max_alpha=0.7):
        self.T = T
        self.alphas = []
        for t in range(1, T+1):
            frac = t / T
            alpha_t = min_alpha + (max_alpha - min_alpha)*frac
            self.alphas.append(alpha_t)

    def __len__(self):
        return self.T

    def __getitem__(self, t):
        # t in [1..T], python indexing 0..T-1
        return self.alphas[t-1]

class SigmoidDiffusionSchedule:
    """
    Sigmoid schedule of alpha_t from t=1..T.

    alpha_t = min_alpha + (max_alpha - min_alpha)*sigmoid(k*(frac - 0.5)),
    where frac = (t-1)/(T-1).
    """
    def __init__(self, T=30, min_alpha=0.1, max_alpha=0.7, k=12.0):
        self.T = T
        self.alphas = []
        for t in range(1, T+1):
            # frac in [0..1]
            frac = (t - 1) / (T - 1)  
            # logistic
            s = 1 / (1 + math.exp(-k * (frac - 0.5)))
            alpha_t = min_alpha + (max_alpha - min_alpha) * s
            self.alphas.append(alpha_t)

    def __len__(self):
        return self.T

    def __getitem__(self, t):
        # t in [1..T], python indexing 0..T-1
        return self.alphas[t - 1]

In [4]:
# ------------------------------------------------------------------------
# 2) Forward noising that respects puzzle givens
# ------------------------------------------------------------------------
def forward_diffusion_with_puzzle(puzzle, solution, t, schedule, vocab_size, device):
    """
    puzzle:   (batch, 81) with digits in [0..9]. 0 means blank, non-zero means given.
    solution: (batch, 81) correct final solution
    t:        an integer in [1..T]
    schedule: contains alpha_t
    returns x_t: partially noised solution (batch, 81)
        - givens remain the same as solution's corresponding digit
        - blank positions get replaced with random digits w.p. alpha_t
    """
    alpha_t = schedule[t]  # fraction to noise
    puzzle = puzzle.to(device)
    solution = solution.to(device)

    # Where puzzle is nonzero => givens => do NOT overwrite
    givens_mask = (puzzle != 0)

    # We'll noise only the positions that are blank in the puzzle
    #   i.e. puzzle[i] == 0 => we can noise solution[i].
    blank_mask = (puzzle == 0)

    # Create random noise from [0..vocab_size-1] for the blank positions
    noise = torch.randint(0, vocab_size, solution.shape, device=device)

    # Decide which blank positions to replace with noise
    replace_mask = (torch.rand_like(solution.float()) < alpha_t) & blank_mask

    # x_t: start from the true solution, then replace with noise for some blank cells
    x_t = solution.clone()
    x_t[replace_mask] = noise[replace_mask]

    # givens remain the same as the correct solution digit at that position
    # (actually, this is already the default if puzzle != 0, but we do not overwrite them)
    # so x_t[givens_mask] = solution[givens_mask] # if you want to be explicit

    return x_t

In [5]:
# ------------------------------------------------------------------------
# 3) The model sees puzzle+partially noised solution as input
# ------------------------------------------------------------------------
class PuzzleDenoiser(nn.Module):
    """
    Transformer that takes [puzzle || x_t] => shape (batch, 162) tokens
    and outputs a distribution over solution tokens for the second half.
    """
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.output_layer = nn.Linear(embed_dim, vocab_size)

    def forward(self, puzzle, x_t):
        """
        puzzle, x_t: (batch, 81)
        We'll embed => shape (batch, 162, embed_dim)
        Return shape => (batch, 162, vocab_size),
        but we'll only evaluate the second half in the loss.
        """
        batch_size = puzzle.size(0)
        # Concatenate
        inp = torch.cat([puzzle, x_t], dim=1)  # (batch, 162)
        emb = self.embedding(inp)             # => (batch, 162, embed_dim)
        enc_out = self.encoder(emb)           # => (batch, 162, embed_dim)
        logits = self.output_layer(enc_out)   # => (batch, 162, vocab_size)
        return logits

In [7]:
# ------------------------------------------------------------------------
# 4) Diffusion train step
# ------------------------------------------------------------------------
def diffusion_train_step(
    model, 
    puzzle,      # (batch, 81)
    solution,    # (batch, 81)
    schedule,
    optimizer,
    loss_fn,
    vocab_size,
    device
):
    model.train()
    optimizer.zero_grad()

    # Move to device
    puzzle = puzzle.to(device)
    solution = solution.to(device)

    # Sample a random t
    T = schedule.T
    t = np.random.randint(1, T+1)

    # Create x_t that doesn't overwrite puzzle givens
    x_t = forward_diffusion_with_puzzle(
        puzzle, solution, t, schedule, vocab_size, device
    )

    # Forward pass [puzzle || x_t]
    logits = model(puzzle, x_t)  # shape (batch, 162, vocab_size)

    # We only care about the second half (positions 81..161) 
    # for the solution. The first half is puzzle context.
    logits_solution_part = logits[:, 81:, :]  # (batch, 81, vocab_size)

    # Compare to ground truth solution
    loss = loss_fn(
        logits_solution_part.reshape(-1, vocab_size),
        solution.reshape(-1)
    )
    loss.backward()
    optimizer.step()
    return loss.item()

In [9]:
# ------------------------------------------------------------------------
# 5) Validation step
# ------------------------------------------------------------------------
@torch.no_grad()
def diffusion_eval_step(
    model, 
    puzzle,
    solution,
    schedule,
    loss_fn,
    vocab_size,
    device
):
    model.eval()
    puzzle = puzzle.to(device)
    solution = solution.to(device)

    t = np.random.randint(1, schedule.T+1)
    x_t = forward_diffusion_with_puzzle(
        puzzle, solution, t, schedule, vocab_size, device
    )

    logits = model(puzzle, x_t)
    logits_solution_part = logits[:, 81:, :]  # only second half
    loss = loss_fn(
        logits_solution_part.reshape(-1, vocab_size),
        solution.reshape(-1)
    )
    return loss.item()

In [10]:
# ------------------------------------------------------------------------
# 6) Iterative decoding to fill blank cells
# ------------------------------------------------------------------------
@torch.no_grad()
def iterative_decode(
    model,
    puzzle,        # shape (batch, 81) puzzle givens
    schedule,
    vocab_size,
    device,
    steps=None
):
    """
    Steps:
      1) Start from x_T = fully noised in blank cells (or partial).
      2) For t in [T..1], model predicts solution => we partially adopt it in blank cells
         respecting puzzle givens.
      3) Return final x_0
    If steps=None => run the full T steps from schedule.
    """
    model.eval()
    puzzle = puzzle.to(device)
    batch_size, seq_len = puzzle.shape
    T = schedule.T if steps is None else steps

    # Initialize x_t: fully noised in blank positions, puzzle givens remain correct solution digits 
    # (We don't have the solution during inference, so let's use puzzle for givens 
    #  and random for blanks)
    x_t = puzzle.clone()
    blank_mask = (puzzle == 0)
    # fill blanks with random
    x_t[blank_mask] = torch.randint(0, vocab_size, x_t[blank_mask].shape, device=device)

    for t in range(T, 0, -1):
        # forward pass
        logits = model(puzzle, x_t)         # (batch, 162, vocab_size)
        logits_solution_part = logits[:, 81:, :]  # (batch, 81, vocab_size)

        # predicted solution tokens for second half
        pred_sol = logits_solution_part.argmax(dim=-1)  # (batch, 81)

        # Now adopt the model's predictions *only* in blank cells.
        # Keep puzzle givens as-is. But if puzzle[i] was 0, we update from pred_sol.
        x_t[blank_mask] = pred_sol[blank_mask]

        # Optionally, you can do partial or probabilistic "denoising" 
        # (like in the paper). For simplicity, we do a 1-step "take argmax" each iteration.

    return x_t  # hopefully a filled solution

In [29]:
@torch.no_grad()
def measure_sudoku_solve_rate(
    model,
    schedule,
    X_val,    # (N, 81) puzzles
    y_val,    # (N, 81) solutions
    device,
    vocab_size=10,
    batch_size=32
):
    """
    Runs iterative_decode on each puzzle in X_val,
    checks how many are solved 100% correctly.
    Returns a float in [0..1] for each batch.
    """
    model.eval()
    
    # If inputs are already batches, don't create new loader
    if len(X_val.shape) == 2 and X_val.shape[0] <= batch_size:
        puzzles = X_val.to(device)
        solutions = y_val.to(device)
        
        x_filled = iterative_decode(
            model=model,
            puzzle=puzzles, 
            schedule=schedule,
            vocab_size=vocab_size,
            device=device
        )
        
        eq_mask = (x_filled == solutions)
        batch_solved = eq_mask.all(dim=1).sum()
        solve_rate = batch_solved.item() / puzzles.size(0)
        return solve_rate
        
    # For full dataset evaluation, use DataLoader
    dataset = TensorDataset(X_val, y_val)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    solved_count = 0
    total_count = 0

    for puzzles, solutions in tqdm(loader, desc="Evaluating Solve Rate"):
        puzzles = puzzles.to(device)
        solutions = solutions.to(device)

        x_filled = iterative_decode(
            model=model,
            puzzle=puzzles,
            schedule=schedule,
            vocab_size=vocab_size,
            device=device
        )

        eq_mask = (x_filled == solutions)
        batch_solved = eq_mask.all(dim=1).sum()
        solved_count += batch_solved.item()
        total_count += puzzles.size(0)

    solve_rate = solved_count / total_count
    return solve_rate

In [46]:
@torch.no_grad()
def measure_token_accuracy(model, schedule, X_val, y_val, device, vocab_size=10, batch_size=32):
    model.eval()
    dataset = TensorDataset(X_val, y_val)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    correct = 0
    total = 0
    for puzzles, solutions in tqdm(loader, desc="Evaluating Token Accuracy"):
        puzzles = puzzles.to(device)
        solutions = solutions.to(device)
        x_filled = iterative_decode(model, puzzles, schedule, vocab_size, device)
        correct += (x_filled == solutions).sum().item()
        total += puzzles.numel()
    return correct / total

In [47]:
# ------------------------------------------------------------------------
# 7) Putting it all together: example training loop
# ------------------------------------------------------------------------
def train_puzzle_diffusion(
    X_train, y_train, 
    X_val,   y_val,
    vocab_size=10,
    T=30,
    embed_dim=256,
    num_heads=8,
    num_layers=4,
    batch_size=32,
    num_epochs=5
):
    """
    X_* are puzzle arrays of shape (N, 81).
    y_* are solution arrays of shape (N, 81).
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    schedule = SigmoidDiffusionSchedule(
        T=T, 
        min_alpha=0.1, 
        max_alpha=0.7,
        k = 12.0
    )

    model = PuzzleDenoiser(
        vocab_size=vocab_size, 
        embed_dim=embed_dim, 
        num_heads=num_heads, 
        num_layers=num_layers
    ).to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    loss_fn = nn.CrossEntropyLoss()

    # Setup data loaders
    train_dataset = TensorDataset(X_train, y_train)
    val_dataset   = TensorDataset(X_val,   y_val)
    train_loader  = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader    = DataLoader(val_dataset,   batch_size=batch_size)

    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        model.train()
        train_losses = []
        for batch_puzzle, batch_solution in tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]"):
            loss_val = diffusion_train_step(
                model=model,
                puzzle=batch_puzzle,
                solution=batch_solution,
                schedule=schedule,
                optimizer=optimizer,
                loss_fn=loss_fn,
                vocab_size=vocab_size,
                device=device
            )
            train_losses.append(loss_val)
        avg_train_loss = np.mean(train_losses)

        # Validation
        model.eval()  # Set model to evaluation mode
        val_losses = []
        
        # Calculate accuracy on full validation set once per epoch
        val_accuracy = measure_sudoku_solve_rate(
            model,
            schedule,
            X_val,
            y_val,
            device=device,
            vocab_size=vocab_size,
            batch_size=batch_size
        )

        token_accuracy = measure_token_accuracy(
            model=model,
            schedule=schedule,
            X_val=X_val,
            y_val=y_val,
            device=device,
            vocab_size=vocab_size,
            batch_size=batch_size
        )

        with torch.no_grad():
            for batch_puzzle, batch_solution in tqdm(val_loader, desc=f"Epoch {epoch+1} [Val  ]"):
                # Get validation loss
                vloss = diffusion_eval_step(
                    model=model,
                    puzzle=batch_puzzle,
                    solution=batch_solution, 
                    schedule=schedule,
                    loss_fn=loss_fn,
                    vocab_size=vocab_size,
                    device=device
                )
                val_losses.append(vloss)
                
        avg_val_loss = np.mean(val_losses)

        print(f"Epoch {epoch+1}: train_loss={avg_train_loss:.4f}, val_loss={avg_val_loss:.4f}")
        print(f"Solve rate on validation set: {val_accuracy*100:.2f}%")
        print(f"Token-level accuracy on validation set: {token_accuracy*100:.2f}%")

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), "puzzle_diffuser_best.pt")
            print(f"  [*] Best model saved @ val_loss={avg_val_loss:.4f}")

    return model, schedule

In [14]:
# Load and preprocess data
import pandas as pd
from sklearn.model_selection import train_test_split

# Load sudoku data
df = pd.read_csv('./data/sudoku.csv')

# Convert strings to tensors
def preprocess_sudoku(puzzle_str):
    # Convert string to list of integers and then to tensor
    return torch.tensor([int(d) for d in puzzle_str], dtype=torch.long)

# Convert all puzzles and solutions
puzzles = torch.stack([preprocess_sudoku(p) for p in df['quizzes']])
solutions = torch.stack([preprocess_sudoku(s) for s in df['solutions']])

# Karpathy split (90/5/5)
train_size = 0.9
val_size = 0.05
test_size = 0.05

# First split into train and temp
X_train, X_temp, y_train, y_temp = train_test_split(
    puzzles, solutions, train_size=train_size, random_state=42
)

# Split temp into val and test
X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.5, random_state=42
)

In [45]:
model, schedule = train_puzzle_diffusion(
    X_train, y_train,
    X_val,   y_val,
    vocab_size=10,
    T=40,
    embed_dim=512,
    num_heads=8,
    num_layers=8,
    batch_size=64,
    num_epochs=50
)

Using device: cuda


Epoch 1 [Train]:   1%|▏         | 210/14063 [00:56<1:02:40,  3.68it/s]


KeyboardInterrupt: 

In [43]:
#############################################
# 2) CREATE THE NEW SCHEDULE
#############################################
schedule = SigmoidDiffusionSchedule(
    T=30,
    min_alpha=0.1,
    max_alpha=0.7,
    k=12.0
)

#############################################
# 3) PICK A SAMPLE PUZZLE & SOLUTION
#############################################
sample_puzzle = X_train[0:1]
sample_solution = y_train[0:1]

print("Puzzle:", sample_puzzle)
print("Solution:", sample_solution)

#############################################
# 4) SHOW x_1, x_2, ... x_T
#############################################
vocab_size = 10  # digits 0-9
for t in range(1, schedule.T+1):
    alpha_t = schedule[t]  # just to see the noise fraction
    x_t = forward_diffusion_with_puzzle(
        puzzle=sample_puzzle, 
        solution=sample_solution, 
        t=t, 
        schedule=schedule, 
        vocab_size=vocab_size, 
        device=device
    )
    print(f"x_{t} (alpha={alpha_t:.3f}) =", x_t)

Puzzle: tensor([[0, 4, 2, 0, 0, 9, 0, 7, 5, 9, 0, 0, 0, 0, 7, 0, 0, 3, 3, 0, 5, 6, 1, 0,
         9, 0, 0, 0, 0, 4, 9, 7, 8, 0, 0, 6, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 1, 8,
         4, 0, 6, 2, 0, 0, 0, 6, 7, 0, 0, 0, 4, 3, 8, 8, 0, 0, 7, 0, 0, 0, 9, 2,
         0, 0, 0, 0, 5, 0, 0, 0, 0]])
Solution: tensor([[1, 4, 2, 3, 8, 9, 6, 7, 5, 9, 8, 6, 5, 4, 7, 1, 2, 3, 3, 7, 5, 6, 1, 2,
         9, 8, 4, 2, 5, 4, 9, 7, 8, 3, 1, 6, 6, 9, 3, 1, 2, 5, 8, 4, 7, 7, 1, 8,
         4, 3, 6, 2, 5, 9, 5, 6, 7, 2, 9, 1, 4, 3, 8, 8, 3, 1, 7, 6, 4, 5, 9, 2,
         4, 2, 9, 8, 5, 3, 7, 6, 1]])
x_1 (alpha=0.101) = tensor([[1, 4, 2, 3, 8, 9, 6, 7, 5, 9, 8, 6, 5, 4, 7, 1, 2, 3, 3, 7, 5, 6, 1, 2,
         9, 8, 4, 2, 5, 4, 9, 7, 8, 3, 1, 6, 6, 9, 3, 1, 2, 5, 8, 4, 7, 4, 1, 8,
         4, 1, 6, 2, 5, 9, 5, 6, 7, 2, 9, 1, 4, 3, 8, 8, 3, 1, 7, 9, 4, 5, 9, 2,
         4, 9, 9, 8, 5, 3, 7, 6, 1]], device='cuda:0')
x_2 (alpha=0.102) = tensor([[1, 4, 2, 3, 8, 9, 6, 7, 5, 9, 8, 6, 5, 4, 7, 1, 2, 3, 3, 7, 5, 6, 1, 2,
 

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Then to do iterative decoding:
puzzle_batch = X_val[0:2]  # for example
x_filled = iterative_decode(
    model, 
    puzzle=puzzle_batch, 
    schedule=schedule, 
    vocab_size=10,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
print("Puzzle: ", puzzle_batch)
print("Decoded solution: ", x_filled)
print("Original solution: ", y_val[0:2])

In [30]:
accuracy = measure_sudoku_solve_rate(
    model,
    schedule,
    X_val[0:100],
    y_val[0:100],
    device=device,
    vocab_size=10,
    batch_size=32
)

print(f"Solve rate on validation set: {accuracy*100:.2f}%")

Evaluating Solve Rate: 100%|██████████| 4/4 [00:00<00:00,  6.24it/s]

Solve rate on validation set: 0.00%





In [49]:
token_accuracy = measure_token_accuracy(
    model,
    schedule,
    X_val[0:100],
    y_val[0:100],
    device=device,
    vocab_size=10,
    batch_size=32
)

print(f"Token-level accuracy on validation set: {token_accuracy*100:.2f}%")

Evaluating Token Accuracy: 100%|██████████| 4/4 [00:01<00:00,  3.96it/s]

Token-level accuracy on validation set: 48.02%



