In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW
from datasets import load_dataset

# Load the model and tokenizer
model_name = "gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained(model_name)

# Set BOS and EOS tokens to the same token
tokenizer.bos_token = tokenizer.eos_token

In [None]:
# Load the XSum dataset
dataset = load_dataset("xsum", split="train[:1%]")  # Small subset for quick training


# Tokenize inputs and targets for summarization
def tokenize_data(example):
    input_text = tokenizer.bos_token + " Summarize: " + example["document"]
    input_ids = tokenizer(input_text, truncation=True, max_length=512).input_ids
    # Shift target tokens for decoder input during training
    return {"input_ids": input_ids, "labels": input_ids}


# Apply tokenization
tokenized_dataset = dataset.map(tokenize_data, load_from_cache_file=False)

In [None]:
# Custom collate function for dynamic padding
def collate_fn(batch):
    input_ids = [torch.tensor(item["input_ids"]) for item in batch]
    labels = [torch.tensor(item["labels"]) for item in batch]

    # Pad sequences dynamically
    input_ids_padded = torch.nn.utils.rnn.pad_sequence(
        input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
    )
    labels_padded = torch.nn.utils.rnn.pad_sequence(
        labels, batch_first=True, padding_value=-100
    )

    return {"input_ids": input_ids_padded, "labels": labels_padded}


# Create DataLoader with custom collate function
dataloader = DataLoader(
    tokenized_dataset, batch_size=2, shuffle=True, collate_fn=collate_fn
)

# Set up optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Training loop
num_epochs = 3
model.train()
for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    for batch in dataloader:
        input_ids = batch["input_ids"].to(device)
        labels = batch["labels"].to(device)

        # Forward pass
        outputs = model(input_ids=input_ids, labels=labels)
        loss = outputs.loss

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item():.4f}")

print("Training complete!")