# Text Generation

In [None]:
import pandas as pd
import torch 

from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "gpt2-xl"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)

In [None]:
def predict_greedy(model: torch.nn.Module, input_text: str, max_new_tokens: int, choices_per_step: int, device="cpu") -> str:
    iterations = []
    # Turn text into tokens
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
    with torch.no_grad():
        for i in range(max_new_tokens):
            iteration = dict()
            iteration["input"] = tokenizer.decode(input_ids[0])
            output = model(input_ids=input_ids)

            next_token_logits = output.logits[0, -1, :]
            next_token_probs = torch.softmax(next_token_logits, dim=-1)
            sorted_ids = torch.argsort(next_token_probs, dim=-1, descending=True)

            for choice_idx in range(choices_per_step):
                token_id = sorted_ids[choice_idx]
                token_prob = next_token_probs[token_id].cpu().numpy()
                token_choice = f"{tokenizer.decode(token_id)} ({100*token_prob:.2f}%)"
                iteration[f"Choice {choice_idx+1}"] = token_choice
            
            input_ids = torch.cat([input_ids, sorted_ids[None, 0, None]], dim=-1)
            iterations.append(iteration)

    return pd.DataFrame(iterations)


In [None]:
predict_greedy(model=model, input_text="The transformer is", max_new_tokens=10, choices_per_step=5)

In [None]:
def generate(model: torch.nn.Module, input_text: str, max_new_tokens: int, do_sample=True, device: str = "cpu"):
    # Turn text into tokens
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
    output = model.generate(input_ids, max_new_tokens=max_new_tokens, do_sample=do_sample)
    return tokenizer.decode(output[0])

In [None]:
print("# Without sampling:")
for i in range(3):
    print("- " + generate(model=model, input_text="The transformer is", max_new_tokens=10, do_sample=False))

print("# With sampling:")
for i in range(3):
    print("- " + generate(model=model, input_text="The transformer is", max_new_tokens=10, do_sample=True))