In [1]:
!pip install transformers datasets -qU
!pip install pytorch_lightning -q
!pip install wandb -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.9/7.9 MB[0m [31m52.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m493.7/493.7 kB[0m [31m45.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m311.1/311.1 kB[0m [31m32.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m84.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m64.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m15.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m295.0/295.0 kB[0m [31m30.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━

In [2]:
import numpy as np
from tqdm.auto import tqdm
from typing import Tuple, Dict
from dataclasses import dataclass

import torch
from torch.utils.data import Dataset, DataLoader


import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger

from transformers import AdamW, BartTokenizer, BartForConditionalGeneration
from datasets import load_dataset, train_test_split

ImportError: ignored

In [3]:
pl.seed_everything(42)

INFO:lightning_fabric.utilities.seed:Seed set to 42


42

In [4]:
@dataclass
class Cfg:
    DATASET_LOC = "cnn_dailymail"
    CONFIG = "3.0.0"
    MODEL_NAME = "facebook/bart-base"
    padding = "max_length"
    truncation = True
    add_special_tokens = True
    batch_size = 8
    num_workers = 2

In [5]:
cnn_dataset = load_dataset(Cfg.DATASET_LOC, Cfg.CONFIG, split="train[:15000]")

Downloading builder script:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/9.88k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/15.1k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/5 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/159M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/376M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/12.3M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/661k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/572k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/287113 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13368 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11490 [00:00<?, ? examples/s]

In [6]:
cnn_dataset = cnn_dataset.train_test_split(test_size=0.2)
cnn_dataset

DatasetDict({
    train: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 12000
    })
    test: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 3000
    })
})

In [7]:
def prepare_input(tokenizer: BartTokenizer, text: str, max_len) -> Dict:
    """Tokenize and prepare the input text using the provided tokenizer.

    Args:
        tokenizer (RobertaTokenizer): The Roberta tokenizer to encode the input.
        text (str): The input text to be tokenized.

    Returns:
        inputs (dict): A dictionary containing the tokenized input with keys such as 'input_ids',
            'attention_mask', etc.
    """
    inputs = tokenizer.encode_plus(
        text,
        return_tensors="pt",
        max_length=max_len,
        padding=Cfg.padding,
        truncation=Cfg.truncation,
        add_special_tokens=Cfg.add_special_tokens,
    )
    return inputs

In [8]:
class SummaryDataset(Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text_encoding = prepare_input(self.tokenizer, self.data[idx]["article"], 512)
        summary_encoding = prepare_input(self.tokenizer, self.data[idx]["highlights"], 256)

        return dict(
            text=self.data["article"],
            summary=self.data["highlights"],
            text_input_ids=text_encoding["input_ids"].flatten(),
            text_attention_mask=text_encoding["attention_mask"].flatten(),
            summary_input_ids=summary_encoding["input_ids"].flatten(),
            summary_attention_mask=summary_encoding["attention_mask"].flatten(),
        )

In [9]:
class SummaryDataModule(pl.LightningDataModule):
    def __init__(self, data, tokenizer, batch_size):
        super().__init__()

        self.ds = data
        self.train_ds = data["train"]
        self.val_ds = data["test"]
        # self.test_ds = data['test']
        self.tokenizer = tokenizer
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.train_dataset = SummaryDataset(self.train_ds, self.tokenizer)
        self.val_dataset = SummaryDataset(self.val_ds, self.tokenizer)
        # self.test_dataset = SummaryDataset(self.test_ds, self.tokenizer)

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=Cfg.num_workers)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=Cfg.num_workers)

    # def test_dataloader(self):
    #     return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=Cfg.num_workers)

In [10]:
tokenizer = BartTokenizer.from_pretrained(Cfg.MODEL_NAME)

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

In [11]:
data = SummaryDataModule(cnn_dataset, tokenizer, Cfg.batch_size)

In [12]:
class SummaryModel(pl.LightningModule):
    def __init__(self):
        super().__init__()

        self.model = BartForConditionalGeneration.from_pretrained(Cfg.MODEL_NAME, return_dict=True)

    def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None):
        output = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            decoder_attention_mask=decoder_attention_mask,
        )

        return output.loss, output.logits

    def training_step(self, batch, batch_idx):
        input_ids = batch["text_input_ids"]
        attention_mask = batch["text_attention_mask"]
        labels = batch["summary_input_ids"]
        labels_attention_mask = batch["summary_attention_mask"]

        loss, outputs = self(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            decoder_attention_mask=labels_attention_mask,
        )

        self.log("train_loss", loss, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        input_ids = batch["text_input_ids"]
        attention_mask = batch["text_attention_mask"]
        labels = batch["summary_input_ids"]
        labels_attention_mask = batch["summary_attention_mask"]

        loss, outputs = self(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            decoder_attention_mask=labels_attention_mask,
        )

        self.log("val_loss", loss, prog_bar=True, logger=True)
        return loss

    # def test_step(self, batch, batch_idx):
    #     input_ids = batch["text_input_ids"]
    #     attention_mask = batch["text_attention_mask"]
    #     labels = batch["summary_input_ids"]
    #     labels_attention_mask = batch["summary_attention_mask"]

    #     loss, outputs = self(
    #         input_ids=input_ids,
    #         attention_mask=attention_mask,
    #         labels=labels,
    #         decoder_attention_mask=labels_attention_mask,
    #     )

    #     self.log("test_loss", loss, prog_bar=True, logger=True)
    #     return loss

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=0.0001)
        return optimizer

In [13]:
model = SummaryModel()

Downloading model.safetensors:   0%|          | 0.00/558M [00:00<?, ?B/s]

In [14]:
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    mode="min",
    dirpath="artifacts",
    filename="best-checkpoint",
    save_top_k=1,
    verbose=True,
)

