In [2]:
from dataset import generate_one_sample
from transformers import T5ForConditionalGeneration
import torch
from tqdm import tqdm

model = T5ForConditionalGeneration.from_pretrained("Somsung/t5-sudoku-solver")

n_non_empty_cells_range = range(6, 12)
n_samples_per_clue = 100  # Number of samples to test for each n_non_empty_cells value

for n_non_empty_cells in n_non_empty_cells_range:
    print(f"\nTesting with {n_non_empty_cells} clues...")
    
    correct_predictions = 0
    total_predictions = 0
    
    for i in tqdm(range(n_samples_per_clue), desc=f"Number of non-empty cells={n_non_empty_cells}"):
        input_ids, labels = generate_one_sample(n_clues=n_non_empty_cells)
        input_tensor = torch.tensor([input_ids]).to(model.device)
        
        # Generate prediction
        with torch.no_grad():
            output = model.generate(input_tensor, max_new_tokens=17, do_sample=False)
        
        # Check if prediction is correct
        predicted_solution = output[0][1:].tolist()  # Remove pad token
        if predicted_solution == labels:
            correct_predictions += 1
        
        total_predictions += 1
        
        # Calculate statistics
        accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
        
    print(f"Results for {n_non_empty_cells} non-empty cells:")
    print(f"  Accuracy: {correct_predictions}/{total_predictions} ({accuracy:.2%})")


Testing with 6 clues...


Number of non-empty cells=6: 100%|██████████| 100/100 [00:08<00:00, 11.77it/s]


Results for 6 non-empty cells:
  Accuracy: 85/100 (85.00%)

Testing with 7 clues...


Number of non-empty cells=7: 100%|██████████| 100/100 [00:08<00:00, 12.17it/s]


Results for 7 non-empty cells:
  Accuracy: 95/100 (95.00%)

Testing with 8 clues...


Number of non-empty cells=8: 100%|██████████| 100/100 [00:08<00:00, 12.19it/s]


Results for 8 non-empty cells:
  Accuracy: 95/100 (95.00%)

Testing with 9 clues...


Number of non-empty cells=9: 100%|██████████| 100/100 [00:08<00:00, 12.01it/s]


Results for 9 non-empty cells:
  Accuracy: 97/100 (97.00%)

Testing with 10 clues...


Number of non-empty cells=10: 100%|██████████| 100/100 [00:08<00:00, 11.92it/s]


Results for 10 non-empty cells:
  Accuracy: 100/100 (100.00%)

Testing with 11 clues...


Number of non-empty cells=11: 100%|██████████| 100/100 [00:08<00:00, 11.68it/s]

Results for 11 non-empty cells:
  Accuracy: 100/100 (100.00%)



