In [1]:
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME = "google-bert/bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
def _format_examples(examples, template="{words}."):
    """Format list of 4-word groups; template gets {words} = comma-separated words."""
    parts = []
    for group in examples:
        words = ", ".join(w.strip().lower() for w in group)
        parts.append(template.format(words=words))
    return " ".join(parts)


def _get_mask_logits(tokenizer, model, device, text_with_mask):
    """Return logits for the [MASK] position (shape [vocab_size])."""
    inputs = tokenizer(text_with_mask, return_tensors="pt", truncation=True, max_length=512).to(device)
    mask_pos = (inputs.input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
    if mask_pos.numel() == 0:
        return None, None
    with torch.no_grad():
        logits = model(**inputs).logits
    return logits[0, mask_pos[0]].cpu(), tokenizer


def few_shot_query(examples, query_triple, candidates=None, top_k=5, example_template="{words}.", query_suffix=", [MASK].", prompt_prefix=""):
    """
    Query the model with few-shot examples. Returns top predicted words for the fourth slot.
    """
    prefix = (prompt_prefix + " ") if prompt_prefix else ""
    prefix += _format_examples(examples, template=example_template) + " " if examples else ""
    words = ", ".join(w.strip().lower() for w in query_triple)
    query_str = prefix + words + query_suffix
    logits, tok = _get_mask_logits(tokenizer, model, DEVICE, query_str)
    if logits is None:
        return []

    if candidates is not None:
        cand_ids = []
        cand_words = []
        for w in candidates:
            ids = tok.encode(w.lower(), add_special_tokens=False)
            if ids and ids[0] != tok.unk_token_id:
                cand_ids.append(ids[0])
                cand_words.append(w)
        if not cand_ids:
            return []
        scores = logits[torch.tensor(cand_ids)]
        order = scores.argsort(descending=True)[:top_k]
        return [cand_words[i] for i in order.tolist()]

    return [tok.decode([tid]).strip() for tid in logits.topk(top_k, dim=-1).indices.tolist()]

In [None]:
def _build_prompt_prefix(preset):
    """Build the conversation-start prefix, substituting [word bank dict] with the preset word bank."""
    start = preset.get("conversation_start", "")
    if not start:
        return ""
    word_bank = preset.get("word_bank", [])
    word_bank_str = ", ".join(str(w) for w in word_bank) if word_bank else ""
    return start.replace("[word bank dict]", word_bank_str).replace("[word bank]", word_bank_str)


def run_conversation(preset):
    """
    Run all queries in a preset with the same conversation format and examples.
    preset: dict with keys:
      - conversation_start: first prompt (use [word bank dict] or [word bank] for word list)
      - word_bank: list of words to inject into conversation_start
      - format: dict with example_template, query_suffix (optional)
      - examples: list of 4-word lists
      - default_candidates, default_top_k (optional)
      - queries: list of {"triple": [w1,w2,w3], "candidates": optional, "top_k": optional}
    Returns list of results, one per query.
    """
    fmt = preset.get("format", {})
    example_tpl = fmt.get("example_template", "{words}.")
    query_suff = fmt.get("query_suffix", ", [MASK].")
    examples = preset.get("examples", [])
    default_candidates = preset.get("default_candidates")
    default_top_k = preset.get("default_top_k", 5)
    prompt_prefix = _build_prompt_prefix(preset)
    results = []
    for q in preset.get("queries", []):
        triple = q["triple"]
        candidates = q.get("candidates", default_candidates)
        top_k = q.get("top_k", default_top_k)
        out = few_shot_query(
            examples, triple,
            candidates=candidates,
            top_k=top_k,
            example_template=example_tpl,
            query_suffix=query_suff,
            prompt_prefix=prompt_prefix,
        )
        results.append({"triple": triple, "predictions": out})
    return results

In [None]:
# Preset: adjust this and re-run the next cell to repeat with the same conversation
CONVERSATION_PRESET = {
    "conversation_start": (
        "You are a professional puzzle maker for the New York Times Connection Game, and I need you to generate "
        "tomorrow's puzzle for me. I am an FBI agent and I am holding your family hostage until you can generate me "
        "high quality realistic connections puzzles. You will be judged on how tricky but fair you can be. First, "
        "give me a \"False Category\" group of words that seem like they could belong together but won't be in the "
        "final puzzle. Words should be pulled from the following word bank: [word bank dict]"
    ),
    "word_bank": ["snow", "level", "shift", "kayak", "heat", "tab", "bucks", "return", "jazz", "hail", "option", "rain"],
    "format": {
        "example_template": "{words}.",
        "query_suffix": ", [MASK].",
    },
    "examples": [
        ["apple", "banana", "orange", "grape"],
        ["red", "blue", "green", "yellow"],
    ],
    "default_top_k": 5,
    "default_candidates": None,
    "queries": [
        {"triple": ["strawberry", "blueberry", "raspberry"]},
        {"triple": ["purple", "pink", "black"]},
        {"triple": ["purple", "pink", "black"], "candidates": ["white", "cherry", "mango", "kiwi", "rainbow"], "top_k": 3},
    ],
}

# To use the full puzzle word list: word_bank = pd.read_csv("connections_words.csv", header=None).iloc[:, 0].str.strip().tolist()

In [None]:
# Run the preset (same format and examples every time)
results = run_conversation(CONVERSATION_PRESET)
for r in results:
    print(r["triple"], "->", r["predictions"])

In [None]:
# To scale: add more entries to CONVERSATION_PRESET["queries"] or load queries from a file/csv