In [22]:
import json
from dataclasses import dataclass
from functools import lru_cache

import torch
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME = "microsoft/deberta-v3-small"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()

print("device:", DEVICE)
print("model:", MODEL_NAME)



device: cpu
model: microsoft/deberta-v3-small


In [23]:
def _mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor, special_tokens_mask=None) -> torch.Tensor:
    mask = attention_mask.bool()
    if special_tokens_mask is not None:
        mask = mask & (~special_tokens_mask.bool().to(mask.device))
    if mask.sum() == 0:
        mask = attention_mask.bool()
    x = last_hidden_state[0][mask[0]]
    return x.mean(dim=0)


@lru_cache(maxsize=10000)
def embed_phrase(phrase: str) -> torch.Tensor:
    phrase = phrase.strip().lower()
    inputs = tokenizer(
        phrase,
        return_tensors="pt",
        truncation=True,
        max_length=32,
        return_special_tokens_mask=True,
    )
    special_tokens_mask = inputs.pop("special_tokens_mask", None)
    inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
    with torch.no_grad():
        out = model(**inputs)
    if special_tokens_mask is not None:
        special_tokens_mask = special_tokens_mask.to(DEVICE)
    vec = _mean_pool(out.last_hidden_state, inputs["attention_mask"], special_tokens_mask).float().cpu()
    return vec / (vec.norm(p=2) + 1e-12)


def rank_candidates_by_similarity(triple: list[str], candidates: list[str]) -> list[str]:
    triple_vecs = [embed_phrase(w) for w in triple]
    centroid = torch.stack(triple_vecs, dim=0).mean(dim=0)
    centroid = centroid / (centroid.norm(p=2) + 1e-12)

    scored = []
    for w in candidates:
        v = embed_phrase(w)
        scored.append((float(torch.dot(centroid, v)), w))

    scored.sort(key=lambda t: t[0], reverse=True)
    return [w for _, w in scored]


def leave_one_out_queries(words4: list[str]) -> list[dict]:
    if len(words4) != 4:
        return []
    out = []
    for i in range(4):
        answer = words4[i]
        triple = [words4[j] for j in range(4) if j != i]
        out.append({"triple": triple, "answer": answer, "candidates": list(words4)})
    return out

In [29]:
from datasets import load_dataset


def load_connections_from_hf(split: str = "train"):
    """Load NYT Connections from Hugging Face (tm21cy/NYT-Connections) and return a split.

    The split has columns like: date, contest, words (16), answers (4 groups with description + words).
    """
    ds = load_dataset("tm21cy/NYT-Connections")
    if split not in ds:
        split = list(ds.keys())[0]
    return ds[split]



In [30]:
hf_split = load_connections_from_hf()
print("puzzles:", len(hf_split))

puzzles: 652


In [31]:
from itertools import combinations
from torch.nn.functional import normalize


def group_similarity(embeddings: torch.Tensor) -> float:
    """Average pairwise cosine similarity inside a group (embeddings: [4, hidden])."""
    X = normalize(embeddings, dim=1)
    sims = X @ X.T
    n = sims.size(0)
    mask = ~torch.eye(n, dtype=bool)
    sims = sims[mask]
    return float(sims.mean())


def solve_puzzle(words16: list[str]) -> list[list[str]]:
    """Given 16 words from one Connections puzzle, greedily form 4 groups of 4 by max avg similarity."""
    if len(words16) != 16:
        raise ValueError(f"Expected 16 words, got {len(words16)}")

    # Pre-embed all words once
    vecs = torch.stack([embed_phrase(w) for w in words16], dim=0)

    remaining = list(range(16))
    groups_idx: list[list[int]] = []

    for _ in range(3):  # first 3 groups; last group is whatever remains
        best_score = float("-inf")
        best_combo = None
        for combo in combinations(remaining, 4):
            emb = vecs[list(combo)]
            score = group_similarity(emb)
            if score > best_score:
                best_score = score
                best_combo = list(combo)
        groups_idx.append(best_combo)
        remaining = [i for i in remaining if i not in best_combo]

    groups_idx.append(remaining)
    return [[words16[i] for i in idxs] for idxs in groups_idx]


# Demo on the first puzzle
row0 = hf_split[0]
words16 = row0["words"]
print("Puzzle date:", row0.get("date"))
print("All words:", words16)

pred_groups = solve_puzzle(words16)
print("\nPredicted groups:")
for g in pred_groups:
    print(g)

print("\nGold groups:")
for ans in row0["answers"]:
    print(ans["answerDescription"], "->", ans["words"])

