In [3]:
# Import necessary libraries
import torch
from transformers import XLNetTokenizer, XLNetLMHeadModel
from torch.utils.data import Dataset, DataLoader

# Load the fine-tuned model and tokenizer
model_path = './xlnet_finetuned'
tokenizer = XLNetTokenizer.from_pretrained(model_path)
model = XLNetLMHeadModel.from_pretrained(model_path)
model.eval()
model.to('cuda' if torch.cuda.is_available() else 'cpu')

# Define a dataset for handling input data for predictions
class InputDataset(Dataset):
    def __init__(self, tokenizer, texts, max_len=512):
        self.tokenizer = tokenizer
        self.texts = texts
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        return inputs['input_ids'].squeeze(0), inputs['attention_mask'].squeeze(0)

# Sample data for prediction
sample_texts = [
    "Hello, this is the assistant. How can I assist you today?",
    "I'm interested in booking a flight to New York."
]

# Create dataset and dataloader
input_dataset = InputDataset(tokenizer, sample_texts)
input_dataloader = DataLoader(input_dataset, batch_size=2)

# Generate predictions
for batch in input_dataloader:
    input_ids, attention_mask = batch
    input_ids = input_ids.to('cuda' if torch.cuda.is_available() else 'cpu')
    attention_mask = attention_mask.to('cuda' if torch.cuda.is_available() else 'cpu')

    with torch.no_grad():
        # Set max_new_tokens to generate a fixed number of tokens beyond the input
        outputs = model.generate(
            input_ids, 
            attention_mask=attention_mask, 
            max_new_tokens=50  # generate up to 50 new tokens beyond the length of input_ids
        )
        for i, output in enumerate(outputs):
            print(f"Input: {sample_texts[i]}")
            print(f"Generated: {tokenizer.decode(output, skip_special_tokens=True)}\n")



This is a friendly reminder - the current text generation call will exceed the model's predefined maximum length (-1). Depending on the model, you may observe exceptions, performance degradation, or nothing at all.


Input: Hello, this is the assistant. How can I assist you today?
Generated: Hello, this is the assistant. How can I assist you today?ssssssssssssssssssssssssssssssssssssssssssssssssss

Input: I'm interested in booking a flight to New York.
Generated: I'm interested in booking a flight to New York.iorioriorporporporiorioriorioriorioriorioriorioriorioriorioriorioriorioriorioriorioriorioriorioriorioriorioriorioriorioriorioriorioriorioriorioriorior

