In [1]:
import torch
from transformers import AutoModel, AutoTokenizer

from conn import DeBERTaEncoder

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME = "microsoft/deberta-v3-small"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()
encoder = DeBERTaEncoder(model, tokenizer, DEVICE)

print("device:", DEVICE)
print("model:", MODEL_NAME)

  from .autonotebook import tqdm as notebook_tqdm


device: cpu
model: microsoft/deberta-v3-small


In [2]:
from conn import load_connections_from_hf, gold_groups_from_row

hf_split = load_connections_from_hf()
print("puzzles:", len(hf_split))

K = 20
example_groups = []
for i in range(min(K, len(hf_split))):
    row = hf_split[i]
    for g in gold_groups_from_row(row):
        if len(g) == 4:
            example_groups.append(g)
print("example groups:", len(example_groups))

puzzles: 652
example groups: 80


In [3]:
from conn import solve_puzzle_few_shot


def solve_puzzle(words16):
    return solve_puzzle_few_shot(words16, encoder, example_groups, alpha=0.5)


row0 = hf_split[0]
words16 = row0["words"]
print("Puzzle date:", row0.get("date"))
print("All words:", words16)

pred_groups = solve_puzzle(words16)
print("\nPredicted groups:")
for g in pred_groups:
    print(g)

print("\nGold groups:")
for ans in row0["answers"]:
    print(ans["answerDescription"], "->", ans["words"])

Puzzle date: 2024-06-03 00:00:00
All words: ['LASER', 'PLUCK', 'THREAD', 'WAX', 'COIL', 'SPOOL', 'WIND', 'WRAP', 'HONEYCOMB', 'ORGANISM', 'SOLAR PANEL', 'SPREADSHEET', 'BALL', 'MOVIE', 'SCHOOL', 'VITAMIN']

Predicted groups:
['LASER', 'COIL', 'SPOOL', 'WRAP']
['SPREADSHEET', 'BALL', 'MOVIE', 'SCHOOL']
['PLUCK', 'WAX', 'WIND', 'HONEYCOMB']
['THREAD', 'ORGANISM', 'SOLAR PANEL', 'VITAMIN']

Gold groups:
REMOVE, AS BODY HAIR -> ['LASER', 'PLUCK', 'THREAD', 'WAX']
TWIST AROUND -> ['COIL', 'SPOOL', 'WIND', 'WRAP']
THINGS MADE OF CELLS -> ['HONEYCOMB', 'ORGANISM', 'SOLAR PANEL', 'SPREADSHEET']
B-___ -> ['BALL', 'MOVIE', 'SCHOOL', 'VITAMIN']


In [4]:
from conn import (
    accuracy_min_swaps,
    accuracy_zero_one,
    evaluate,
    gold_groups_from_row,
)

N_EVAL = 100
acc, n_eval = evaluate(
    hf_split,
    metric_fn=accuracy_zero_one,
    solver_fn=solve_puzzle,
    max_samples=N_EVAL,
    gold_from_row=gold_groups_from_row,
)
mean_swaps, _ = evaluate(
    hf_split,
    metric_fn=accuracy_min_swaps,
    solver_fn=solve_puzzle,
    max_samples=N_EVAL,
    gold_from_row=gold_groups_from_row,
)
print(f"Zero-one accuracy: {acc:.4f}  (n={n_eval}, requested={N_EVAL})")
print(f"Mean 1-1 swaps to correct: {mean_swaps:.2f}  (n={n_eval})")

Zero-one accuracy: 0.0000  (n=100, requested=100)
Mean 1-1 swaps to correct: 3.77  (n=100)
