In [None]:
from typing import Dict, Optional

from datasets import load_dataset
import torch
from transformers import AutoConfig, AutoTokenizer, GPT2LMHeadModel

In [None]:
# Provide checkpoint path
ckpt_path = "../models/gpt2-finetuned.pt"
device = "cuda:0" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained("gpt2")
gpt2config = AutoConfig.from_pretrained("gpt2")
model = GPT2LMHeadModel(gpt2config)
model.load_state_dict(torch.load(ckpt_path, map_location=torch.device('cpu'))).to(device)
# model = GPT2LMHeadModel.from_pretrained("gpt2").to(device)


In [None]:
# Optionally load the test set to sample prompts.
dataset = load_dataset("wikitext", "wikitext-103-v1", split="test")
print(dataset)

block_size = int(tokenizer.model_max_length / 2)

def tokenize_function(examples):
    return tokenizer(examples["text"])

def group_texts(examples):
        # Concatenate all texts.
        concatenated_examples = {
            k: sum(examples[k], []) for k in examples.keys()}
        total_length = len(concatenated_examples[list(examples.keys())[0]])
        # We drop the small remainder, we could add padding if the model supported
        # it instead of this drop, you can customize this part to your needs.
        total_length = (total_length // block_size) * block_size
        # Split by chunks of max_len.
        result = {
            k: [t[i: i + block_size]
                for i in range(0, total_length, block_size)]
            for k, t in concatenated_examples.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result

tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=["text"])
test_dataset = tokenized_dataset.map(group_texts, batched=True)
print(len(test_dataset), test_dataset)

In [None]:
def generate(prompt: Dict, ref: Optional[torch.IntTensor] = None, max_new_tokens: int = 100):
    prompt = {k: v.to(device) for k, v in prompt.items()}
    out = model.generate(**prompt, max_new_tokens=max_new_tokens)
    print("OUTPUT:\n", tokenizer.decode(out[0], skip_special_tokens=True))
    if ref:
        print("REFERENCE:\n", tokenizer.decode(ref[: min(len(ref), max_new_tokens)], skip_special_tokens=True))

# generate(input, test_dataset[0]['input_ids'])

In [None]:
def generate_randomly_from_dataset(dataset, max_new_tokens=100):
    idx = torch.randint(len(dataset), size=(1,)).item()
    # start_idx = torch.randint(block_size-max_new_tokens, size=(1,)).item()
    start_idx = 0
    prompt_length = 20 # torch.randint(10, 20, size=(1,)).item()
    prompt = {k: torch.IntTensor(v[start_idx: start_idx+prompt_length]).unsqueeze(dim=0) for k,v in dataset[idx].items()}
    print("PROMPT:\n", tokenizer.decode(prompt['input_ids'][0][:max_new_tokens], skip_special_tokens=True))
    generate(prompt, dataset[idx]['input_ids'][start_idx:], max_new_tokens)

generate_randomly_from_dataset(test_dataset)


In [None]:
# Generating based on user prompt
input = "He is known as a great basketball player."
def generate_from_user_input(input: str, max_new_tokens: int = 100):
    print("PROMPT:\n", input)
    prompt = tokenizer(input, return_tensors="pt").to(device)
    out = model.generate(**prompt, max_new_tokens=max_new_tokens)
    print("OUTPUT:\n", tokenizer.decode(out[0], skip_special_tokens=True))

generate_from_user_input(input)
