In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from tqdm.auto import tqdm
import math
from torch.cuda.amp import autocast, GradScaler
import random

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# ------------------------------------------------------------------------
# 1) A discrete and Sigmoid diffusion schedule
# ------------------------------------------------------------------------
class DiscreteDiffusionSchedule:
    """
    Simple linear schedule of alpha_t from t=1..T,
    where alpha_t = min_alpha + (max_alpha - min_alpha)*(t/T).
    """
    def __init__(self, T=10, min_alpha=0.1, max_alpha=0.7):
        self.T = T
        self.alphas = []
        for t in range(1, T+1):
            frac = t / T
            alpha_t = min_alpha + (max_alpha - min_alpha)*frac
            self.alphas.append(alpha_t)

    def __len__(self):
        return self.T

    def __getitem__(self, t):
        # t in [1..T], python indexing 0..T-1
        return self.alphas[t-1]

class SigmoidDiffusionSchedule:
    """
    Sigmoid schedule of alpha_t from t=1..T.

    alpha_t = min_alpha + (max_alpha - min_alpha)*sigmoid(k*(frac - 0.5)),
    where frac = (t-1)/(T-1).
    """
    def __init__(self, T=30, min_alpha=0.1, max_alpha=0.7, k=12.0):
        self.T = T
        self.alphas = []
        for t in range(1, T+1):
            # frac in [0..1]
            frac = (t - 1) / (T - 1)  
            # logistic
            s = 1 / (1 + math.exp(-k * (frac - 0.5)))
            alpha_t = min_alpha + (max_alpha - min_alpha) * s
            self.alphas.append(alpha_t)

    def __len__(self):
        return self.T

    def __getitem__(self, t):
        # t in [1..T], python indexing 0..T-1
        return self.alphas[t - 1]

class OneFlipDiffusionSchedule:
    """
    Diffusion schedule where `t` digits are flipped at step `t`.
    """
    def __init__(self, T=50, puzzle_size=81):
        """
        Args:
            T (int): Number of diffusion steps.
            puzzle_size (int): Total number of digits in the puzzle (e.g., 81 for 9x9 Sudoku).
        """
        self.T = T
        self.puzzle_size = puzzle_size
        self.flip_counts = [min(t, puzzle_size) for t in range(1, T + 1)]  # `t` flips at step `t`

    def __len__(self):
        return self.T

    def __getitem__(self, t):
        """
        Returns the number of flips to apply at step t.
        Args:
            t (int): Step index (1-based, i.e., t in [1..T]).
        """
        return self.flip_counts[t - 1]



In [3]:
# ------------------------------------------------------------------------
# 2) Forward noising that respects puzzle givens
# ------------------------------------------------------------------------
def forward_diffusion_with_puzzle(puzzle, solution, t, schedule, vocab_size, device):
    """
    puzzle:   (batch, 81) with digits in [0..9]. 0 means blank, non-zero means given.
    solution: (batch, 81) correct final solution
    t:        an integer in [1..T]
    schedule: contains alpha_t
    returns x_t: partially noised solution (batch, 81)
        - givens remain the same as solution's corresponding digit
        - blank positions get replaced with random digits w.p. alpha_t
    """
    alpha_t = schedule[t]  # fraction to noise
    puzzle = puzzle.to(device)
    solution = solution.to(device)

    # Where puzzle is nonzero => givens => do NOT overwrite
    givens_mask = (puzzle != 0)

    # We'll noise only the positions that are blank in the puzzle
    #   i.e. puzzle[i] == 0 => we can noise solution[i].
    blank_mask = (puzzle == 0)

    # Create random noise from [0..vocab_size-1] for the blank positions
    noise = torch.randint(0, vocab_size, solution.shape, device=device)

    # Decide which blank positions to replace with noise
    replace_mask = (torch.rand_like(solution.float()) < alpha_t) & blank_mask

    # x_t: start from the true solution, then replace with noise for some blank cells
    x_t = solution.clone()
    x_t[replace_mask] = noise[replace_mask]

    # givens remain the same as the correct solution digit at that position
    # (actually, this is already the default if puzzle != 0, but we do not overwrite them)
    # so x_t[givens_mask] = solution[givens_mask] # if you want to be explicit

    return x_t

