## Introduction

Finetuned BART to summarize conversations. Bidirectional Autoregressive Transformer (BART) is a Transformer-
based encoder-decoder model. Fine-tuned BART on the conversations using KerasNLP, and generate summaries
using the fine-tuned model.

## Setup
7zip library

In [1]:
!pip install git+https://github.com/keras-team/keras-nlp.git py7zr -q

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.0/67.0 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m950.8/950.8 kB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.2/5.2 MB[0m [31m34.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m49.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m412.3/412.3 kB[0m [31m33.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m138.9/138.9 kB[0m [31m13.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.7/49.7 kB[0m [31m5.0 MB/s[0m eta

In [2]:
import py7zr
import time

import keras_nlp
import tensorflow as tf
import tensorflow_datasets as tfds

import keras_core as keras

Using TensorFlow backend


In [3]:
BATCH_SIZE = 8
NUM_BATCHES = 600
EPOCHS = 25
MAX_ENCODER_SEQUENCE_LENGTH = 512
MAX_DECODER_SEQUENCE_LENGTH = 128
MAX_GENERATION_LENGTH = 40

In [4]:
# Download the dataset.
filename = keras.utils.get_file(
    "corpus.7z",
    origin="https://huggingface.co/datasets/samsum/resolve/main/data/corpus.7z",
)

# Extract the `.7z` file.
with py7zr.SevenZipFile(filename, mode="r") as z:
    z.extractall(path="/root/tensorflow_datasets/downloads/manual")

# Load data using TFDS.
samsum_ds = tfds.load("samsum", split="train", as_supervised=True)

Downloading data from https://huggingface.co/datasets/samsum/resolve/main/data/corpus.7z
[1m2944100/2944100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 0us/step
Downloading and preparing dataset Unknown size (download: Unknown size, generated: 10.71 MiB, total: 10.71 MiB) to /root/tensorflow_datasets/samsum/1.0.0...


Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/14732 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/samsum/1.0.0.incomplete03ODJI/samsum-train.tfrecord*...:   0%|          | …

Generating validation examples...:   0%|          | 0/818 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/samsum/1.0.0.incomplete03ODJI/samsum-validation.tfrecord*...:   0%|       …

Generating test examples...:   0%|          | 0/819 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/samsum/1.0.0.incomplete03ODJI/samsum-test.tfrecord*...:   0%|          | 0…

Dataset samsum downloaded and prepared to /root/tensorflow_datasets/samsum/1.0.0. Subsequent calls will reuse this data.


In [5]:
train_ds = (
    samsum_ds.map(
        lambda dialogue, summary: {"encoder_text": dialogue, "decoder_text": summary}
    )
    .batch(BATCH_SIZE)
    .cache()
)
train_ds = train_ds.take(NUM_BATCHES)

In [6]:
preprocessor = keras_nlp.models.BartSeq2SeqLMPreprocessor.from_preset(
    "bart_base_en",
    encoder_sequence_length=MAX_ENCODER_SEQUENCE_LENGTH,
    decoder_sequence_length=MAX_DECODER_SEQUENCE_LENGTH,
)
bart_lm = keras_nlp.models.BartSeq2SeqLM.from_preset(
    "bart_base_en", preprocessor=preprocessor
)

bart_lm.summary()

Downloading from https://www.kaggle.com/api/v1/models/keras/bart/keras/bart_base_en/2/download/tokenizer.json...
100%|██████████| 448/448 [00:00<00:00, 1.18MB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/bart/keras/bart_base_en/2/download/assets/tokenizer/merges.txt...
100%|██████████| 446k/446k [00:00<00:00, 1.27MB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/bart/keras/bart_base_en/2/download/assets/tokenizer/vocabulary.json...
100%|██████████| 0.99M/0.99M [00:00<00:00, 2.40MB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/bart/keras/bart_base_en/2/download/config.json...
100%|██████████| 483/483 [00:00<00:00, 1.23MB/s]
Downloading from https://www.kaggle.com/api/v1/models/keras/bart/keras/bart_base_en/2/download/model.weights.h5...
100%|██████████| 532M/532M [00:15<00:00, 36.4MB/s]
  return id(getattr(self, attr)) not in self._functional_layer_ids
  return id(getattr(self, attr)) not in self._functional_layer_ids


In [7]:
optimizer =tf.keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
    epsilon=1e-6,
    global_clipnorm=1.0,  # Gradient clipping.
)
# Exclude layernorm and bias terms from weight decay.
optimizer.exclude_from_weight_decay(var_names=["bias"])
optimizer.exclude_from_weight_decay(var_names=["gamma"])
optimizer.exclude_from_weight_decay(var_names=["beta"])

loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

bart_lm.compile(
    optimizer=optimizer,
    loss=loss,
    weighted_metrics=["accuracy"],
)

Let's train the model!

In [8]:
import tensorflow as tf
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))

Found GPU at: /device:GPU:0


In [None]:
with tf.device('/device:GPU:0'):
  bart_lm.fit(train_ds, epochs=EPOCHS)

Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25

In [None]:
bart_lm.save(f"summaryBart_{EPOCHS}Epoch.keras")

In [None]:
!du -m "summaryBart_1Epoch.keras"


In [None]:

def generate_text(model, input_text, max_length=200, print_time_taken=False):
    start = time.time()
    output = model.generate(input_text, max_length=max_length)
    end = time.time()
    print(f"Total Time Elapsed: {end - start:.2f}s")
    return output


# Load the dataset.
val_ds = tfds.load("samsum", split="validation", as_supervised=True)
val_ds = val_ds.take(100)

dialogues = []
ground_truth_summaries = []
for dialogue, summary in val_ds:
    dialogues.append(dialogue.numpy())
    ground_truth_summaries.append(summary.numpy())

# Let's make a dummy call - the first call to XLA generally takes a bit longer.
_ = generate_text(bart_lm, "sample text", max_length=MAX_GENERATION_LENGTH)

# Generate summaries.
generated_summaries = generate_text(
    bart_lm,
    val_ds.map(lambda dialogue, _: dialogue).batch(8),
    max_length=MAX_GENERATION_LENGTH,
    print_time_taken=True,
)

Let's see some of the summaries.

In [None]:
for dialogue, generated_summary, ground_truth_summary in zip(
    dialogues[:5], generated_summaries[:5], ground_truth_summaries[:5]
):
    print("Dialogue:", dialogue)
    print("Generated Summary:", generated_summary)
    print("Ground Truth Summary:", ground_truth_summary)
    print("=============================")