Puzzle date: 2024-06-03 00:00:00
All words: ['LASER', 'PLUCK', 'THREAD', 'WAX', 'COIL', 'SPOOL', 'WIND', 'WRAP', 'HONEYCOMB', 'ORGANISM', 'SOLAR PANEL', 'SPREADSHEET', 'BALL', 'MOVIE', 'SCHOOL', 'VITAMIN']

Predicted groups:
['LASER', 'WAX', 'COIL', 'SPOOL']
['SPREADSHEET', 'BALL', 'MOVIE', 'SCHOOL']
['PLUCK', 'WIND', 'HONEYCOMB', 'VITAMIN']
['THREAD', 'WRAP', 'ORGANISM', 'SOLAR PANEL']

Gold groups:
REMOVE, AS BODY HAIR -> ['LASER', 'PLUCK', 'THREAD', 'WAX']
TWIST AROUND -> ['COIL', 'SPOOL', 'WIND', 'WRAP']
THINGS MADE OF CELLS -> ['HONEYCOMB', 'ORGANISM', 'SOLAR PANEL', 'SPREADSHEET']
B-___ -> ['BALL', 'MOVIE', 'SCHOOL', 'VITAMIN']


In [35]:
from itertools import permutations


def _gold_groups_from_row(row) -> list[list[str]]:
    """Extract the 4 answer groups (each 4 words) from a puzzle row."""
    return [list(g.get("words", [])) for g in row.get("answers", []) if len(g.get("words", [])) == 4]


def _norm(g: list) -> frozenset:
    return frozenset(w.strip() for w in g)


def accuracy_zero_one(pred_groups: list[list[str]], gold_groups: list[list[str]]) -> float:
    """1.0 if predicted groups match gold exactly (as sets of 4 words), else 0.0."""
    if len(pred_groups) != 4 or len(gold_groups) != 4:
        return 0.0
    pred_sets = {_norm(g) for g in pred_groups}
    gold_sets = {_norm(g) for g in gold_groups}
    return 1.0 if pred_sets == gold_sets else 0.0


def accuracy_min_swaps(pred_groups: list[list[str]], gold_groups: list[list[str]]) -> float:
    """Minimum number of 1-for-1 word swaps between groups to turn predicted into gold.

    Matches pred groups to gold groups (best bijection), counts misplaced words, returns ceil(misplaced/2).
    """
    if len(pred_groups) != 4 or len(gold_groups) != 4:
        return float("inf")
    pred_sets = [_norm(g) for g in pred_groups]
    gold_sets = [_norm(g) for g in gold_groups]
    best_misplaced = 16
    for perm in permutations(range(4)):
        misplaced = 0
        for i in range(4):
            j = perm[i]
            misplaced += 4 - len(pred_sets[i] & gold_sets[j])
        best_misplaced = min(best_misplaced, misplaced)
    return (best_misplaced + 1) // 2  # min 1-1 swaps: each swap fixes 2 words

In [36]:
def evaluate(split, metric_fn=None, solver_fn=None, max_samples=None):
    """Run solver on every puzzle and aggregate metric. Plug in a custom metric via metric_fn.

    metric_fn(pred_groups, gold_groups) -> float (e.g. 0/1 or partial score).
    solver_fn(words16) -> list[list[str]] (default: solve_puzzle).
    """
    if metric_fn is None:
        metric_fn = accuracy_zero_one
    if solver_fn is None:
        solver_fn = solve_puzzle

    scores = []
    n = len(split) if max_samples is None else min(max_samples, len(split))
    for i in range(n):
        row = split[i]
        words16 = row.get("words", [])
        if len(words16) != 16:
            continue
        gold = _gold_groups_from_row(row)
        if len(gold) != 4:
            continue
        pred = solver_fn(words16)
        scores.append(metric_fn(pred, gold))

    return sum(scores) / len(scores) if scores else 0.0, len(scores)


N_EVAL = 10
acc, n_eval = evaluate(hf_split, metric_fn=accuracy_zero_one, max_samples=N_EVAL)
mean_swaps, _ = evaluate(hf_split, metric_fn=accuracy_min_swaps, max_samples=N_EVAL)
print(f"Zero-one accuracy: {acc:.4f}  (n={n_eval}, requested={N_EVAL})")
print(f"Mean 1-1 swaps to correct: {mean_swaps:.2f}  (n={n_eval})")

Zero-one accuracy: 0.0000  (n=10, requested=10)
Mean 1-1 swaps to correct: 3.90  (n=10)