def forward_diffusion_mixed(
    puzzle,
    x_prev,
    t,
    schedule,
    vocab_size,
    device,
    zero_bias=0.8,
    bias_increment=1.3
):
    """
    Markov forward step from x_{t-1} to x_t by flipping `schedule[t]` digits, 
    with a bias toward flipping to zero, while respecting puzzle givens.

    Args:
        puzzle: (batch, 81) Original puzzle digits [0..9], 0 means blank, non-zero means given.
        x_prev: (batch, 81) The noised state from the previous step (x_{t-1}).
        t:      The current step index in [1..T].
        schedule: A diffusion schedule (e.g., OneFlipDiffusionSchedule).
        vocab_size: Number of possible digit values, e.g. 10 for digits [0..9].
        device: Torch device ("cuda" or "cpu").
        zero_bias: Probability of flipping to zero vs a random digit in blank cells.
        bias_increment: Multiplicative factor to increase zero_bias each time we flip to a non-zero digit.
    
    Returns:
        x_t: (batch, 81) The new noised state (x_{t}) after flipping up to `schedule[t]` digits.
    """
    puzzle = puzzle.to(device)
    x_prev = x_prev.to(device)

    # Number of flips determined by the schedule
    num_flips = int(schedule[t])
    batch_size, puzzle_size = x_prev.shape
    
    # Start x_t from x_{t-1}
    x_t = x_prev.clone()

    for b in range(batch_size):
        # Randomly select positions to flip
        flip_indices = torch.randperm(puzzle_size, device=device)[:num_flips]

        current_zero_bias = zero_bias  # Probability of flipping to zero

        for idx in flip_indices:
            if puzzle[b, idx] != 0:
                # Given digits: flip very rarely (e.g., 1% chance)
                if random.random() > 0.99:
                    x_t[b, idx] = random.randint(1, vocab_size - 1)
            else:
                # Blank cells: flip to zero or random digit
                if random.random() < current_zero_bias:
                    x_t[b, idx] = 0
                else:
                    x_t[b, idx] = random.randint(1, vocab_size - 1)
                    # Increase zero_bias after flipping to a nonzero digit
                    current_zero_bias = min(1.0, current_zero_bias * bias_increment)

    return x_t

In [14]:
def generate_forward_diffusion_path(puzzle, solution, schedule, vocab_size, device):
    """
    Generate the entire forward-diffusion chain:
       x_0, x_1, ..., x_T
    storing each x_t so that we can sample them consistently.

    puzzle, solution: (batch, 81)
    schedule: (e.g., OneFlipDiffusionSchedule)
    vocab_size: e.g. 10
    device: 'cuda' or 'cpu'

    Returns: A list [x_0, x_1, ..., x_T], each of shape (batch, 81)
    """
    # x_0 = the clean solution (we assume x_0 is solution)
    # puzzle is used by forward_diffusion_mixed to avoid flipping givens.

    # Make sure they're on the right device
    puzzle = puzzle.to(device)
    solution = solution.to(device)

    x_current = solution.clone()  # start from x_0 = solution
    path = [x_current]            # store x_0 in the path

    for step in range(1, schedule.T + 1):
        # x_step = forward_diffusion_mixed(...) at time = step
        # This flips schedule[step] positions in x_current,
        # but also respects puzzle givens.
        x_next = forward_diffusion_mixed(
            puzzle,
            x_current,      # note: we use x_current here, not 'solution'
            step,
            schedule,
            vocab_size,
            device
        )
        path.append(x_next)
        x_current = x_next
    
    return path