wandb_logger = WandbLogger(project="Text_Summarization-bart-cnn")

In [18]:
trainer = pl.Trainer(
    logger=wandb_logger,
    callbacks=checkpoint_callback,
    max_epochs=1,
    accelerator="gpu",
    devices=1,
)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [19]:
trainer.fit(model, data)

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type                         | Params
-------------------------------------------------------
0 | model | BartForConditionalGeneration | 139 M 
-------------------------------------------------------
139 M     Trainable params
0         Non-trainable params
139 M     Total params
557.682   Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:Epoch 0, global step 1500: 'val_loss' reached 0.47670 (best 0.47670), saving model to '/content/artifacts/best-checkpoint.ckpt' as top 1
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=1` reached.


In [20]:
trained_model = SummaryModel.load_from_checkpoint("/content/artifacts/best-checkpoint.ckpt")
trained_model.freeze()

In [65]:
def summarize(text):
    text_encoding = prepare_input(tokenizer, text, 512).to("cuda:0")

    generated_ids = trained_model.model.generate(
        input_ids=text_encoding["input_ids"],
        attention_mask=text_encoding["attention_mask"],
        max_length=128,
        num_beams=2,
        repetition_penalty=2.5,
        length_penalty=1.75,
        early_stopping=True,
    )

    preds = [tokenizer.decode(gen_id, skip_special_tokens=True, cleanup_tokenization_spaces=True) for gen_id in generated_ids]

    return "".join(preds)

In [23]:
cnn_dataset["test"]["article"][5]

'(CNN) -- Barcelona and Real Madrid have both played down suggestions that Saturday\'s "El Clasico" showdown will decide the Spanish league title. With eight matches left in the season, the two bitter rivals are locked on 77 points at the top of the table with Real ahead by just one goal on "for-and-against" differential. Barcelona coach Josep Guardiola  said his team would treat the match -- traditionally the biggest fixtures in the La Liga schedule -- as "a final" but insisted that defeat would not be terminal for either team\'s title hopes. "If there were only three or four games to go I would say it is an almost decisive match, but when there are seven left afterwards it\'s not so much -- but it is very important," he told reporters on Friday. "It\'s not a final, but we need to play as if it was one. It\'s a game where the winner will strike a blow to the other." Barcelona triumphed 6-2 in the Spanish capital last season to move seven points clear with four games to play, and cruis

In [45]:
import pprint

pp = pprint.PrettyPrinter(width=100, indent=4)

In [46]:
pp.pprint(cnn_dataset["test"]["highlights"][5])

('Real Madrid and Barcelona locked on 77 points at the top of the Spanish league table .\n'
 'Real have slight edge by just one goal on "for-and-against" differential with eight games left '
 '.\n'
 "Both teams' coaches insist that victory will not decide the La Liga crown .\n"
 'Defending champions Barcelona won the 79th "El Clasico" 1-0 at home in November .')


In [58]:
pp.pprint(summarize(cnn_dataset["test"]["article"][5]))

('Barcelona and Real Madrid play down suggestions that Saturday\'s "El Clasico" will decide the '
 'Spanish league title.\n'
 'The two rivals are locked on 77 points at the top of the table with Real ahead by just one goal '
 'on "for-and-against" differential.\n'
 'Real have won 50 of the 79 encounters between the two teams since 1929, losing just 15 times.')


In [67]:
pp.pprint(
    summarize(
        "During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930."
    )
)

('The Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in '
 'the world.\n'
 'The Chrysler Building in New York City was finished in 1930.')


In [70]:
input = "The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct."

input_list = input.split(".")

output_list = [summarize(sentence) for sentence in input_list]

for sentence in output_list:
    pp.pprint(sentence)

o = "".join(output_list)

pp.pprint(summarize(o))

('The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building.\n'
 'It is the tallest structure in Paris, with 1,062 ft of height.')
('The base is square, measuring 125 metres (410 ft) on each side.\n'
 'Its base is a square of 125 metres in diameter; it measures 125 metres wide.')
('The Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in '
 'the world.\n'
 'It was built in 1930 and finished in 1930.\n'
 'The Chrysler Building in New York City was finished in 1931.')
('The structure reaches a height of 300 metres.\n'
 'It was the first structure to reach a height above 300 metres in three years.')
('The tower is now taller than the Chrysler Building by 5 feet.\n'
 'It is also taller than a Chrysler building by 5 inches.')
('2 metres (17 ft) tall and 17 ft long, respectively.\n'
 'The height of the two-metre wide area is about 17 ft deep.')
'The Eiffel Tower is the second tallest free-standing structure in France after t