In [1]:
import sys
from from_root import from_root
import pandas as pd

# Ensure we can import from src/
sys.path.insert(0, str(from_root("src")))

from read_and_write_docs import read_jsonl, read_rds
from model_loading import load_model, distinct_special_chars, load_model_efficient

In [2]:
model_loc = "/Volumes/BCross/models/ModernBERT/ModernBERT-base"

In [5]:
from transformers import AutoTokenizer, AutoModelForMaskedLM

tokenizer = AutoTokenizer.from_pretrained(model_loc)
model = AutoModelForMaskedLM.from_pretrained(model_loc)

In [7]:
text = "The capital of France [MASK] [MASK]."
inputs = tokenizer(text, return_tensors="pt")
outputs = model(**inputs)

# To get predictions for the mask:
masked_index = inputs["input_ids"][0].tolist().index(tokenizer.mask_token_id)
predicted_token_id = outputs.logits[0, masked_index].argmax(axis=-1)
predicted_token = tokenizer.decode(predicted_token_id)
print("Predicted token:", predicted_token)
# Predicted token:  Paris


Predicted token:  is


In [13]:
# B1 — Joint multi-mask fill with beam search (ModernBERT or any MLM)
import math
from typing import List, Dict, Tuple, Optional, Iterable
import torch
from transformers import PreTrainedTokenizerBase, PreTrainedModel, AutoTokenizer, AutoModelForMaskedLM

@torch.no_grad()
def fill_masks_beam(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizerBase,
    masked_text: str,
    top_k_per_mask: int = 25,   # widen/limit local options at each mask
    beam_size: int = 100,       # how many partial hypotheses to keep
    max_candidates: int = 50,   # truncate final list
    banned_token_ids: Optional[Iterable[int]] = None,
) -> List[Dict]:
    """
    Jointly fill one or more [MASK] tokens using a simple beam search.

    Returns a list of dicts: {text, score, tokens, token_ids, mask_positions}
    where 'score' is the sum of log-probs across all masks (higher is better).
    """
    if tokenizer.mask_token_id is None:
        raise ValueError("Tokenizer must define mask_token_id (MLM required).")

    enc = tokenizer(masked_text, return_tensors="pt", add_special_tokens=True, truncation=True)
    input_ids = enc["input_ids"]
    mask_id = tokenizer.mask_token_id
    mask_positions = (input_ids[0] == mask_id).nonzero(as_tuple=False).flatten().tolist()
    if not mask_positions:
        return []

    outputs = model(**enc)
    log_probs = outputs.logits.log_softmax(-1)  # [1, seq_len, vocab]

    specials = set(getattr(tokenizer, "all_special_ids", []) or [])
    banned = set(banned_token_ids or []) | specials

    # Precompute top-k candidates at each mask position
    per_mask_cands = []
    for pos in mask_positions:
        lp = log_probs[0, pos].clone()
        if banned:
            idx = torch.tensor(sorted(banned), dtype=torch.long)
            lp.index_fill_(0, idx, float("-inf"))
        topk = torch.topk(lp, k=min(top_k_per_mask, lp.numel()))
        per_mask_cands.append((topk.indices.tolist(), topk.values.tolist()))

    # Beam over masks (left-to-right in token order)
    beam = [(0.0, [])]  # (cum_logprob, chosen_token_ids_so_far)
    for cand_ids, cand_lps in per_mask_cands:
        new_beam = []
        for cum_lp, chosen in beam:
            for tid, lp in zip(cand_ids, cand_lps):
                new_beam.append((cum_lp + float(lp), chosen + [tid]))
        new_beam.sort(key=lambda x: x[0], reverse=True)
        beam = new_beam[:beam_size]

    # Materialize and deduplicate by decoded text
    best = {}
    for cum_lp, choice_ids in beam:
        filled = input_ids.clone()
        for pos, tid in zip(mask_positions, choice_ids):
            filled[0, pos] = tid
        text_out = tokenizer.decode(filled[0], skip_special_tokens=True)
        prev = best.get(text_out)
        if (prev is None) or (cum_lp > prev["score"]):
            best[text_out] = {
                "text": text_out,
                "score": cum_lp,
                "tokens": tokenizer.convert_ids_to_tokens(choice_ids),
                "token_ids": choice_ids,
                "mask_positions": mask_positions,
            }
    return sorted(best.values(), key=lambda r: r["score"], reverse=True)[:max_candidates]

# --- Example usage ---
# tok = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
# mdl = AutoModelForMaskedLM.from_pretrained("answerdotai/ModernBERT-base").eval()


In [15]:
masked = "I didn[MASK] [MASK] time to finish the report."
cands = fill_masks_beam(model, tokenizer, masked, top_k_per_mask=30, beam_size=150, max_candidates=40)
for c in cands:
    print(f"{c['text']}   (score={c['score']:.2f})  tokens={c['tokens']}")

