
# Fine-tuning BART Summarization

---

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

## Setup

---

In [None]:
!pip install ipywidgets
!jupyter nbextension enable --py widgetsnbextension

In [None]:
! pip install transformers
! pip install datasets
! pip install sentencepiece
! pip install rouge_score
! pip install wandb

In [None]:
import torch
import numpy as np
import datasets

from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq,
)

from tabulate import tabulate
import nltk
from datetime import datetime

In [None]:
WANDB_INTEGRATION = True
if WANDB_INTEGRATION:
    import wandb

    wandb.login()

## Model and tokenizer

---

Hiperparámetros: 

In [None]:
#Llamado del modelo
model_name = "facebook/bart-base"

#Definición de modelo y tokenizador
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Se fijan los parámetros del modelo
model.config.activation_dropout = 0.0
model.config.classif_dropout = 0.0
model.config.max_encoder_position_embeddings = 1024
model.config.max_decoder_position_embeddings = 16384
model.config.min_length = 350
model.config.max_length = 500
print(model.config)

  # "classif_dropout": 0.0,
  # "classifier_dropout": 0.0,
  # "d_model": 768,
  # "decoder_attention_heads": 12,
  # "decoder_ffn_dim": 3072,
  # "decoder_layerdrop": 0.0,
  # "decoder_layers": 6,
  # "decoder_start_token_id": 2,
  # "dropout": 0.1,
  # "encoder_attention_heads": 12,
  # "encoder_ffn_dim": 3072,
  # "encoder_layerdrop": 0.0,
  # "encoder_layers": 6,
  # "eos_token_id": 2,

# tokenización
encoder_max_length = 256 
decoder_max_length = 64

loading configuration file https://huggingface.co/facebook/bart-base/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/f5310d276a6d1648d00c32fadc8bf7b4607e0fbd5b404fc4a0045960aa2bdfdb.a243ed957122436adb0b8d8e9d20f896f45c174b6324d625ca0a20a84f72a910
Model config BartConfig {
  "_name_or_path": "facebook/bart-base",
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartModel"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.0,
  "d_model": 768,
  "decoder_attention_heads": 12,
  "decoder_ffn_dim": 3072,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 12,
  "encoder_ffn_dim": 3072,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id":

BartConfig {
  "_name_or_path": "facebook/bart-base",
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartModel"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.0,
  "classifier_dropout": 0.0,
  "d_model": 768,
  "decoder_attention_heads": 12,
  "decoder_ffn_dim": 3072,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 12,
  "encoder_ffn_dim": 3072,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id": 2,
  "gradient_checkpointing": false,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2"
  },
  "init_std": 0.02,
  "is_encoder_decoder": true,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1,
    "LABEL_2": 2
  },
  "max_decoder_position_embeddings": 16384,
  

# Data

---

### Descarga y Preparación de los Datos

### Cargado de Dataset

In [None]:
train_data_txt = datasets.load_dataset("ccdv/govreport-summarization", split="train")
validation_data_txt = datasets.load_dataset("ccdv/govreport-summarization", split="test")

No config specified, defaulting to: gov_report_summarization_dataset/document
Reusing dataset gov_report_summarization_dataset (/root/.cache/huggingface/datasets/ccdv___gov_report_summarization_dataset/document/1.0.0/32c7f04ec577e9dc99b889eff9d64a321a3a00b7332e8733eb32d39bebbdc34f)
No config specified, defaulting to: gov_report_summarization_dataset/document
Reusing dataset gov_report_summarization_dataset (/root/.cache/huggingface/datasets/ccdv___gov_report_summarization_dataset/document/1.0.0/32c7f04ec577e9dc99b889eff9d64a321a3a00b7332e8733eb32d39bebbdc34f)


**Preprocess and tokenize**

In [None]:
def batch_tokenize_preprocess(batch, tokenizer, max_source_length, max_target_length):
    source, target = batch["report"], batch["summary"]
    source_tokenized = tokenizer(
        source, padding="max_length", truncation=True, max_length=max_source_length
    )
    target_tokenized = tokenizer(
        target, padding="max_length", truncation=True, max_length=max_target_length
    )

    batch = {k: v for k, v in source_tokenized.items()}
    # Ignore padding in the loss
    batch["labels"] = [
        [-100 if token == tokenizer.pad_token_id else token for token in l]
        for l in target_tokenized["input_ids"]
    ]
    return batch


