# Automating Text Summarization for News Articles - Finetunning T5ForConditionalGeneration over CNN Daily mail dataset

## Install Required Libraries

In [None]:
pip install torch pytorch-lightning transformers datasets


## Prepare the Dataset

In [None]:
from datasets import load_dataset
dataset = load_dataset("cnn_dailymail", "3.0.0")

## Define the Data Module

In [3]:
from torch.utils.data import DataLoader
from pytorch_lightning import LightningDataModule
from transformers import T5Tokenizer

class SummarizationDataModule(LightningDataModule):
    def __init__(self, dataset, tokenizer_name='t5-small', batch_size=4):
        super().__init__()
        self.dataset = dataset
        self.tokenizer = T5Tokenizer.from_pretrained(tokenizer_name)
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train_dataset = self.dataset["train"]
        self.val_dataset = self.dataset["validation"]
        self.test_dataset = self.dataset["test"]

    def collate_fn(self, batch):
        articles = [item['article'] for item in batch]
        summaries = [item['highlights'] for item in batch]
        encodings = self.tokenizer(articles, max_length=512, truncation=True, padding="max_length", return_tensors="pt")
        labels = self.tokenizer(summaries, max_length=150, truncation=True, padding="max_length", return_tensors="pt").input_ids
        labels[labels == 0] = -100  # To ignore pad tokens in loss computation
        return dict(input_ids=encodings.input_ids, attention_mask=encodings.attention_mask, labels=labels)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn)

## Define the Model

In [4]:
from pytorch_lightning import LightningModule
from transformers import T5ForConditionalGeneration, AdamW

class NewsSummarizer(LightningModule):
    def __init__(self, model_name='t5-small', learning_rate=2e-4):
        super().__init__()
        self.model = T5ForConditionalGeneration.from_pretrained(model_name)
        self.learning_rate = learning_rate

    def forward(self, input_ids, attention_mask, labels=None):
        output = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        return output.loss, output.logits

    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, labels = batch["input_ids"], batch["attention_mask"], batch["labels"]
        loss, _ = self(input_ids, attention_mask, labels)
        self.log("train_loss", loss, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, labels = batch["input_ids"], batch["attention_mask"], batch["labels"]
        loss, _ = self(input_ids, attention_mask, labels)
        self.log("val_loss", loss, prog_bar=True, logger=True)

    def configure_optimizers(self):
        return AdamW(self.parameters(), lr=self.learning_rate)


## Train the Model

In [None]:
from pytorch_lightning import Trainer
import torch

data_module = SummarizationDataModule(dataset)
model = NewsSummarizer()

trainer = Trainer(max_epochs=3, devices=1 if torch.cuda.is_available() else 0, accelerator="gpu")
trainer.fit(model, datamodule=data_module)


## Inference

In [None]:
def generate_summary(article, model, tokenizer, device='cuda'):
    model.eval()
    model.to(device)
    inputs = tokenizer.encode("summarize: " + article, return_tensors="pt", max_length=512, truncation=True).to(device)
    summary_ids = model.model.generate(inputs, max_length=150, min_length=40, length_penalty=2.0, num_beams=4, early_stopping=True)
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return summary

# Example usage
article = "Your new article text here."
summary = generate_summary(article, model, data_module.tokenizer)
print(summary)
