In [5]:
import torch
import torch.nn as nn
import torch.optim as optim

In [2]:
def corrupt_sequence(sequence, noise_level, vocab_size):
    noise = torch.randint(0, vocab_size, sequence.shape)  # Random tokens
    mask = torch.rand(sequence.shape) < noise_level       # Noise mask
    corrupted = torch.where(mask, noise, sequence)        # Apply noise
    return corrupted

In [16]:
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, num_layers):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads,batch_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.output_layer = nn.Linear(embed_dim, vocab_size)

    def forward(self, x):
        x = self.embedding(x)  # Embed tokens
        x = self.encoder(x)    # Pass through transformer
        return self.output_layer(x)  # Predict clean tokens

In [17]:
vocab_size = 10  # digits 0-9

In [18]:
model = TransformerEncoder(vocab_size=vocab_size, embed_dim=256, num_heads=8, num_layers=6)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()

def train_step(clean_sequence, model, optimizer, loss_fn, vocab_size):
    model.train()
    optimizer.zero_grad()

    # Generate noisy input
    noise_level = torch.rand(1).item()  # Random noise level
    corrupted_sequence = corrupt_sequence(clean_sequence, noise_level, vocab_size)

    # Predict denoised sequence
    logits = model(corrupted_sequence)
    loss = loss_fn(logits.view(-1, vocab_size), clean_sequence.view(-1))
    loss.backward()
    optimizer.step()

    return loss.item()

In [9]:
# Load and preprocess data
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

# Load sudoku data
df = pd.read_csv('./data/sudoku.csv')

# Convert strings to tensors
def preprocess_sudoku(puzzle_str):
    # Convert string to list of integers and then to tensor
    return torch.tensor([int(d) for d in puzzle_str], dtype=torch.long)

# Convert all puzzles and solutions
puzzles = torch.stack([preprocess_sudoku(p) for p in df['quizzes']])
solutions = torch.stack([preprocess_sudoku(s) for s in df['solutions']])

# Karpathy split (90/5/5)
train_size = 0.9
val_size = 0.05
test_size = 0.05

# First split into train and temp
X_train, X_temp, y_train, y_temp = train_test_split(
    puzzles, solutions, train_size=train_size, random_state=42
)

# Split temp into val and test
X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.5, random_state=42
)


In [19]:
X_val[0], y_val[0]

(tensor([8, 0, 2, 5, 0, 0, 6, 0, 0, 4, 0, 5, 1, 0, 0, 8, 9, 0, 0, 0, 0, 9, 0, 4,
         0, 3, 0, 7, 0, 0, 3, 0, 9, 5, 0, 1, 3, 4, 0, 0, 0, 0, 7, 2, 0, 0, 8, 0,
         0, 0, 6, 9, 0, 0, 0, 6, 3, 7, 1, 0, 4, 5, 0, 0, 0, 0, 0, 2, 0, 0, 0, 8,
         1, 7, 0, 0, 0, 0, 3, 0, 0]),
 tensor([8, 9, 2, 5, 3, 7, 6, 1, 4, 4, 3, 5, 1, 6, 2, 8, 9, 7, 6, 1, 7, 9, 8, 4,
         2, 3, 5, 7, 2, 6, 3, 4, 9, 5, 8, 1, 3, 4, 9, 8, 5, 1, 7, 2, 6, 5, 8, 1,
         2, 7, 6, 9, 4, 3, 2, 6, 3, 7, 1, 8, 4, 5, 9, 9, 5, 4, 6, 2, 3, 1, 7, 8,
         1, 7, 8, 4, 9, 5, 3, 6, 2]))

In [14]:

# Training loop
num_epochs = 100
batch_size = 32


# Create data loaders
from torch.utils.data import TensorDataset, DataLoader
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataset = TensorDataset(X_val, y_val)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

# Training loop with validation
from tqdm.auto import tqdm

best_val_loss = float('inf')
for epoch in range(num_epochs):
    # Training
    model.train()
    train_losses = []
    
    for batch_puzzles, batch_solutions in tqdm(train_loader, desc=f'Epoch {epoch+1}'):
        loss = train_step(batch_solutions, model, optimizer, loss_fn, vocab_size)
        train_losses.append(loss)
    
    avg_train_loss = sum(train_losses) / len(train_losses)
    
    # Validation
    model.eval()
    val_losses = []
    with torch.no_grad():
        for batch_puzzles, batch_solutions in val_loader:
            noise_level = torch.rand(1).item()
            corrupted = corrupt_sequence(batch_solutions, noise_level, vocab_size)
            logits = model(corrupted)
            loss = loss_fn(logits.view(-1, vocab_size), batch_solutions.view(-1))
            val_losses.append(loss.item())
    
    avg_val_loss = sum(val_losses) / len(val_losses)
    
    # Print metrics
    print(f'Epoch {epoch+1}:')
    print(f'  Training Loss: {avg_train_loss:.4f}')
    print(f'  Validation Loss: {avg_val_loss:.4f}')
    
    # Save best model
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), 'best_model.pt')


Epoch 1:   0%|          | 0/28125 [00:00<?, ?it/s]


ValueError: Expected input batch_size (1296000) to match target batch_size (2592).