In [0]:
from transformers import GPT2Tokenizer, GPT2Model

# Load the pre-trained GPT-2 tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
model = GPT2Model.from_pretrained('gpt2-medium')

In [None]:
import torch

In [10]:
# Define the input sentence
input_sentence = 'Hello I am a man'

# Tokenize the input sentence
input_ids = tokenizer.encode(input_sentence, return_tensors='pt')

In [12]:
# Get the logits for each token in the input sentence
with torch.no_grad():
    outputs = model(input_ids)
    logits = outputs[0][0]

# Convert the logits to probabilities
probs = torch.softmax(logits, dim=-1)

# Print the logits and probabilities for each token in the input sentence
for i, token in enumerate(input_sentence.split()):
    token_id = tokenizer.convert_tokens_to_ids(token)
    logit = logits[i][token_id].item()
    prob = probs[i][token_id].item()
    print('Token:', token)
    print('Logit:', logit)
    print('Probability:', prob)


Token: Hello
Logit: -62.537235260009766
Probability: 0.0005615773261524737
Token: I
Logit: -93.89395904541016
Probability: 8.078628525254317e-06
Token: am
Logit: -88.47907257080078
Probability: 5.0811927394533996e-06
Token: a
Logit: -74.59211730957031
Probability: 6.097152891015867e-06
Token: man
Logit: -77.46145629882812
Probability: 1.7511711121187545e-05


In [5]:
input_sentence = "I think"
input_ids = tokenizer.encode(input_sentence, return_tensors='pt')

In [6]:
next_token_logits = model(input_ids)[0][:, -1, :]
probs = next_token_logits.softmax(dim=-1)

In [20]:
next_token = torch.argmax(probs, dim=-1)

In [21]:
k = 5
topk_tokens = torch.topk(probs, k=k, dim=-1).indices

In [22]:
topk_tokens

tensor([[340, 326, 356, 262, 314]])

In [23]:
topk_token_forms = tokenizer.convert_ids_to_tokens(topk_tokens[0].tolist())

In [24]:
for i in range(k):
    token_form = tokenizer.convert_ids_to_tokens(topk_tokens[0][i].item())
    print(f"Token form {i}: {token_form}")

Token form 0: Ġit
Token form 1: Ġthat
Token form 2: Ġwe
Token form 3: Ġthe
Token form 4: ĠI
