In [1]:
from transformers import AutoTokenizer, AutoModelForMaskedLM

# Choose RoBERTa (good balance of quality & resource use)
model_name = "roberta-base"

# Load tokenizer (converts words to IDs, handles <mask>)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Load masked language model (predicts missing tokens)
model = AutoModelForMaskedLM.from_pretrained(model_name)

In [2]:
sentence = "Negotiations <mask> <mask> <mask> <mask> <mask> <mask>."
inputs = tokenizer(sentence, return_tensors="pt")

In [3]:
import torch

with torch.no_grad():  # no gradients needed (faster, less memory)
    outputs = model(**inputs)
    logits = outputs.logits  # shape: [batch, seq_len, vocab_size]


In [4]:
mask_token_id = tokenizer.mask_token_id
mask_positions = (inputs["input_ids"] == mask_token_id).nonzero(as_tuple=True)[1]
print("Mask positions:", mask_positions)

Mask positions: tensor([4, 5, 6, 7, 8, 9])


In [None]:
# Pick the 5th mask (index 4)
mask_index4 = mask_positions[4]

# Get logits for that mask
mask_logits = logits[0, mask_index4]

# Look up the token ID for "round"
round_id = tokenizer.convert_tokens_to_ids("table")
round_logit = mask_logits[round_id].item()

# Compare with average logit
avg_logit = mask_logits.mean().item()

print(f"Logit for 'table': {round_logit:.2f}")
print(f"Average logit across vocab: {avg_logit:.2f}")
print("Above average?", round_logit > avg_logit)


Logit for 'table': -0.27
Average logit across vocab: -2.25
Above average? True


In [38]:
topk = torch.topk(mask_logits, 500)
for idx, score in zip(topk.indices, topk.values):
    print(f"{tokenizer.decode([idx]):<15} {score.item():.2f}")


 the            8.62
 and            8.06
 in             7.71
 of             7.70
 be             7.27
 to             7.20
,               7.10
.               7.04
 are            6.75
 on             6.69
 for            6.68
 this           6.63
 with           6.56
 a              6.46
</s>            6.42
-               6.41
 is             6.30
 last           6.18
 European       6.06
 next           6.02
 international  6.00
 were           5.92
 further        5.91
 as             5.87
 two            5.82
 Russian        5.82
 other          5.81
's              5.72
 political      5.72
 was            5.71
 by             5.70
 been           5.69
 more           5.67
 new            5.66
 remain         5.57
 American       5.54
 not            5.38
 from           5.35
 at             5.32
 between        5.30
 June           5.29
 foreign        5.26
 three          5.25
 French         5.24
 EU             5.18
 government     5.14
 September      5.11
 national    

In [37]:
word = "Trump"
word_id = tokenizer.convert_tokens_to_ids(word)

# Get logit score for that word
word_logit = mask_logits[word_id].item()

# Sort all logits (descending: highest first)
sorted_scores, sorted_ids = torch.sort(mask_logits, descending=True)

# Find rank of the word
rank = (sorted_ids == word_id).nonzero(as_tuple=True)[0].item() + 1  # +1 for 1-based rank
vocab_size = mask_logits.shape[0]

print(f"Word '{word}' has logit {word_logit:.2f}")
print(f"Rank: {rank} out of {vocab_size}")


Word 'Trump' has logit 1.29
Rank: 2394 out of 50265


In [36]:
word = "Trump"
tokens = tokenizer.tokenize(word)
print(tokens)


['Trump']