train_data = train_data_txt.map(
    lambda batch: batch_tokenize_preprocess(
        batch, tokenizer, encoder_max_length, decoder_max_length
    ),
    batched=True,
    remove_columns=train_data_txt.column_names,
)

validation_data = validation_data_txt.map(
    lambda batch: batch_tokenize_preprocess(
        batch, tokenizer, encoder_max_length, decoder_max_length
    ),
    batched=True,
    remove_columns=validation_data_txt.column_names,
)

Loading cached processed dataset at /root/.cache/huggingface/datasets/ccdv___gov_report_summarization_dataset/document/1.0.0/32c7f04ec577e9dc99b889eff9d64a321a3a00b7332e8733eb32d39bebbdc34f/cache-363a195acc9fabc3.arrow


  0%|          | 0/1 [00:00<?, ?ba/s]

## Training

---

### Metrics

In [None]:
# Borrowed from https://github.com/huggingface/transformers/blob/master/examples/seq2seq/run_summarization.py

nltk.download("punkt", quiet=True)

metric = datasets.load_metric("rouge")


def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [label.strip() for label in labels]

    # rougeLSum expects newline after each sentence
    preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
    labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]

    return preds, labels


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

    result = metric.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )
    # Extract a few results from ROUGE
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}

    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds
    ]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

### Training arguments

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="results",
    num_train_epochs=20,  # Bajo por ahora
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=32,  # numeros al poder de 2
    per_device_eval_batch_size=32,
    learning_rate=3e-04, #3e-05
    warmup_steps=500,
    weight_decay=0.1,
    label_smoothing_factor=0.1,
    predict_with_generate=True, #Para métricas ROUGE
    logging_dir="logs",
    logging_steps=500,
    save_total_limit=3,
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_data,
    eval_dataset=validation_data,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


### Train

Wandb integration

In [None]:
wandb

<module 'wandb' from '/usr/local/lib/python3.7/dist-packages/wandb/__init__.py'>

In [None]:
if WANDB_INTEGRATION:
    wandb_run = wandb.init(
        project="BART_GOVREP_FT",
        config={
            "per_device_train_batch_size": training_args.per_device_train_batch_size,
            "learning_rate": training_args.learning_rate,
            "dataset": "ccdv/govreport-summarization",
        },
    )

    now = datetime.now()
    current_time = now.strftime("%H%M%S")
    wandb_run.name = "run_" + current_time

Evaluate before fine-tuning

In [None]:
trainer.evaluate()

***** Running Evaluation *****
  Num examples = 973
  Batch size = 32


Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


{'eval_gen_len': 352.5231,
 'eval_loss': 4.9700188636779785,
 'eval_rouge1': 20.2761,
 'eval_rouge2': 9.3443,
 'eval_rougeL': 14.5055,
 'eval_rougeLsum': 18.6406,
 'eval_runtime': 1157.6208,
 'eval_samples_per_second': 0.841,
 'eval_steps_per_second': 0.027}

Train the model

In [None]:
torch.cuda.is_available()
import tensorflow as tf

In [None]:
tf.test.gpu_device_name()

'/device:GPU:0'

In [None]:
#%%wandb
# uncomment to display Wandb charts

trainer.train()

***** Running training *****
  Num examples = 17517
  Num Epochs = 8
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 1
  Total optimization steps = 4384


Step,Training Loss
500,3.7838
1000,3.4254
1500,3.1321
2000,2.9018
2500,2.6931
3000,2.5136
3500,2.3531
4000,2.2354


Saving model checkpoint to results/checkpoint-500
Configuration saved in results/checkpoint-500/config.json
Model weights saved in results/checkpoint-500/pytorch_model.bin
tokenizer config file saved in results/checkpoint-500/tokenizer_config.json
Special tokens file saved in results/checkpoint-500/special_tokens_map.json
Deleting older checkpoint [results/checkpoint-3000] due to args.save_total_limit
Saving model checkpoint to results/checkpoint-1000
Configuration saved in results/checkpoint-1000/config.json
Model weights saved in results/checkpoint-1000/pytorch_model.bin
tokenizer config file saved in results/checkpoint-1000/tokenizer_config.json
Special tokens file saved in results/checkpoint-1000/special_tokens_map.json
Deleting older checkpoint [results/checkpoint-3500] due to args.save_total_limit
Saving model checkpoint to results/checkpoint-1500
Configuration saved in results/checkpoint-1500/config.json
Model weights saved in results/checkpoint-1500/pytorch_model.bin
tokenizer 