I didn't have time to finish the report.   (score=-0.06)  tokens=["'t", 'Ġhave']
I didn't get time to finish the report.   (score=-3.87)  tokens=["'t", 'Ġget']
I didn't find time to finish the report.   (score=-4.45)  tokens=["'t", 'Ġfind']
I didn't take time to finish the report.   (score=-5.00)  tokens=["'t", 'Ġtake']
I didn't had time to finish the report.   (score=-5.91)  tokens=["'t", 'Ġhad']
I didn t have time to finish the report.   (score=-6.10)  tokens=['Ġt', 'Ġhave']
I didn't make time to finish the report.   (score=-6.16)  tokens=["'t", 'Ġmake']
I didn’ have time to finish the report.   (score=-6.68)  tokens=['âĢĻ', 'Ġhave']
I didn't enough time to finish the report.   (score=-7.33)  tokens=["'t", 'Ġenough']
I didn't leave time to finish the report.   (score=-7.47)  tokens=["'t", 'Ġleave']
I didn't need time to finish the report.   (score=-7.69)  tokens=["'t", 'Ġneed']
I didn not have time to finish the report.   (score=-7.90)  tokens=['Ġnot', 'Ġhave']
I didn't see time to f

In [16]:
# C1 — Try multiple lengths at one masked span and rank globally
from typing import Iterable, List, Dict

def variable_length_infill(
    model, tokenizer,
    masked_template: str,              # contains ONE [MASK] span
    length_options: Iterable[int] = (1, 2, 3, 4),
    per_length_topk: int = 10,
    normalize_by_masks: bool = True,   # de-bias longer spans
    **beam_kwargs
) -> List[Dict]:
    mask_tok = tokenizer.mask_token
    assert mask_tok in masked_template, "Template must contain one [MASK] span."
    rows = []
    for L in length_options:
        expanded = masked_template.replace(mask_tok, " ".join([mask_tok]*L), 1)
        outs = fill_masks_beam(model, tokenizer, expanded, **beam_kwargs)
        for o in outs[:per_length_topk]:
            score = o["score"] / L if normalize_by_masks else o["score"]
            rows.append({"text": o["text"], "length": L, "score": score, "raw_score": o["score"], "tokens": o["tokens"]})
    return sorted(rows, key=lambda r: r["score"], reverse=True)


In [18]:
tmpl = "I [MASK] time to finish the report."  # one span to grow/shrink
best = variable_length_infill(model, tokenizer, tmpl, length_options=range(1,5), per_length_topk=8,
                              top_k_per_mask=30, beam_size=200, max_candidates=80)
for r in best:
    print(f"L={r['length']} :: {r['text']} (norm_score={r['score']:.2f})")

L=3 :: I don't have time to finish the report. (norm_score=-1.02)
L=1 :: I have time to finish the report. (norm_score=-1.03)
L=4 :: I don't have enough time to finish the report. (norm_score=-1.05)
L=4 :: I don not have enough time to finish the report. (norm_score=-1.07)
L=2 :: I have no time to finish the report. (norm_score=-1.11)
L=3 :: I don not have time to finish the report. (norm_score=-1.11)
L=1 :: I need time to finish the report. (norm_score=-1.13)
L=3 :: I did't have time to finish the report. (norm_score=-1.14)
L=3 :: I do't have time to finish the report. (norm_score=-1.14)
L=4 :: I don't have have time to finish the report. (norm_score=-1.18)
L=4 :: I don not have have time to finish the report. (norm_score=-1.20)
L=3 :: I did not have time to finish the report. (norm_score=-1.23)
L=3 :: I do not have time to finish the report. (norm_score=-1.23)
L=3 :: I have't have time to finish the report. (norm_score=-1.24)
L=3 :: I will't have time to finish the report. (norm_scor

In [19]:
# D1 — PLL scorer (ModernBERT) to rerank finished candidates
import torch

@torch.no_grad()
def pll_score(model, tokenizer, text: str) -> float:
    enc = tokenizer(text, return_tensors="pt", add_special_tokens=True)
    ids = enc["input_ids"]
    mask_id = tokenizer.mask_token_id
    total = 0.0
    # For each token position (excluding specials), predict that token given the rest
    for i in range(1, ids.size(1)-1):                      # skip [CLS]/[SEP]-like specials
        masked = ids.clone()
        target = masked[0, i].item()
        masked[0, i] = mask_id
        out = model(input_ids=masked, attention_mask=enc["attention_mask"])
        lp = out.logits[0, i].log_softmax(-1)[target].item()
        total += lp
    return total


In [20]:
# Example: rerank top-N ModernBERT candidates from section B/C
scored = [(pll_score(model, tokenizer, c['text']), c['text']) for c in [cands[i] for i in range(min(10, len(cands)))]]
for s, t in sorted(scored, key=lambda x: x[0], reverse=True):
    print(f"{s:.2f} :: {t}")

-11.02 :: I didn't have time to finish the report.
-15.29 :: I didn't get time to finish the report.
-16.22 :: I didn't find time to finish the report.
-17.65 :: I didn t have time to finish the report.
-18.63 :: I didn't make time to finish the report.
-19.23 :: I didn't leave time to finish the report.
-19.79 :: I didn't take time to finish the report.
-20.11 :: I didn't had time to finish the report.
-21.42 :: I didn't enough time to finish the report.
-27.79 :: I didn’ have time to finish the report.
