In [9]:
import torch
from torch.utils.data import DataLoader, TensorDataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel

class TextGenerator:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = GPT2LMHeadModel.from_pretrained('gpt2-large')
        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')
        self.max_len = 30
        self.top_k = 100
        self.top_p = 0.8

    def generate_text(self, prompt):
        input_ids = self.tokenizer.encode(prompt, return_tensors='pt').to(self.device)
        sample_outputs = self.model.generate(
            input_ids,
            do_sample=True,
            max_length=self.max_len,
            top_k=self.top_k,
            top_p=self.top_p,
            num_return_sequences=1
        )
        return self.tokenizer.decode(sample_outputs[0], skip_special_tokens=True)

def calculate_perplexity(model, dataset, batch_size=1):
    dataloader = DataLoader(dataset, batch_size=batch_size)
    model.eval()
    total_loss = 0
    num_tokens = 0
    for batch in dataloader:
        # Extract the input tensor from the batch tuple
        input_tensor = batch[0].to(model.device)
        with torch.no_grad():
            outputs = model(input_tensor, labels=input_tensor)
            loss = outputs[0]
        total_loss += loss.item() * input_tensor.numel()
        num_tokens += input_tensor.numel()
    avg_loss = total_loss / num_tokens
    perplexity = torch.exp(torch.tensor(avg_loss))
    return perplexity.item()


# Example usage
text_generator = TextGenerator()
generated_text = text_generator.generate_text("Hello everyone")
print(f"Generated text: {generated_text}")

tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')
text_data = ["This is some example text.", "Here is some more text."]
encoded_data = [tokenizer.encode(text, return_tensors='pt') for text in text_data]
dataset = TensorDataset(torch.cat(encoded_data))

perplexity = calculate_perplexity(text_generator.model, dataset)
print(f"Perplexity: {perplexity}")


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Generated text: Hello everyone,We are glad to inform you that the team has finally managed to find a way to solve the current issue with the current release version of
Perplexity: 64.04151916503906