In [13]:
# ------------------------------------------------------------------------
# 3) The model sees puzzle+partially noised solution as input
# ------------------------------------------------------------------------
class PuzzleDenoiser(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers, max_T=40):
        super().__init__()
        self.vocab_size = vocab_size
        
        # Standard embeddings
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        
        # 1) Time embedding for t in [1..max_T]
        self.time_embedding = nn.Embedding(max_T + 1, embed_dim)
        
        # Positional embedding for puzzle+solution sequence (learned initialization)
        self.pos_embedding = nn.Parameter(torch.zeros(1, 162, embed_dim))
        nn.init.normal_(self.pos_embedding, mean=0.0, std=0.02)  # Custom initialization
        
        # Layer norm after positional embeddings
        self.post_pos_norm = nn.LayerNorm(embed_dim)
        
        # Transformer
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=num_heads,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        
        # Output layer
        self.output_layer = nn.Linear(embed_dim, vocab_size)

    def forward(self, puzzle, x_t, t):
        """
        puzzle, x_t: (batch, 81)
        t: (batch,) integer steps in [1..max_T], describing which noising step
        """
        batch_size = puzzle.size(0)
        
        # Concatenate [puzzle || x_t] => shape (batch, 162)
        inp = torch.cat([puzzle, x_t], dim=1)  # (batch, 162)
        
        # Token embeddings for puzzle + x_t
        emb = self.embedding(inp)  # => (batch, 162, embed_dim)
        
        # 2) Time embedding => Broadcast time embedding across all tokens
        t_emb = self.time_embedding(t).unsqueeze(1)  # (batch, 1, embed_dim)
        t_broadcast = t_emb.expand(-1, 162, -1)  # Broadcast to match sequence length
        emb = emb + t_broadcast
        
        # 3) Add positional embeddings
        cat_emb = emb + self.pos_embedding[:, :162, :]  # Slice positional embeddings to match sequence length
        
        # Apply layer normalization after positional embeddings
        cat_emb = self.post_pos_norm(cat_emb)
        
        # Pass through Transformer
        enc_out = self.encoder(cat_emb)  # => (batch, 162, embed_dim)
        
        # Apply LayerNorm after Transformer (optional)
        enc_out = self.post_pos_norm(enc_out)
        
        # Output layer => (batch, 162, vocab_size)
        logits = self.output_layer(enc_out)
        
        return logits



