In [None]:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from pathlib import Path
import os

In [None]:
model_name = "distilgpt2"
model_save_path = Path(os.getcwd()) / 'models'
weights_path = f'{model_save_path}/gpt2_retrained_good_split.pt'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(weights_path)
print(device)

In [None]:
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

model = GPT2LMHeadModel.from_pretrained(model_name)
model.load_state_dict(torch.load(weights_path, map_location=device))
model.to(device)
model.eval()

In [None]:
prompt = "Genre: Heavy Metal\n\n"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

with torch.inference_mode():
    outputs = model.generate(
        input_ids,
        do_sample=True,
        max_length=512,
        top_p=0.9,
        top_k=50,
        temperature=0.9,
        repetition_penalty=1.1,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        num_return_sequences=2,
    )
for i, output in enumerate(outputs, 1):
    text = tokenizer.decode(output, skip_special_tokens=True)
    print(f"\n{'='*60}")
    print(f"SAMPLE {i}")
    print(f"{'='*60}")
    print(text)

print(text)