In [None]:
!pip install transformers datasets -qU
!pip install pytorch_lightning -q
!pip install wandb -q

In [None]:
import numpy as np
from tqdm.auto import tqdm
from typing import Tuple, Dict
from dataclasses import dataclass

import torch
from torch.utils.data import Dataset, DataLoader


import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger

from transformers import AdamW, BartTokenizer, BartForConditionalGeneration
from datasets import load_dataset, train_test_split

In [None]:
pl.seed_everything(42)

In [None]:
@dataclass
class Cfg:
    DATASET_LOC = "cnn_dailymail"
    CONFIG = "3.0.0"
    MODEL_NAME = "facebook/bart-base"
    padding = "max_length"
    truncation = True
    add_special_tokens = True
    batch_size = 8
    num_workers = 2

In [None]:
cnn_dataset = load_dataset(Cfg.DATASET_LOC, Cfg.CONFIG, split="train[:15000]")

In [None]:
cnn_dataset = cnn_dataset.train_test_split(test_size=0.2)
cnn_dataset

In [None]:
def prepare_input(tokenizer: BartTokenizer, text: str, max_len) -> Dict:
    """Tokenize and prepare the input text using the provided tokenizer.

    Args:
        tokenizer (RobertaTokenizer): The Roberta tokenizer to encode the input.
        text (str): The input text to be tokenized.

    Returns:
        inputs (dict): A dictionary containing the tokenized input with keys such as 'input_ids',
            'attention_mask', etc.
    """
    inputs = tokenizer.encode_plus(
        text,
        return_tensors="pt",
        max_length=max_len,
        padding=Cfg.padding,
        truncation=Cfg.truncation,
        add_special_tokens=Cfg.add_special_tokens,
    )
    return inputs

In [None]:
class SummaryDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text_encoding = prepare_input(self.tokenizer, self.data[idx]["article"], 512)
        summary_encoding = prepare_input(self.tokenizer, self.data[idx]["highlights"], 256)

        return dict(
            text=self.data["article"],
            summary=self.data["highlights"],
            text_input_ids=text_encoding["input_ids"].flatten(),
            text_attention_mask=text_encoding["attention_mask"].flatten(),
            summary_input_ids=summary_encoding["input_ids"].flatten(),
            summary_attention_mask=summary_encoding["attention_mask"].flatten(),
        )

In [None]:
class SummaryDataModule(pl.LightningDataModule):
    def __init__(self, data, tokenizer, batch_size):
        super().__init__()

        self.ds = data
        self.train_ds = data["train"]
        self.val_ds = data["test"]
        # self.test_ds = data['test']
        self.tokenizer = tokenizer
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train_dataset = SummaryDataset(self.train_ds, self.tokenizer)
        self.val_dataset = SummaryDataset(self.val_ds, self.tokenizer)
        # self.test_dataset = SummaryDataset(self.test_ds, self.tokenizer)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=Cfg.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=Cfg.num_workers)

    # def test_dataloader(self):
    #     return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=Cfg.num_workers)

In [None]:
tokenizer = BartTokenizer.from_pretrained(Cfg.MODEL_NAME)

In [None]:
data = SummaryDataModule(cnn_dataset, tokenizer, Cfg.batch_size)

In [None]:
class SummaryModel(pl.LightningModule):
    def __init__(self):
        super().__init__()

        self.model = BartForConditionalGeneration.from_pretrained(Cfg.MODEL_NAME, return_dict=True)

    def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
        output = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            decoder_attention_mask=decoder_attention_mask,
        )

        return output.loss, output.logits

    def training_step(self, batch, batch_idx):
        input_ids = batch["text_input_ids"]
        attention_mask = batch["text_attention_mask"]
        labels = batch["summary_input_ids"]
        labels_attention_mask = batch["summary_attention_mask"]

        loss, outputs = self(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            decoder_attention_mask=labels_attention_mask,
        )

        self.log("train_loss", loss, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        input_ids = batch["text_input_ids"]
        attention_mask = batch["text_attention_mask"]
        labels = batch["summary_input_ids"]
        labels_attention_mask = batch["summary_attention_mask"]

        loss, outputs = self(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            decoder_attention_mask=labels_attention_mask,
        )

        self.log("val_loss", loss, prog_bar=True, logger=True)
        return loss

    # def test_step(self, batch, batch_idx):
    #     input_ids = batch["text_input_ids"]
    #     attention_mask = batch["text_attention_mask"]
    #     labels = batch["summary_input_ids"]
    #     labels_attention_mask = batch["summary_attention_mask"]

    #     loss, outputs = self(
    #         input_ids=input_ids,
    #         attention_mask=attention_mask,
    #         labels=labels,
    #         decoder_attention_mask=labels_attention_mask,
    #     )

    #     self.log("test_loss", loss, prog_bar=True, logger=True)
    #     return loss

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=0.0001)
        return optimizer

In [None]:
model = SummaryModel()

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    mode="min",
    dirpath="artifacts",
    filename="best-checkpoint",
    save_top_k=1,
    verbose=True,
)

wandb_logger = WandbLogger(project="Text_Summarization-bart-cnn")

In [None]:
trainer = pl.Trainer(
    logger=wandb_logger,
    callbacks=checkpoint_callback,
    max_epochs=1,
    accelerator="gpu",
    devices=1,
)

In [None]:
trainer.fit(model, data)

In [None]:
trained_model = SummaryModel.load_from_checkpoint("/content/artifacts/best-checkpoint.ckpt")
trained_model.freeze()

In [None]:
def summarize(text):
    text_encoding = prepare_input(tokenizer, text, 512).to("cuda:0")

    generated_ids = trained_model.model.generate(
        input_ids=text_encoding["input_ids"],
        attention_mask=text_encoding["attention_mask"],
        max_length=128,
        num_beams=2,
        repetition_penalty=2.5,
        length_penalty=1.75,
        early_stopping=True,
    )

    preds = [tokenizer.decode(gen_id, skip_special_tokens=True, cleanup_tokenization_spaces=True) for gen_id in generated_ids]

    return "".join(preds)

In [None]:
cnn_dataset["test"]["article"][5]

In [None]:
import pprint

pp = pprint.PrettyPrinter(width=100, indent=4)

In [None]:
pp.pprint(cnn_dataset["test"]["highlights"][5])

In [None]:
pp.pprint(summarize(cnn_dataset["test"]["article"][5]))

In [None]:
pp.pprint(
    summarize(
        "During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930."
    )
)

In [None]:
input = "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct."

input_list = input.split(".")

output_list = [summarize(sentence) for sentence in input_list]

for sentence in output_list:
    pp.pprint(sentence)

o = "".join(output_list)

pp.pprint(summarize(o))