TrainOutput(global_step=4384, training_loss=2.8161683883110102, metrics={'train_runtime': 3010.1358, 'train_samples_per_second': 46.555, 'train_steps_per_second': 1.456, 'total_flos': 2.136150707798016e+16, 'train_loss': 2.8161683883110102, 'epoch': 8.0})

Evaluate after fine-tuning

In [None]:
trainer.evaluate()

***** Running Evaluation *****
  Num examples = 973
  Batch size = 32


{'epoch': 8.0,
 'eval_gen_len': 351.0,
 'eval_loss': 3.6922292709350586,
 'eval_rouge1': 17.8533,
 'eval_rouge2': 7.0441,
 'eval_rougeL': 13.0016,
 'eval_rougeLsum': 16.6201,
 'eval_runtime': 846.3294,
 'eval_samples_per_second': 1.15,
 'eval_steps_per_second': 0.037}

In [None]:
if WANDB_INTEGRATION:
    wandb_run.finish()




VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
eval/gen_len,█▁
eval/loss,█▁
eval/rouge1,█▁
eval/rouge2,█▁
eval/rougeL,█▁
eval/rougeLsum,█▁
eval/runtime,█▁
eval/samples_per_second,▁█
eval/steps_per_second,▁█
train/epoch,▁▂▃▄▅▆▆▇██

0,1
eval/gen_len,351.0
eval/loss,3.69223
eval/rouge1,17.8533
eval/rouge2,7.0441
eval/rougeL,13.0016
eval/rougeLsum,16.6201
eval/runtime,846.3294
eval/samples_per_second,1.15
eval/steps_per_second,0.037
train/epoch,8.0


## Evaluation

---

**Generate summaries from the fine-tuned model and compare them with those generated from the original, pre-trained one.**

In [None]:
def generate_summary(test_samples, model):
    inputs = tokenizer(
        test_samples["report"],
        padding="max_length",
        truncation=True,
        max_length=encoder_max_length,
        return_tensors="pt",
    )
    input_ids = inputs.input_ids.to(model.device)
    attention_mask = inputs.attention_mask.to(model.device)
    outputs = model.generate(input_ids, attention_mask=attention_mask)
    output_str = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return outputs, output_str


model_before_tuning = AutoModelForSeq2SeqLM.from_pretrained(model_name)

test_samples = validation_data_txt.select(range(16))

summaries_before_tuning = generate_summary(test_samples, model_before_tuning)[1]
summaries_after_tuning = generate_summary(test_samples, model)[1]

loading configuration file https://huggingface.co/facebook/bart-base/resolve/main/config.json from cache at /root/.cache/huggingface/transformers/f5310d276a6d1648d00c32fadc8bf7b4607e0fbd5b404fc4a0045960aa2bdfdb.a243ed957122436adb0b8d8e9d20f896f45c174b6324d625ca0a20a84f72a910
Model config BartConfig {
  "_name_or_path": "facebook/bart-base",
  "activation_dropout": 0.1,
  "activation_function": "gelu",
  "add_bias_logits": false,
  "add_final_layer_norm": false,
  "architectures": [
    "BartModel"
  ],
  "attention_dropout": 0.1,
  "bos_token_id": 0,
  "classif_dropout": 0.1,
  "classifier_dropout": 0.0,
  "d_model": 768,
  "decoder_attention_heads": 12,
  "decoder_ffn_dim": 3072,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 6,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "early_stopping": true,
  "encoder_attention_heads": 12,
  "encoder_ffn_dim": 3072,
  "encoder_layerdrop": 0.0,
  "encoder_layers": 6,
  "eos_token_id": 2,
  "forced_bos_token_id": 0,
  "forced_eos_token_id":

In [None]:
print(
    tabulate(
        zip(
            range(len(summaries_after_tuning)),
            summaries_after_tuning,
            summaries_before_tuning,
        ),
        headers=["Id", "Summary after", "Summary before"],
    )
)
print("\nTarget summaries:\n")
print(
    tabulate(list(enumerate(test_samples["summary"])), headers=["Id", "Target summary"])
)
# print("\nSource documents:\n")
# print(tabulate(list(enumerate(test_samples["text"])), headers=["Id", "Document"]))

  Id  Summary after                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                     

In [None]:
from google.colab import files

torch.save(model.state_dict(), 'checkpoint.pth')

# download checkpoint file
files.download('checkpoint.pth')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
from tensorflow.python.client import device_lib
device_lib.list_local_devices()