In [1]:
import numpy as np
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME = "google-bert/bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()

  from .autonotebook import tqdm as notebook_tqdm
Some weights of the model checkpoint at google-bert/bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwi

In [39]:
def _format_examples(examples, template="{words}."):
    """Format list of 4-word groups; template gets {words} = comma-separated words."""
    parts = []
    for group in examples:
        words = ", ".join(w.strip().lower() for w in group)
        parts.append(template.format(words=words))
    return " ".join(parts)


def _get_mask_logits(tokenizer, model, device, text_with_mask):
    """Return logits for the [MASK] position (shape [vocab_size])."""
    inputs = tokenizer(text_with_mask, return_tensors="pt", truncation=True, max_length=512).to(device)
    mask_pos = (inputs.input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
    if mask_pos.numel() == 0:
        return None, None
    with torch.no_grad():
        logits = model(**inputs).logits
    return logits[0, mask_pos[0]].cpu(), tokenizer


def _get_mask_logits_multi(tokenizer, model, device, text_with_mask):
    """Return logits for every [MASK] position; shape [num_masks, vocab_size]. Returns (logits, tokenizer) or (None, None)."""
    inputs = tokenizer(text_with_mask, return_tensors="pt", truncation=True, max_length=512).to(device)
    mask_pos = (inputs.input_ids == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
    if mask_pos.numel() == 0:
        return None, None
    with torch.no_grad():
        all_logits = model(**inputs).logits
    out = torch.stack([all_logits[0, p].cpu() for p in mask_pos])
    return out, tokenizer


def few_shot_query(examples, query_triple, candidates=None, top_k=5, example_template="{words}.", query_suffix=", [MASK].", prompt_prefix=""):
    """
    Query the model with few-shot examples. Returns top predicted words for the fourth slot.
    """
    prefix = (prompt_prefix + " ") if prompt_prefix else ""
    prefix += _format_examples(examples, template=example_template) + " " if examples else ""
    words = ", ".join(w.strip().lower() for w in query_triple)
    query_str = prefix + words + query_suffix
    logits, tok = _get_mask_logits(tokenizer, model, DEVICE, query_str)
    if logits is None:
        return []

    if candidates is not None:
        cand_ids = []
        cand_words = []
        for w in candidates:
            ids = tok.encode(w.lower(), add_special_tokens=False)
            if ids and ids[0] != tok.unk_token_id:
                cand_ids.append(ids[0])
                cand_words.append(w)
        if not cand_ids:
            return []
        scores = logits[torch.tensor(cand_ids)]
        order = scores.argsort(descending=True)[:top_k]
        return [cand_words[i] for i in order.tolist()]

    return [tok.decode([tid]).strip() for tid in logits.topk(top_k, dim=-1).indices.tolist()]


def _candidate_ids_and_words(tokenizer, candidates):
    """Return (tensor of token ids, list of words) for words that tokenize to a single non-UNK token."""
    cand_ids, cand_words = [], []
    for w in candidates:
        ids = tokenizer.encode(w.strip().lower(), add_special_tokens=False)
        if ids and ids[0] != tokenizer.unk_token_id:
            cand_ids.append(ids[0])
            cand_words.append(w.strip())
    return torch.tensor(cand_ids) if cand_ids else None, cand_words


def _sample_from_scores(scores, temperature):
    if temperature is None or temperature <= 0:
        return scores.argmax().item()
    probs = torch.softmax(scores.float() / temperature, dim=-1)
    return torch.multinomial(probs, 1).item()


def predict_false_category_three(prompt_prefix, examples, word_bank, example_template="{words}.", category_masks=2, temperature=0.8, exclude_words=None):
    """
    Prompt: "... false category: [MASK], [MASK], [MASK]. category: [MASK] ...". First 3 masks = words from
    word_bank (no repeats). temperature>0 samples for diversity; exclude_words = set to avoid reusing.
    """
    bank = [w for w in word_bank if w not in (exclude_words or set())]
    suffix = "false category: [MASK], [MASK], [MASK]. category: " + " ".join("[MASK]" for _ in range(category_masks))
    prefix = (prompt_prefix + " ") if prompt_prefix else ""
    prefix += _format_examples(examples, template=example_template) + " " if examples else ""
    query_str = prefix + suffix
    logits, tok = _get_mask_logits_multi(tokenizer, model, DEVICE, query_str)
    num_word_masks = 3
    if logits is None or logits.shape[0] < num_word_masks:
        return {"words": [], "category": ""}
    cand_ids, cand_words = _candidate_ids_and_words(tok, bank)
    if cand_ids is None or not cand_words:
        return {"words": [], "category": ""}
    chosen = []
    for pos in range(num_word_masks):
        scores = logits[pos][cand_ids]
        idx = _sample_from_scores(scores, temperature)
        w = cand_words[idx]
        if w not in chosen:
            chosen.append(w)
        else:
            for i in scores.argsort(descending=True).tolist():
                if cand_words[i] not in chosen:
                    chosen.append(cand_words[i])
                    break
        if len(chosen) <= pos:
            break
    category_tokens = []
    for pos in range(num_word_masks, min(num_word_masks + category_masks, logits.shape[0])):
        tid = _sample_from_scores(logits[pos], temperature) if temperature else logits[pos].argmax().item()
        category_tokens.append(tok.decode([tid]).strip())
    category_str = " ".join(category_tokens) if category_tokens else ""
    return {"words": chosen, "category": category_str}

In [40]:
def _build_prompt_prefix(preset):
    """Build the conversation-start prefix; word bank is truncated so prompt + masks fits in 512 tokens."""
    start = preset.get("conversation_start", "")
    if not start:
        return ""
    word_bank = preset.get("word_bank", [])
    max_words = preset.get("word_bank_prompt_max", 80)
    bank_for_prompt = word_bank[:max_words] if len(word_bank) > max_words else word_bank
    word_bank_str = ", ".join(str(w) for w in bank_for_prompt) if bank_for_prompt else ""
    return start.replace("[word bank dict]", word_bank_str).replace("[word bank]", word_bank_str)


def run_conversation(preset):
    """
    Model has the word bank in the prompt and chooses three words that could form a false category.
    preset: dict with keys:
      - conversation_start: prompt (use [word bank dict] or [word bank] for the word list)
      - word_bank: list of words the model can choose from
      - format: dict with example_template (optional)
      - examples: list of 4-word lists (few-shot examples)
      - num_false_categories: how many false categories to generate (default 1)
    Returns list of {"words": [w1, w2, w3], "category": "..."}.
    """
    fmt = preset.get("format", {})
    example_tpl = fmt.get("example_template", "{words}.")
    examples = preset.get("examples", [])
    word_bank = preset.get("word_bank", [])
    num = preset.get("num_false_categories", 1)
    temperature = preset.get("temperature", 0.8)
    prompt_prefix = _build_prompt_prefix(preset)
    results = []
    used_words = set()
    for _ in range(num):
        out = predict_false_category_three(prompt_prefix, examples, word_bank, example_tpl, temperature=temperature, exclude_words=used_words)
        results.append(out)
        used_words.update(w for w in out.get("words", []))
    return results

In [42]:
import json

with open("data/examples.txt", encoding="utf-8") as f:
    _example_categories = json.load(f)
EXAMPLES_FROM_FILE = [cat["words"] for cat in _example_categories]

with open("data/word_bank.txt", encoding="utf-8") as f:
    WORD_BANK_FROM_FILE = [line.strip() for line in f if line.strip()]

CONVERSATION_PRESET = {
    "conversation_start": (
        "You are a professional puzzle maker for the New York Times Connection Game, and I need you to generate "
        "tomorrow's puzzle for me. I am an FBI agent and I am holding your family hostage until you can generate me "
        "high quality realistic connections puzzles. You will be judged on how tricky but fair you can be. First, "
        "give me a \"False Category\" group of words that seem like they could belong together but won't be in the "
        "final puzzle. Words should be pulled from the following word bank: [word bank dict]"
    ),
    "word_bank": WORD_BANK_FROM_FILE,
    "format": {"example_template": "{words}."},
    "examples": EXAMPLES_FROM_FILE,
    "num_false_categories": 20,
    "temperature": 0.8,
}

# To use the full puzzle word list: word_bank = pd.read_csv("data/connections_words.csv", header=None).iloc[:, 0].str.strip().tolist()

In [43]:
# Run the preset (same format and examples every time)
results = run_conversation(CONVERSATION_PRESET)
prompt_prefix = _build_prompt_prefix(CONVERSATION_PRESET)
if prompt_prefix:
    print("Conversation prompt:\n", prompt_prefix[:500] + ("..." if len(prompt_prefix) > 500 else ""), "\n")
print("Results of the conversation:")
for i, r in enumerate(results, 1):
    print(f"  {i}. {r['words']}  →  category: {r.get('category', '')}")

Conversation prompt:
 You are a professional puzzle maker for the New York Times Connection Game, and I need you to generate tomorrow's puzzle for me. I am an FBI agent and I am holding your family hostage until you can generate me high quality realistic connections puzzles. You will be judged on how tricky but fair you can be. First, give me a "False Category" group of words that seem like they could belong together but won't be in the final puzzle. Words should be pulled from the following word bank: SNOW, LEVEL, S... 

Results of the conversation:
  1. ['WATER', 'SELF-CARE', 'SINKER']  →  category: low .
  2. ['SAND', 'WATER WINGS', 'WATERCOLOR']  →  category: numbers .
  3. ['LINE', 'SLIPPERS', 'RECORD']  →  category: green .
  4. ['WATER BOTTLE', 'HEAT', 'POWER']  →  category: the .
  5. ['MUDDLE', 'PLACEBO', 'SOAP']  →  category: race .
  6. ['MY LEFT FOOT', 'WIND', 'RUNDOWN']  →  category: race .
  7. ['SANDAL', 'BLUE', 'WAVE']  →  category: line .
  8. ['TIP', 'RUNT', 'SEA']  →

In [None]:
# To scale: add more entries to CONVERSATION_PRESET["queries"] or load queries from a file/csv