In [13]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch

def analyze_masks(sentence, model, tokenizer, search_word=None, top_n=5):
    """
    For each <mask> token in the sentence:
    - Show rank/logit/prob of a searched word (if provided).
    - Show top_n most probable candidates.
    """
    inputs = tokenizer(sentence, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits  # [batch, seq_len, vocab_size]

    mask_token_id = tokenizer.mask_token_id
    mask_positions = (inputs["input_ids"] == mask_token_id).nonzero(as_tuple=True)[1]

    search_word_id = None
    if search_word:
        search_word_id = tokenizer.convert_tokens_to_ids(search_word)

    results = {}
    for pos in mask_positions:
        mask_logits = logits[0, pos]
        probs = torch.softmax(mask_logits, dim=-1)

        # Top-N predictions
        top_logits, top_ids = torch.topk(mask_logits, top_n)
        tokens = tokenizer.convert_ids_to_tokens(top_ids.tolist())
        top_probs = probs[top_ids].tolist()

        entry = {
            "top_predictions": [
                {"token": t, "logit": l.item(), "prob": p}
                for t, l, p in zip(tokens, top_logits, top_probs)
            ]
        }

        # If searching a specific word
        if search_word_id is not None:
            word_logit = mask_logits[search_word_id].item()
            word_prob = probs[search_word_id].item()

            sorted_ids = torch.argsort(mask_logits, descending=True)
            rank = (sorted_ids == search_word_id).nonzero(as_tuple=True)[0].item() + 1  # 1-based

            entry["searched_word"] = {
                "word": search_word,
                "logit": word_logit,
                "prob": word_prob,
                "rank": rank,
                "vocab_size": mask_logits.shape[0]
            }

        results[pos.item()] = entry

    return results


# Example usage
if __name__ == "__main__":
    model_name = "roberta-base"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForMaskedLM.from_pretrained(model_name)

    sentence = "<mask> <mask> <mask> worker <mask> <mask> <mask> <mask> <mask> <mask>" 

    preds = analyze_masks(sentence, model, tokenizer, search_word="Ġshovel", top_n=1)

    for pos, info in preds.items():
        if pos < 20:
            print(f"\nMask at position {pos}:")
            if "searched_word" in info:
                sw = info["searched_word"]
                print(f"  Word '{sw['word']}': logit={sw['logit']:.2f}, prob={sw['prob']:.16f}, rank={sw['rank']} / {sw['vocab_size']}")
            print("  Top predictions:")
            for c in info["top_predictions"]:
                print(f"    {c['token']:>10s} | logit={c['logit']:.2f} | prob={c['prob']:.8f}")



Mask at position 1:
  Word 'Ġshovel': logit=-0.65, prob=0.0000000413132319, rank=16789 / 50265
  Top predictions:
          </s> | logit=14.51 | prob=0.15902592

Mask at position 2:
  Word 'Ġshovel': logit=-0.43, prob=0.0000060175279941, rank=12895 / 50265
  Top predictions:
             A | logit=8.09 | prob=0.03012952

Mask at position 3:
  Word 'Ġshovel': logit=0.75, prob=0.0000060494016907, rank=8480 / 50265
  Top predictions:
    Ġconstruction | logit=9.73 | prob=0.04781074

Mask at position 5:
  Word 'Ġshovel': logit=2.55, prob=0.0000263756737695, rank=2415 / 50265
  Top predictions:
           Ġis | logit=10.60 | prob=0.08302768

Mask at position 6:
  Word 'Ġshovel': logit=1.10, prob=0.0000405749451602, rank=3031 / 50265
  Top predictions:
            Ġa | logit=8.21 | prob=0.04979895

Mask at position 7:
  Word 'Ġshovel': logit=1.11, prob=0.0000514188650413, rank=2461 / 50265
  Top predictions:
           Ġto | logit=7.64 | prob=0.03532176

Mask at position 8:
  Word 'Ġshovel'

In [15]:
y = 0.0000001693114484 + 0.0000001858393830 + 0.0000005164816912 + 0.0000040519726099 + 0.0000061137752709
z = y * 1.26e-05
z

1.3907099308283998e-10

In [11]:

word = " digging"
tokens = tokenizer.tokenize(word)
print(tokens)

['Ġdigging']
