In [11]:
from transformers import AutoTokenizer, AutoModelForMaskedLM

# Choose RoBERTa (good balance of quality & resource use)
model_name = "roberta-base"

# Load tokenizer (converts words to IDs, handles <mask>)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Load masked language model (predicts missing tokens)
model = AutoModelForMaskedLM.from_pretrained(model_name)

In [12]:
sentence = "<mask> apple <mask> <mask> <mask> <mask> <mask>."
inputs = tokenizer(sentence, return_tensors="pt")

In [13]:
import torch

with torch.no_grad():  # no gradients needed (faster, less memory)
    outputs = model(**inputs)
    logits = outputs.logits  # shape: [batch, seq_len, vocab_size]


In [14]:
mask_token_id = tokenizer.mask_token_id
mask_positions = (inputs["input_ids"] == mask_token_id).nonzero(as_tuple=True)[1]
print("Mask positions:", mask_positions)

Mask positions: tensor([1, 3, 4, 5, 6, 7])


In [15]:
# Pick the 5th mask (index 4)
mask_index4 = mask_positions[0]

# Get logits for that mask
mask_logits = logits[0, mask_index4]

# Look up the token ID for "round"
round_id = tokenizer.convert_tokens_to_ids("round")
round_logit = mask_logits[round_id].item()

# Compare with average logit
avg_logit = mask_logits.mean().item()

print(f"Logit for 'round': {round_logit:.2f}")
print(f"Average logit across vocab: {avg_logit:.2f}")
print("Above average?", round_logit > avg_logit)


Logit for 'round': 5.32
Average logit across vocab: -1.72
Above average? True


In [16]:
topk = torch.topk(mask_logits, 100)
#test
for idx, score in zip(topk.indices, topk.values):
    print(f"{tokenizer.decode([idx]):<15} {score.item():.2f}")


An              16.85
The             16.85
This            14.62
One             13.99
My              13.86
an              13.81
Another         13.57
Every           13.21
the             13.17
How             13.10
Two             13.09
No              12.98
That            12.81
Apple           12.80
Fresh           12.64
American        12.59
Three           12.51
And             12.49
Green           12.47
For             12.41
More            12.40
Each            12.35
First           12.31
1               12.25
These           12.22
With            12.22
Some            12.22
Traditional     12.16
On              12.10
Our             12.09
As              12.00
Rare            11.98
What            11.93
Not             11.90
Your            11.88
Big             11.87
Raw             11.85
Classic         11.85
A               11.84
New             11.80
Red             11.78
and             11.77
Four            11.74
Next            11.72
When            11.70
Most      

In [17]:
word = "round"
word_id = tokenizer.convert_tokens_to_ids(word)

# Get logit score for that word
word_logit = mask_logits[word_id].item()

# Sort all logits (descending: highest first)
sorted_scores, sorted_ids = torch.sort(mask_logits, descending=True)

# Find rank of the word
rank = (sorted_ids == word_id).nonzero(as_tuple=True)[0].item() + 1  # +1 for 1-based rank
vocab_size = mask_logits.shape[0]

print(f"Word '{word}' has logit {word_logit:.2f}")
print(f"Rank: {rank} out of {vocab_size}")


Word 'round' has logit 5.32
Rank: 3465 out of 50265


In [None]:
word = "Trump"
tokens = tokenizer.tokenize(word)
print(tokens)
