In [None]:
import torch
import numpy as np
from transformers import T5ForConditionalGeneration, T5Tokenizer
from datasets import load_dataset
from rouge import Rouge
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

In [None]:
# Load the XSum dataset
dataset = load_dataset("xsum")

In [None]:
# Load the tokenizer and model
tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")

In [None]:
# Preprocess the dataset
def preprocess_batch(batch):
    input_texts = ["summarize: " + doc for doc in batch["document"]]
    target_texts = batch["summary"]

    source = tokenizer(input_texts, max_length=512, truncation=True, padding='max_length', return_tensors="pt")
    target = tokenizer(target_texts, max_length=150, truncation=True, padding='max_length', return_tensors="pt")

    return {
        "input_ids": source["input_ids"].tolist(),
        "attention_mask": source["attention_mask"].tolist(),
        "labels": target["input_ids"].tolist(),
    }

tokenized_dataset = dataset.map(preprocess_batch, remove_columns=["document", "summary"], batched=True, batch_size=16)
train_dataset = tokenized_dataset["train"]


In [None]:
# Hyperparameters
alpha = 0.9
gamma = 0.99
n_epochs = 3
batch_size = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

rouge = Rouge()

In [None]:
# Calculate the Rouge-L score
def compute_rouge_l(pred_summary, ref_summary):
    scores = rouge.compute(predictions=[pred_summary], references=[ref_summary], rouge_types=["rougeL"])
    return scores["rougeL"].fmeasure[0]

In [None]:
# Train the model using PPO
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

for epoch in range(n_epochs):
    epoch_loss = 0
    for i in range(0, len(train_dataset), batch_size):
        model.train()
        optimizer.zero_grad()

        batch = train_dataset[i:i+batch_size]
        input_ids = torch.stack(batch["input_ids"]).to(device)
        attention_mask = torch.stack(batch["attention_mask"]).to(device)
        labels = torch.stack(batch["labels"]).to(device)

        # Generate summaries
        with torch.no_grad():
            summary_ids = model.generate(input_ids, attention_mask=attention_mask, num_beams=4, max_length=150, early_stopping=True)
        pred_summaries = [tokenizer.decode(s, skip_special_tokens=True) for s in summary_ids]
        ref_summaries = [tokenizer.decode(s, skip_special_tokens=True) for s in labels]

        # Calculate rewards
        rewards = []
        for pred_summary, ref_summary in zip(pred_summaries, ref_summaries):
            reward = compute_rouge_l(pred_summary, ref_summary)
            rewards.append(reward)

        rewards = torch.tensor(rewards).to(device)

        # Compute the policy gradient loss
        logits = model(input_ids, attention_mask=attention_mask).logits
        log_probs = torch.gather(logits.view(-1, logits.size(-1)), 1, labels.view(-1, 1)).view(batch_size, -1)
        loss = -(rewards * log_probs).sum() / batch_size

        # Backpropagate the loss
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch + 1} Loss: {epoch_loss / len(train_dataset):.4f}")

In [None]:
# Evaluate the model
def generate_summary(text):
    input_text = "summarize: " + text
    inputs = tokenizer(input_text, max_length=512, truncation=True, padding="max_length", return_tensors="pt").to(device)
    with torch.no_grad():
        summary_ids = model.generate(inputs["input_ids"], attention_mask=inputs["attention_mask"], num_beams=4, max_length=150, early_stopping=True)
    return tokenizer.decode(summary_ids[0], skip_special_tokens=True)

rouge = Rouge()
test_dataset = tokenized_dataset["test"].select(range(1000))
predictions = []
references = []

for example in test_dataset:
    article = tokenizer.decode(example["input_ids"], skip_special_tokens=True)
    pred_summary = generate_summary(article)
    ref_summary = tokenizer.decode(example["labels"], skip_special_tokens=True)

    predictions.append(pred_summary)
    references.append(ref_summary)

rouge_scores = rouge.compute(predictions=predictions, references=references, rouge_types=["rouge1", "rouge2", "rougeL"])

print("Rouge Scores:", rouge_scores)

# Calculate BLEU scores
smooth = SmoothingFunction().method1
bleu_scores = [sentence_bleu([ref.split()], pred.split(), smoothing_function=smooth) for ref, pred in zip(references, predictions)]
avg_bleu = np.mean(bleu_scores)

print("Average BLEU Score:", avg_bleu)