In [14]:
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch

# Load RoBERTa
model_name = "roberta-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)

# Define the sentence (with masks)
sentence = "The <mask> is the largest desert in the world."
inputs = tokenizer(sentence, return_tensors="pt")

with torch.no_grad():
    outputs = model(**inputs)
    logits = outputs.logits  # [batch, seq_len, vocab_size]

# Get mask positions
mask_token_id = tokenizer.mask_token_id
mask_positions = (inputs["input_ids"] == mask_token_id).nonzero(as_tuple=True)[1]
print("Mask positions:", mask_positions.tolist())

# Get word ID
word = "ĠAntarctica"
word_id = tokenizer.convert_tokens_to_ids(word)

# Number of alternatives to display
top_n = 6

# Iterate over all masks
results = []
for pos in mask_positions:
    mask_logits = logits[0, pos]

    # Word logit
    word_logit = mask_logits[word_id].item()

    # Rank of the word
    sorted_scores, sorted_ids = torch.sort(mask_logits, descending=True)
    rank = (sorted_ids == word_id).nonzero(as_tuple=True)[0].item() + 1  # 1-based

    # Average logit for context
    avg_logit = mask_logits.mean().item()

    # Top N predictions
    top_scores, top_ids = torch.topk(mask_logits, top_n)
    top_tokens = tokenizer.convert_ids_to_tokens(top_ids.tolist())
    top_predictions = [(tok, float(score)) for tok, score in zip(top_tokens, top_scores)]

    results.append({
        "mask_index": pos.item(),
        "word_logit": word_logit,
        "avg_logit": avg_logit,
        "rank": rank,
        "vocab_size": mask_logits.shape[0],
        "top_predictions": top_predictions
    })

# Print nicely
for r in results:
    print(
        f"\nMask {r['mask_index']:2d} | "
        f"logit(table)={r['word_logit']:.2f} | "
        f"avg={r['avg_logit']:.2f} | "
        f"rank(table)={r['rank']} / {r['vocab_size']}"
    )
    print("   Top predictions:")
    for tok, score in r["top_predictions"]:
        print(f"      {tok:15s} (logit={score:.2f})")


Mask positions: [2]

Mask  2 | logit(table)=6.75 | avg=-2.31 | rank(table)=354 / 50265
   Top predictions:
      ĠSahara         (logit=16.91)
      Ġdesert         (logit=14.36)
      ĠLevant         (logit=12.65)
      ĠMoon           (logit=12.41)
      ĠAmazon         (logit=12.04)
      ĠSinai          (logit=11.97)


In [57]:
word = " Antarctica"
tokens = tokenizer.tokenize(word)
print(tokens)

['ĠAntarctica']
