In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2").to(device)

prompt = "Today I believe I can fly"
input_ids = tokenizer(prompt, return_tensors="pt")
input_ids = {k: v.to(device) for k, v in input_ids.items()}

Using cuda device


In [4]:

# Greedy search
outputs = model.generate(**input_ids, max_length=30)
greedy_out = tokenizer.batch_decode(outputs, skip_special_tokens=False)

# Beam search
outputs = model.generate(**input_ids, max_length=30, num_beams=5, early_stopping=True)
beam_out = tokenizer.batch_decode(outputs, skip_special_tokens=False)

# Top k sampling
outputs = model.generate(**input_ids, max_length=30, do_sample=True, top_k=50)
top_k_out = tokenizer.batch_decode(outputs, skip_special_tokens=False)

# Top p sampling
outputs = model.generate(**input_ids, max_length=30, do_sample=True, top_p=0.92)
top_p_out = tokenizer.batch_decode(outputs, skip_special_tokens=False)

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [8]:
print(f"Greedy search: {greedy_out}")
print(f"Beam search: {beam_out}")
print(f"Top k sampling: {top_k_out}")
print(f"Top p sampling: {top_p_out}")

Greedy search: ['Today I believe I can fly. I can fly. I can fly. I can fly. I can fly. I can fly. I can fly']
Beam search: ['Today I believe I can fly.\n\nI believe I can fly.\n\nI believe I can fly.\n\nI believe I can fly']
Top k sampling: ["Today I believe I can fly. I'm not afraid. I'm trying to make my life fun and fun for others. When I see the bright"]
Top p sampling: ['Today I believe I can fly in all planes but only when necessary. I can not leave my home or my family in a box but can be taken']
