
# Fine-tuning BART for summarization with filtration

---

## Setup

---

**The comments for this will be mostly similar to that of the previous model I will add the additional comments wherever necessary**

In [5]:
!pip install transformers[torch]
!pip install accelerate -U

Collecting accelerate>=0.20.3 (from transformers[torch])
  Downloading accelerate-0.24.1-py3-none-any.whl (261 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.4/261.4 kB[0m [31m8.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: accelerate
Successfully installed accelerate-0.24.1


In [1]:
# !pip install spacy==3.0.6     # install spaCy (version 3.0.6)
!python -m spacy download en_core_web_lg    # download en_core_web_lg model

import re
import spacy
nlp = spacy.load("en_core_web_lg")

# Defining the functions for filtering the data

def entity_based_filtered_sentences(example):
    # if not (text and summary):
    #     return ""
    text = example["article"]                                                   # The text parts of the input
    summary = example["highlights"]                                             # The summary parts of the input
    doc = nlp(summary)
    sentences_select = {}
    for sent in doc.sents:
        sentences_select[sent.text] = True
    for e in doc.ents:
        if e[0].ent_type_ in ['PERSON', 'FAC', 'GPE', 'ORG', 'NORP', 'LOC', 'EVENT']:
            # match_result = re.search(e.text,text)
            # print(e,match_result)
            # if match_result==None:
            if e.text.lower() not in text.lower():
                # print(e)
                sentences_select[e.sent.text] = False
    result = []
    for sent in doc.sents:
        if sentences_select[sent.text]:
            result.append(sent.text)
    filter_summary =  " ".join(result)
    example["highlights"] = filter_summary
    return example

def create_ent_augmented_target(example):
    text = example["article"]
    summary = example["highlights"]
    entity_summary = []
    doc = nlp(summary)
    entities = []
    for e in doc.ents:
        if e[0].ent_type_ in ['PERSON', 'FAC', 'GPE', 'ORG', 'NORP', 'LOC', 'EVENT']:
            if e.text.lower() in text.lower():
                entities.append(e.text)

    entity_summary =  " ".join(entities)
    entity_summary = entity_summary + " " + summary
    example["highlights"] = entity_summary
    return example

2023-12-01 06:04:20.722266: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-12-01 06:04:20.722331: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-12-01 06:04:20.722367: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-12-01 06:04:20.730497: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-12-01 06:04:23.896437: I tensorflow/c

In [4]:
%%capture
! pip install transformers
! pip install datasets
! pip install sentencepiece
! pip install rouge_score
! pip install wandb

In [2]:
import json

In [3]:
# Same as earlier

import torch
import numpy as np
import datasets

from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
)

from tabulate import tabulate
import nltk
from datetime import datetime

In [None]:
WANDB_INTEGRATION = True
if WANDB_INTEGRATION:
    import wandb

    wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


## Set language

---

English

In [4]:
language = "english"

## Model and tokenizer

---

Download model and tokenizer. Use default parameters or try custom values (see [HF Bart configuration](https://huggingface.co/transformers/_modules/transformers/configuration_bart.html) and [Fairseq Bart](https://github.com/pytorch/fairseq/tree/master/examples/bart)).

In [5]:
# Same as earlier

model_name = "facebook/bart-base"

model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Set model parameters or use the default
# print(model.config)

# tokenization
encoder_max_length = 256
decoder_max_length = 64

config.json:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/558M [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

## Data

---

### Download

In [6]:
data_complete = datasets.load_dataset("cnn_dailymail",'3.0.0',split="train[:10%]")

Downloading builder script:   0%|          | 0.00/8.33k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/9.88k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/15.1k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/5 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/159M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/376M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/12.3M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/661k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/572k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/287113 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13368 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11490 [00:00<?, ? examples/s]

In [7]:
data_complete[0]

{'article': 'LONDON, England (Reuters) -- Harry Potter star Daniel Radcliffe gains access to a reported £20 million ($41.1 million) fortune as he turns 18 on Monday, but he insists the money won\'t cast a spell on him. Daniel Radcliffe as Harry Potter in "Harry Potter and the Order of the Phoenix" To the disappointment of gossip columnists around the world, the young actor says he has no plans to fritter his cash away on fast cars, drink and celebrity parties. "I don\'t plan to be one of those people who, as soon as they turn 18, suddenly buy themselves a massive sports car collection or something similar," he told an Australian interviewer earlier this month. "I don\'t think I\'ll be particularly extravagant. "The things I like buying are things that cost about 10 pounds -- books and CDs and DVDs." At 18, Radcliffe will be able to gamble in a casino, buy a drink in a pub or see the horror film "Hostel: Part II," currently six places below his number one movie on the UK box office char

### Prepare

**Format and split into train and validation sets**

In [8]:
# Same as earlier

def flatten(example):
    return {
        "document": example["article"]["document"],
        "summary": example["article"]["summary"],
    }


def list2samples(example):
    documents = []
    summaries = []
    for sample in zip(example["article"], example["highlights"]):
        if len(sample[0]) > 0:
            documents += sample[0]
            summaries += sample[1]
    return {"document": documents, "summary": summaries}


# dataset = data.map(flatten, remove_columns=["article", "url"])
# dataset = data.map(list2samples, batched=True)

train_data_txt, validation_data_txt = data_complete.train_test_split(test_size=0.1).values()

In [9]:
train_data_txt[0]

{'article': "PARIS, France (CNN)  -- In a city famous for being the birthplace of the avant-garde, it can be hard to keep up with the latest trends. Here's a rough guide to what's hot right now in the French capital. This beat is Tecktonik: The latest dance craze to hit the Parisian streets. Tecktonik Parisian youths love their trends. The latest dance craze sweeping the city is Tecktonik, a fusion dance style usually accessorized with spiked hair and neon accessories. Look out for kids dancing in packs outside the Trocadero. You might even be lucky enough to spot a Tecktonik/breakdance dance-off. Le Scrapbooking Scrapbooking is the current craze amongst Paris' more sedate residents. Head to Le Temple Du Scrap (13 Rue Ernest Cresson) for pretty paper supplies, trimmings, ribbons and associated frippery. Bike around town Much of Paris is walkable, but the city's cheap bike-hire scheme, Velib, which launched in 2007, makes dashing around the city even easier. There are thousands of bikes

In [10]:

# Apply 'entity_based_filtered_sentences' to filter sentences based on entities
train_data_filtered = [entity_based_filtered_sentences(example) for example in train_data_txt]

# Convert the filtered data into dataset
train_data_filtered_dataset = datasets.Dataset.from_dict({
    "article": [example["article"] for example in train_data_filtered],
    "highlights": [example["highlights"] for example in train_data_filtered],
    "id":[example["id"] for example in train_data_filtered]
})


In [11]:
# Block to save the filtered data

import pickle

with open('train_data_filtered_dataset.pkl', 'wb') as file:
    pickle.dump(train_data_filtered_dataset, file)

from google.colab import files

files.download('train_data_filtered_dataset.pkl')



<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [12]:
# Apply 'create_ent_augmented_target' to create an augmented target with entities
train_data_augmented = [create_ent_augmented_target(example) for example in train_data_filtered]

# Convert the filtered data into dataset
train_data_augmented_dataset = datasets.Dataset.from_dict({
    "article": [example["article"] for example in train_data_augmented],
    "highlights": [example["highlights"] for example in train_data_augmented],
    "id":[example["id"] for example in train_data_augmented]
})

In [13]:
# Block to save the filtered data

import pickle

with open('train_data_augmented_dataset.pkl', 'wb') as file:
    pickle.dump(train_data_augmented_dataset, file)

from google.colab import files

files.download('train_data_augmented_dataset.pkl')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
# Showing the formed training data

print(type(train_data_augmented))
print(train_data_augmented[0])

<class 'list'>
{'article': '(CNN) -- A man, incensed that a 6-year-old girl chose to walk through a path reserved for upper caste villagers, pushed her into burning embers, police in north India said Wednesday. She was seriously burned. Dalits, or "untouchables," are victims of discrimination in India despite laws aimed at eliminating prejudice. The girl is a Dalit, or an "untouchable," according to India\'s traditional caste system. India\'s constitution outlaws caste-based discrimination, and barriers have broken down in large cities. Prejudice, however, persists in some rural areas of the country. The girl was walking with her mother down a path in the city of Mathura when she was accosted by a man in his late teens, said police superintendent R.K. Chaturvedi. "He scolded them both and pushed her," Chaturvedi said. The girl fell about 3 to 4 feet into pile of burning embers by the side of the road. The girl remained in critical condition Wednesday. The man confessed to the crime and

In [None]:
# Showing the formed validation data

print(type(validation_data_txt))
print(validation_data_txt[0])

<class 'datasets.arrow_dataset.Dataset'>
{'article': '(CNN) -- Opponents and supporters of Venezuela\'s government staged rival demonstrations Sunday in the streets of the capital to mark the anniversary of a popular revolt that overthrew a dictatorship in the South American country in 1958. In Caracas, supporters of Venezuelan President Hugo Chavez gathered around the presidential palace to listen to him speak on what is known there as "National Democracy Day." "Every day, there will be more democracy in Venezuela -- this democracy that gives more power to the people," Chavez said, as reported by the state-run AVN news agency. "Democracy is as necessary to socialism as oxygen is to living things." The president\'s supporters cheered and waved flags and banners. Critics of Chavez\'s government met in the eastern part of the capital, the state-run VTV network reported. Many waved white banners. One protester carried a poster that read: "Enough of the lies." Marches were also planned in 

**Preprocess and tokenize**

In [None]:
# Same as before

def batch_tokenize_preprocess(batch, tokenizer, max_source_length, max_target_length):
    source, target = batch["article"], batch["highlights"]
    source_tokenized = tokenizer(
        source, padding="max_length", truncation=True, max_length=max_source_length
    )
    target_tokenized = tokenizer(
        target, padding="max_length", truncation=True, max_length=max_target_length
    )

    batch = {k: v for k, v in source_tokenized.items()}
    # Ignore padding in the loss
    batch["labels"] = [
        [-100 if token == tokenizer.pad_token_id else token for token in l]
        for l in target_tokenized["input_ids"]
    ]
    return batch


train_data = train_data_augmented_dataset.map(
    lambda batch: batch_tokenize_preprocess(
        batch, tokenizer, encoder_max_length, decoder_max_length
    ),
    batched=True,
    remove_columns=train_data_txt.column_names,
)

validation_data = validation_data_txt.map(
    lambda batch: batch_tokenize_preprocess(
        batch, tokenizer, encoder_max_length, decoder_max_length
    ),
    batched=True,
    remove_columns=validation_data_txt.column_names,
)

Map:   0%|          | 0/25839 [00:00<?, ? examples/s]

Map:   0%|          | 0/2872 [00:00<?, ? examples/s]

## Training

---

### Metrics

In [None]:
# Borrowed from https://github.com/huggingface/transformers/blob/master/examples/seq2seq/run_summarization.py

nltk.download("punkt", quiet=True)

metric = datasets.load_metric("rouge")


def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )
    # Extract a few results from ROUGE
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}

    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
    ]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

  metric = datasets.load_metric("rouge")


Downloading builder script:   0%|          | 0.00/2.17k [00:00<?, ?B/s]

### Training arguments

In [None]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer


In [None]:
!pip install accelerate==0.20.1


Collecting accelerate==0.20.1
  Downloading accelerate-0.20.1-py3-none-any.whl (227 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/227.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━[0m [32m92.2/227.5 kB[0m [31m2.7 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m227.5/227.5 kB[0m [31m4.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: accelerate
  Attempting uninstall: accelerate
    Found existing installation: accelerate 0.24.1
    Uninstalling accelerate-0.24.1:
      Successfully uninstalled accelerate-0.24.1
Successfully installed accelerate-0.20.1


In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="results",
    num_train_epochs=1,  # demo
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=4,  # demo
    per_device_eval_batch_size=4,
    # learning_rate=3e-05,
    warmup_steps=500,
    weight_decay=0.1,
    label_smoothing_factor=0.1,
    predict_with_generate=True,
    logging_dir="logs",
    logging_steps=50,
    save_total_limit=3,
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_data,
    eval_dataset=validation_data,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

### Train

Wandb integration

In [None]:
if WANDB_INTEGRATION:
    wandb_run = wandb.init(
        project="Fine-tune Bart on CNN-daily filtered",
        config={
            "per_device_train_batch_size": training_args.per_device_train_batch_size,
            "learning_rate": training_args.learning_rate,
            "dataset": "Fine-tune Bart on CNN-daily filtered"
        },
    )

    now = datetime.now()
    current_time = now.strftime("%H%M%S")
    wandb_run.name = "run_" + "Fine-tune Bart on CNN-daily filtered" + "_" + current_time

[34m[1mwandb[0m: Currently logged in as: [33mjindalmohit351[0m ([33mmj2[0m). Use [1m`wandb login --relogin`[0m to force relogin


Evaluate before fine-tuning

In [None]:
trainer.evaluate()

You're using a BartTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'eval_loss': 6.05169677734375,
 'eval_rouge1': 16.43,
 'eval_rouge2': 5.3818,
 'eval_rougeL': 13.1797,
 'eval_rougeLsum': 15.0031,
 'eval_gen_len': 20.0,
 'eval_runtime': 270.908,
 'eval_samples_per_second': 10.601,
 'eval_steps_per_second': 2.65}

Train the model

In [None]:
%%wandb
# uncomment to display Wandb charts

trainer.train()

Step,Training Loss
50,5.7135
100,4.8331
150,4.5262
200,4.257
250,4.0294
300,4.1264
350,4.0402
400,3.9232
450,4.0111
500,3.9975


TrainOutput(global_step=6460, training_loss=3.6951974845153996, metrics={'train_runtime': 1462.3612, 'train_samples_per_second': 17.669, 'train_steps_per_second': 4.418, 'total_flos': 3938745086115840.0, 'train_loss': 3.6951974845153996, 'epoch': 1.0})

Evaluate after fine-tuning

In [None]:
trainer.evaluate()



{'eval_loss': 3.6375041007995605,
 'eval_rouge1': 21.7051,
 'eval_rouge2': 8.5602,
 'eval_rougeL': 17.3585,
 'eval_rougeLsum': 20.0343,
 'eval_gen_len': 20.0,
 'eval_runtime': 262.5708,
 'eval_samples_per_second': 10.938,
 'eval_steps_per_second': 2.735,
 'epoch': 1.0}

In [None]:
if WANDB_INTEGRATION:
    wandb_run.finish()

VBox(children=(Label(value='0.001 MB of 0.021 MB uploaded\r'), FloatProgress(value=0.05643371529507677, max=1.…

0,1
eval/gen_len,▁▁
eval/loss,█▁
eval/rouge1,▁█
eval/rouge2,▁█
eval/rougeL,▁█
eval/rougeLsum,▁█
eval/runtime,█▁
eval/samples_per_second,▁█
eval/steps_per_second,▁█
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇████

0,1
eval/gen_len,20.0
eval/loss,3.6375
eval/rouge1,21.7051
eval/rouge2,8.5602
eval/rougeL,17.3585
eval/rougeLsum,20.0343
eval/runtime,262.5708
eval/samples_per_second,10.938
eval/steps_per_second,2.735
train/epoch,1.0


## Evaluation

---

**Generate summaries from the fine-tuned model and compare them with those generated from the original, pre-trained one.**

In [None]:
def generate_summary(test_samples, model):
    inputs = tokenizer(
        test_samples["article"],
        padding="max_length",
        truncation=True,
        max_length=encoder_max_length,
        return_tensors="pt",
    )
    input_ids = inputs.input_ids.to(model.device)
    attention_mask = inputs.attention_mask.to(model.device)
    outputs = model.generate(input_ids, attention_mask=attention_mask)
    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return outputs, output_str


model_before_tuning = AutoModelForSeq2SeqLM.from_pretrained(model_name)

test_samples = validation_data_txt.select(range(16))

summaries_before_tuning = generate_summary(test_samples, model_before_tuning)[1]
summaries_after_tuning = generate_summary(test_samples, model)[1]

In [None]:
print(
    tabulate(
        zip(
            range(len(summaries_after_tuning)),
            summaries_after_tuning,
            summaries_before_tuning,
        ),
        headers=["Id", "Summary after", "Summary before"],
    )
)
print("\nTarget summaries:\n")
print(
    tabulate(list(enumerate(test_samples["highlights"])), headers=["Id", "Target summary"])
)
print("\nSource documents:\n")
print(tabulate(list(enumerate(test_samples["article"])), headers=["Id", "Document"]))

  Id  Summary after                                                                                               Summary before
----  ----------------------------------------------------------------------------------------------------------  ---------------------------------------------------------------------------------------------------------
   0  Venezuelan Hugo Chavez Caracas NEW: Opponents and supporters of Venezuela's government                      (CNN) -- Opponents and supporters of Venezuela's government staged rival demonstrations Sunday in
   1  Daniel Tosh Daniel Tosh reportedly singled out a woman in his audience and suggested she get raped          (CNN) -- When the comedian Daniel Tosh reportedly singled out a woman in his audience
   2  Agnieszka Radwanska Madrid Masters Lucie Hradecka                                                           (CNN) -- Agnieszka Radwanska has only lost to
   3  Andry Rajoelina Antananarivo Marc Ravalomanana Rajo                      

In [None]:
trainer.save_model("finetune_cnn_dm_filtered")

In [None]:
!zip -r /content/finetune_cnn_dm_filtered.zip /content/finetune_cnn_dm_filtered

In [None]:
from google.colab import files
files.download('/content/finetune_cnn_dm_filtered.zip')