In [24]:
# ------------------------------------------------------------------------
# 4) Diffusion train step
# ------------------------------------------------------------------------
def compute_subgoal_weights(puzzle, solution, vocab_size=10):
    """
    Computes a difficulty-based weight for each cell in the puzzle:
      - If puzzle[i] != 0, it's a given => typically weight = 0 or minimal (since it's known).
      - If puzzle[i] == 0, compute how many digits (1..9) are valid in that cell
        given Sudoku constraints. The more valid candidates => the harder the subgoal => higher weight.

    Returns:
        weights: A tensor of shape (batch, 81) with floating-point weights.
    """

    batch_size = puzzle.size(0)
    weights = torch.zeros_like(solution, dtype=torch.float32)

    # We'll iterate over each puzzle in the batch
    for b in range(batch_size):
        # puzzle[b]: shape (81,)
        # solution[b]: shape (81,)

        # Convert puzzle[b] into a 9x9 grid for easier row/col/box indexing
        puzzle_grid = puzzle[b].view(9, 9).cpu().numpy()  # shape (9,9), on CPU
        # We'll also want to compute constraints for each row, col, box
        # but let's do it cell-by-cell.

        for idx in range(81):
            r = idx // 9  # row
            c = idx % 9   # column

            if puzzle_grid[r, c] != 0:
                # It's a given => we can optionally set weight to 0 or a small value
                weights[b, idx] = 0.0
            else:
                # It's blank => compute how many valid digits remain
                row_vals = set(puzzle_grid[r, :].tolist())
                col_vals = set(puzzle_grid[:, c].tolist())

                # Identify which 3x3 box (by top-left corner)
                box_row = (r // 3) * 3
                box_col = (c // 3) * 3
                box_vals = set(
                    puzzle_grid[box_row:box_row+3, box_col:box_col+3].reshape(-1).tolist()
                )

                # Givens can be 1..9, ignoring 0 (blank)
                used_vals = (row_vals | col_vals | box_vals) - {0}
                # valid digits are those in [1..9] not in used_vals
                all_digits = set(range(1, vocab_size))  # {1,2,...,9} if vocab=10
                valid_candidates = all_digits - used_vals

                num_candidates = len(valid_candidates)

                # Weight logic:
                # e.g., let weight = num_candidates
                # or weight = 1 + num_candidates, or scale by some factor
                # The bigger the number of candidates => the bigger the weight
                weights[b, idx] = float(num_candidates)

    # Optionally normalize weights per puzzle
    weights = weights / (weights.max(dim=1, keepdim=True)[0].clamp(min=1.0) + 1e-8)

    return weights

def diffusion_train_step(
    model,
    puzzle,
    solution,
    schedule,
    optimizer,
    vocab_size,
    device,
    loss_fn,
    scaler  # GradScaler instance for mixed precision
):
    model.train()

    puzzle = puzzle.to(device)
    solution = solution.to(device)

    # 1) Generate the entire chain: [x_0, x_1, ..., x_T]
    path = generate_forward_diffusion_path(
        puzzle, solution, schedule, vocab_size, device
    )
    # path[i] is x_i, shape (batch, 81)

    # 2) Pick random t in [1..T]
    T = schedule.T
    t_int = np.random.randint(1, T + 1)
    t_tensor = torch.tensor([t_int] * puzzle.size(0), device=device)
    
    # x_t and x_{t-1} from the chain
    x_t = path[t_int]
    x_t_minus_1 = path[t_int - 1]

    # 3) Forward pass
    with torch.cuda.amp.autocast():  # Enable mixed precision
        logits = model(puzzle, x_t, t_tensor)  # => (batch, 162, vocab_size)

        # puzzle is shape (batch, 81)
        # x_t is shape (batch, 81)
        # so logits has shape (batch, 162, vocab_size)
        logits_solution_part = logits[:, 81:, :]  # (batch, 81, vocab_size)

        # 4) CE loss with target = x_{t-1}
        ce_loss = loss_fn(
            logits_solution_part.reshape(-1, vocab_size),
            x_t_minus_1.reshape(-1)
        )

    # 5) Scale loss and backprop with GradScaler
    optimizer.zero_grad()
    scaler.scale(ce_loss).backward()
    scaler.step(optimizer)
    scaler.update()

    return ce_loss  # Return tensor without calling .item()


In [16]:
# ------------------------------------------------------------------------
# 5) Validation step
# ------------------------------------------------------------------------
@torch.no_grad()
def diffusion_eval_step(
    model, 
    puzzle,
    solution,
    schedule,
    loss_fn,
    vocab_size,
    device
):
    """
    Performs a forward pass during evaluation with time conditioning.

    Args:
        model: The PuzzleDenoiser model with time conditioning.
        puzzle: Tensor of shape (batch, 81) with puzzle digits (0 for blanks).
        solution: Tensor of shape (batch, 81) with solution digits.
        schedule: The diffusion schedule object containing alphas and T.
        loss_fn: The loss function, e.g., nn.CrossEntropyLoss().
        vocab_size: Size of the vocabulary (digits 0-9 => 10).
        device: torch.device to perform computations on.

    Returns:
        loss: The evaluation loss as a float.
    """
    model.eval()
    puzzle = puzzle.to(device)
    solution = solution.to(device)

    # 1. Sample a random diffusion step t for the entire batch
    T = schedule.T
    t_int = np.random.randint(1, T + 1)  # Sample t in [1, T]
    # Create a tensor of shape (batch_size,) filled with t_int
    t_tensor = torch.full((puzzle.size(0),), t_int, dtype=torch.long, device=device)

    # 2. Create x_t with forward diffusion
    x_t = forward_diffusion_with_puzzle(
        puzzle, solution, t_int, schedule, vocab_size, device
    )

    # 3. Forward pass with time conditioning
    logits = model(puzzle, x_t, t_tensor)  # shape: (batch, 162, vocab_size)

    # 4. Slice out the solution part
    logits_solution_part = logits[:, 81:, :]  # Include only solution tokens

    # 5. Compute cross-entropy loss
    loss = loss_fn(
        logits_solution_part.reshape(-1, vocab_size),  # (batch*81, vocab_size)
        solution.reshape(-1)                           # (batch*81,)
    )
    
    return loss.item()

In [17]:
# ------------------------------------------------------------------------
# 6) Iterative decoding to fill blank cells
# ------------------------------------------------------------------------
@torch.no_grad()
def iterative_decode(model, puzzle, schedule, vocab_size, device):
    model.eval()
    
    puzzle = puzzle.to(device)
    batch_size = puzzle.size(0)
    T = schedule.T

    # Start from fully-random x_T (or partially random)
    # e.g. fill blank positions with random digits, keep givens fixed
    x_t = puzzle.clone()
    blank_mask = (puzzle == 0)
    x_t[blank_mask] = torch.randint(
        0, vocab_size, 
        (batch_size, blank_mask.sum(dim=1)[0].item()), # or simply x_t[blank_mask].shape
        device=device
    )

    for curr_t in range(T, 0, -1):
        t_tensor = torch.full((batch_size,), curr_t, dtype=torch.long, device=device)
        
        logits = model(puzzle, x_t, t_tensor)
        logits_solution_part = logits[:, 81:, :]
        pred_x_t_minus_1 = logits_solution_part.argmax(dim=-1)  # (batch, 81)
        
        # Update only blank cells
        x_t[blank_mask] = pred_x_t_minus_1[blank_mask]

    return x_t  # hopefully denoised solution

In [18]:
@torch.no_grad()
def validate_combined(
    model, 
    schedule,
    loader,         # DataLoader for validation
    device,
    loss_fn,
    vocab_size=10
):
    """
    Validates with:
      - Cross-entropy loss (predicting x_{t-1} from x_t)
      - Solve rate (iterative decode to x_0)
      - Token-level accuracy for filled-in cells
    """
    model.eval()

    total_loss = 0.0
    total_samples = 0
    solved_count = 0
    correct_token_count = 0
    total_token_count = 0

    with torch.cuda.amp.autocast():
        for puzzles, solutions in tqdm(loader, desc="Validation"):
            puzzles = puzzles.to(device)
            solutions = solutions.to(device)
            batch_size = puzzles.size(0)

            # 1) Generate entire chain [x_0, x_1, ..., x_T]
            #    Here x_0 = solutions, or puzzle + filled blanks, 
            #    depending on how you handle x_0 in training.
            chain = generate_forward_diffusion_path(
                puzzle=puzzles,
                solution=solutions,   # or x_0 if your approach differs
                schedule=schedule,
                vocab_size=vocab_size,
                device=device
            )
            # chain[t] => x_t

            # 2) Pick random t in [1..T]
            t_int = np.random.randint(1, schedule.T + 1)
            t_tensor = torch.full((batch_size,), t_int, dtype=torch.long, device=device)

            # 3) Retrieve x_t and x_{t-1} from the chain
            x_t         = chain[t_int]
            x_t_minus_1 = chain[t_int - 1]

            # 4) Forward pass
            logits = model(puzzles, x_t, t_tensor)  # => (batch, 162, vocab_size)
            logits_solution_part = logits[:, 81:, :]  # (batch, 81, vocab_size)

            # 5) Loss vs x_{t-1}, not the final solution
            loss = loss_fn(
                logits_solution_part.reshape(-1, vocab_size),
                x_t_minus_1.reshape(-1)
            )
            
            total_loss += loss.item() * batch_size
            total_samples += batch_size

            # 6) Iterative decode for solve rate & token accuracy
            x_filled = iterative_decode(
                model=model,
                puzzle=puzzles,   # puzzle givens
                schedule=schedule,
                vocab_size=vocab_size,
                device=device
            )
            
            # 7) Solve rate: how many boards are 100% correct
            eq_mask = (x_filled == solutions)  # (batch, 81)
            batch_solved = eq_mask.all(dim=1).sum().item()
            solved_count += batch_solved

            # 8) Token accuracy on masked cells
            masked_mask = (puzzles == 0)               # (batch, 81)
            correct_masked = eq_mask & masked_mask      # Correct predictions on blank cells
            correct_token_count += correct_masked.sum().item()
            total_token_count += masked_mask.sum().item()

    avg_loss = total_loss / total_samples
    solve_rate = solved_count / total_samples
    token_acc = (correct_token_count / total_token_count) if total_token_count > 0 else 0.0

    return avg_loss, solve_rate, token_acc


In [27]:
# ------------------------------------------------------------------------
# 7) Putting it all together: example training loop
# ------------------------------------------------------------------------
def train_puzzle_diffusion(
    X_train, y_train, 
    X_val,   y_val,
    vocab_size=10,
    T=40,
    embed_dim=512,
    num_heads=8,
    num_layers=8,
    batch_size=64,
    num_epochs=50,
    best_model_path=None
):
    """
    Trains the PuzzleDenoiser model with time-conditioned diffusion.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    schedule = SigmoidDiffusionSchedule(
        T=T, 
        min_alpha=0.02, 
        max_alpha=0.7,
        k=12.0
    )

    # Handle multi-GPU setup
    multi_gpu = torch.cuda.device_count() > 1
    if multi_gpu:
        print(f"Using {torch.cuda.device_count()} GPUs!")
    else:
        print("Using 1 GPU or CPU.")

    # Initialize the model with max_T matching the schedule
    model = PuzzleDenoiser(
        vocab_size=vocab_size, 
        embed_dim=embed_dim, 
        num_heads=num_heads, 
        num_layers=num_layers,
        max_T=T
    ).to(device)

    # Load best model if provided (handling DataParallel)
    if best_model_path is not None:
        state_dict = torch.load(best_model_path, map_location=device)
        
        if multi_gpu and not list(state_dict.keys())[0].startswith('module.'):
            model.load_state_dict(state_dict)
            model = nn.DataParallel(model)
        elif not multi_gpu and list(state_dict.keys())[0].startswith('module.'):
            new_state_dict = {k[7:]: v for k, v in state_dict.items()}
            model.load_state_dict(new_state_dict)
        else:
            model.load_state_dict(state_dict)
            
        print(f"Loaded model from {best_model_path}")
    elif multi_gpu:
        model = nn.DataParallel(model)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    # Add learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=3,
        verbose=True
    )
    loss_fn = nn.CrossEntropyLoss()
    scaler = torch.cuda.amp.GradScaler()

    # Setup data loaders
    train_dataset = TensorDataset(X_train, y_train)
    val_dataset   = TensorDataset(X_val,   y_val)
    train_loader  = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader    = DataLoader(val_dataset,   batch_size=batch_size)

    best_val_loss = float('inf')

    # Initial validation before training
    if False:
        print("Running initial validation (Epoch 0):")
        avg_val_loss, val_solve_rate, val_token_acc = validate_combined(
            model=model,
            schedule=schedule,
            loader=val_loader,
            device=device,
            loss_fn=loss_fn,
            vocab_size=vocab_size
        )
        print(f"Initial val_loss={avg_val_loss:.4f}")
        print(f"Initial solve rate on validation set: {val_solve_rate*100:.2f}%")
        print(f"Initial token-level accuracy on validation set: {val_token_acc*100:.2f}%")

    for epoch in range(num_epochs):
        model.train()
        train_losses = []
        for batch_puzzle, batch_solution in tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]"):
            # Use autocast for mixed precision training
            with torch.cuda.amp.autocast():
                loss_val = diffusion_train_step(
                    model=model,
                    puzzle=batch_puzzle,
                    solution=batch_solution,
                    schedule=schedule,
                    optimizer=optimizer,
                    loss_fn=loss_fn,
                    vocab_size=vocab_size,
                    device=device,
                    scaler=scaler
                )
            
            train_losses.append(loss_val.item())
            
        avg_train_loss = np.mean(train_losses)

        # Validation
        avg_val_loss, val_solve_rate, val_token_acc = validate_combined(
            model=model,
            schedule=schedule,
            loader=val_loader,
            device=device,
            loss_fn=loss_fn,
            vocab_size=vocab_size
        )

        print(f"Epoch {epoch+1}: train_loss={avg_train_loss:.4f}, val_loss={avg_val_loss:.4f}")
        print(f"Solve rate on validation set: {val_solve_rate*100:.2f}%")
        print(f"Token-level accuracy on validation set: {val_token_acc*100:.2f}%")

        # Update learning rate scheduler
        scheduler.step(avg_val_loss)

        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), "puzzle_diffuser_best.pt")
            print(f"  [*] Best model saved @ val_loss={avg_val_loss:.4f}")

    return model, schedule


In [10]:
# Load and preprocess data
import pandas as pd
from sklearn.model_selection import train_test_split

# Load sudoku data
df = pd.read_csv('./data/sudoku.csv')

# Convert strings to tensors
def preprocess_sudoku(puzzle_str):
    # Convert string to list of integers and then to tensor
    return torch.tensor([int(d) for d in puzzle_str], dtype=torch.long)

# Convert all puzzles and solutions
puzzles = torch.stack([preprocess_sudoku(p) for p in df['quizzes']])
solutions = torch.stack([preprocess_sudoku(s) for s in df['solutions']])

# Karpathy split (90/5/5)
train_size = 0.9
val_size = 0.05
test_size = 0.05

# First split into train and temp
X_train, X_temp, y_train, y_temp = train_test_split(
    puzzles, solutions, train_size=train_size, random_state=42
)

# Split temp into val and test
X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.5, random_state=42
)

In [None]:
model, schedule = train_puzzle_diffusion(
    X_train, y_train,
    X_val,   y_val,
    vocab_size=10,
    T=40,
    embed_dim=512,
    num_heads=8,
    num_layers=8,
    batch_size=64,
    num_epochs=50
    #best_model_path="puzzle_diffuser_best.pt"
)

In [28]:
model, schedule = train_puzzle_diffusion(
    X_train, y_train,
    X_val,   y_val,
    vocab_size=10,
    T=80,
    embed_dim=256,
    num_heads=4,
    num_layers=4,
    batch_size=64,
    num_epochs=100
    #best_model_path="puzzle_diffuser_best.pt"
)

  scaler = torch.cuda.amp.GradScaler()


Using device: cuda
Using 2 GPUs!


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():  # Enable mixed precision
Epoch 1 [Train]:   1%|▏         | 208/14063 [01:48<1:59:42,  1.93it/s]

In [11]:
#############################################
# 2) CREATE THE NEW SCHEDULE
#############################################
schedule = OneFlipDiffusionSchedule(
    T=80
)

#############################################
# 3) PICK A SAMPLE PUZZLE & SOLUTION
#############################################
sample_puzzle = X_train[0:1]
sample_solution = y_train[0:1]

print("Puzzle:", sample_puzzle)
print("Solution:", sample_solution)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#############################################
# 4) SHOW x_1, x_2, ... x_T
#############################################
vocab_size = 10  # digits 0-9
for t in range(1, schedule.T + 1):
    x_t = forward_diffusion_mixed(X_train[:1], y_train[:1], t, schedule, vocab_size=10, device=device, bias_increment=1.3)
    print(f"x_t at t={t}:", x_t)

Puzzle: tensor([[0, 4, 2, 0, 0, 9, 0, 7, 5, 9, 0, 0, 0, 0, 7, 0, 0, 3, 3, 0, 5, 6, 1, 0,
         9, 0, 0, 0, 0, 4, 9, 7, 8, 0, 0, 6, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 1, 8,
         4, 0, 6, 2, 0, 0, 0, 6, 7, 0, 0, 0, 4, 3, 8, 8, 0, 0, 7, 0, 0, 0, 9, 2,
         0, 0, 0, 0, 5, 0, 0, 0, 0]])
Solution: tensor([[1, 4, 2, 3, 8, 9, 6, 7, 5, 9, 8, 6, 5, 4, 7, 1, 2, 3, 3, 7, 5, 6, 1, 2,
         9, 8, 4, 2, 5, 4, 9, 7, 8, 3, 1, 6, 6, 9, 3, 1, 2, 5, 8, 4, 7, 7, 1, 8,
         4, 3, 6, 2, 5, 9, 5, 6, 7, 2, 9, 1, 4, 3, 8, 8, 3, 1, 7, 6, 4, 5, 9, 2,
         4, 2, 9, 8, 5, 3, 7, 6, 1]])
x_t at t=1: tensor([[1, 4, 2, 3, 8, 9, 6, 7, 5, 9, 8, 6, 5, 4, 7, 1, 2, 3, 3, 7, 5, 6, 1, 2,
         9, 8, 4, 2, 5, 4, 9, 7, 8, 3, 1, 6, 6, 9, 3, 1, 2, 5, 8, 4, 7, 7, 1, 8,
         4, 3, 6, 2, 5, 9, 5, 6, 7, 2, 9, 1, 4, 3, 8, 8, 3, 1, 7, 6, 4, 5, 9, 2,
         4, 2, 9, 8, 5, 3, 7, 6, 1]], device='cuda:0')
x_t at t=2: tensor([[1, 4, 2, 3, 8, 9, 6, 7, 5, 9, 8, 6, 5, 4, 7, 1, 6, 3, 3, 7, 5, 6, 1, 2,
         9, 8, 4,

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Then to do iterative decoding:
puzzle_batch = X_val[0:2]  # for example
x_filled = iterative_decode(
    model, 
    puzzle=puzzle_batch, 
    schedule=schedule, 
    vocab_size=10,
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
print("Puzzle: ", puzzle_batch)
print("Decoded solution: ", x_filled)
print("Original solution: ", y_val[0:2])

NameError: name 'iterative_decode' is not defined

In [None]:
accuracy = measure_sudoku_solve_rate(
    model,
    schedule,
    X_val[0:100],
    y_val[0:100],
    device=device,
    vocab_size=10,
    batch_size=32
)

print(f"Solve rate on validation set: {accuracy*100:.2f}%")

In [None]:
# Create a small validation dataset for testing
X_val_test = X_val[:1000]
y_val_test = y_val[:1000]
test_val_loader = DataLoader(
    TensorDataset(X_val_test, y_val_test),
    batch_size=32,
    shuffle=False
)

# Initialize loss function
loss_fn = nn.CrossEntropyLoss()

# Run validation
avg_val_loss, val_solve_rate, val_token_acc = validate_combined(
    model=model,
    schedule=schedule, 
    loader=test_val_loader,
    device=device,
    loss_fn=loss_fn,
    vocab_size=vocab_size
)

print(f"Validation loss: {avg_val_loss:.4f}")
print(f"Solve rate on validation set: {val_solve_rate*100:.2f}%")
print(f"Token-level accuracy on validation set: {val_token_acc*100:.2f}%")