In [1]:
!pip install git+https://github.com/keras-team/keras-hub.git py7zr -q

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.9/67.9 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.0/3.0 MB[0m [31m18.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m93.1/93.1 kB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.7/49.7 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m23.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m138.9/138.9 kB[0m [31m12.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m413.7/413.7 kB[0m [31m17.6 MB/s[0m eta [36m0:00:00[0m

In [2]:
import os

os.environ["KERAS_BACKEND"] = "jax"

import py7zr
import time

import keras_hub
import keras
import tensorflow as tf
import tensorflow_datasets as tfds


In [3]:
BATCH_SIZE = 8
NUM_BATCHES = 600
EPOCHS = 1
MAX_ENCODER_SEQUENCE_LENGTH = 512
MAX_DECODER_SEQUENCE_LENGTH = 128
MAX_GENERATION_LENGTH = 40

In [4]:
filename = keras.utils.get_file(
    "corpus.7z",
    origin="https://huggingface.co/datasets/samsum/resolve/main/data/corpus.7z",
)

with py7zr.SevenZipFile(filename, mode="r") as z:
    z.extractall(path="/root/tensorflow_datasets/downloads/manual")

samsum_ds = tfds.load("samsum", split="train", as_supervised=True)

Downloading data from https://huggingface.co/datasets/samsum/resolve/main/data/corpus.7z
[1m2944100/2944100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
Downloading and preparing dataset Unknown size (download: Unknown size, generated: 10.71 MiB, total: 10.71 MiB) to /root/tensorflow_datasets/samsum/1.0.0...


Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/14732 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/samsum/incomplete.G3SPUQ_1.0.0/samsum-train.tfrecord*...:   0%|          |…

Generating validation examples...:   0%|          | 0/818 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/samsum/incomplete.G3SPUQ_1.0.0/samsum-validation.tfrecord*...:   0%|      …

Generating test examples...:   0%|          | 0/819 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/samsum/incomplete.G3SPUQ_1.0.0/samsum-test.tfrecord*...:   0%|          | …

Dataset samsum downloaded and prepared to /root/tensorflow_datasets/samsum/1.0.0. Subsequent calls will reuse this data.


In [5]:
for dialogue, summary in samsum_ds:
    print(dialogue.numpy())
    print(summary.numpy())
    break

b"Carter: Hey Alexis, I just wanted to let you know that I had a really nice time with you tonight. \r\nAlexis: Thanks Carter. Yeah, I really enjoyed myself as well. \r\nCarter: If you are up for it, I would really like to see you again soon.\r\nAlexis: Thanks Carter, I'm flattered. But I have a really busy week coming up.\r\nCarter: Yeah, no worries. I totally understand. But if you ever want to go grab dinner again, just let me know. \r\nAlexis: Yeah of course. Thanks again for tonight. \r\nCarter: Sure. Have a great night. "
b'Alexis and Carter met tonight. Carter would like to meet again, but Alexis is busy.'


In [6]:
train_ds = (
    samsum_ds.map(
        lambda dialogue, summary: {"encoder_text": dialogue, "decoder_text": summary}
    )
    .batch(BATCH_SIZE)
    .cache()
)
train_ds = train_ds.take(NUM_BATCHES)


In [7]:
preprocessor = keras_hub.models.BartSeq2SeqLMPreprocessor.from_preset(
    "bart_base_en",
    encoder_sequence_length=MAX_ENCODER_SEQUENCE_LENGTH,
    decoder_sequence_length=MAX_DECODER_SEQUENCE_LENGTH,
)
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset(
    "bart_base_en", preprocessor=preprocessor
)

bart_lm.summary()

Downloading from https://www.kaggle.com/api/v1/models/keras/bart/keras/bart_base_en/2/download/config.json...


100%|██████████| 483/483 [00:00<00:00, 491kB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/bart/keras/bart_base_en/2/download/tokenizer.json...


100%|██████████| 448/448 [00:00<00:00, 514kB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/bart/keras/bart_base_en/2/download/assets/tokenizer/vocabulary.json...


100%|██████████| 0.99M/0.99M [00:00<00:00, 3.10MB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/bart/keras/bart_base_en/2/download/assets/tokenizer/merges.txt...


100%|██████████| 446k/446k [00:00<00:00, 1.72MB/s]


Downloading from https://www.kaggle.com/api/v1/models/keras/bart/keras/bart_base_en/2/download/model.weights.h5...


100%|██████████| 532M/532M [00:18<00:00, 30.9MB/s]


In [8]:
bart_lm.fit(train_ds, epochs=EPOCHS)

[1m600/600[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m445s[0m 672ms/step - loss: 0.4300 - sparse_categorical_accuracy: 0.5528


<keras.src.callbacks.history.History at 0x7fd939505090>

In [9]:
def generate_text(model, input_text, max_length=200, print_time_taken=False):
    start = time.time()
    output = model.generate(input_text, max_length=max_length)
    end = time.time()
    print(f"Total Time Elapsed: {end - start:.2f}s")
    return output

val_ds = tfds.load("samsum", split="validation", as_supervised=True)
val_ds = val_ds.take(100)

dialogues = []
ground_truth_summaries = []
for dialogue, summary in val_ds:
    dialogues.append(dialogue.numpy())
    ground_truth_summaries.append(summary.numpy())

_ = generate_text(bart_lm, "sample text", max_length=MAX_GENERATION_LENGTH)

generated_summaries = generate_text(
    bart_lm,
    val_ds.map(lambda dialogue, _: dialogue).batch(8),
    max_length=MAX_GENERATION_LENGTH,
    print_time_taken=True,
)


Total Time Elapsed: 8.50s
Total Time Elapsed: 36.58s


In [10]:
for dialogue, generated_summary, ground_truth_summary in zip(
    dialogues[:5], generated_summaries[:5], ground_truth_summaries[:5]
):
    print("Diálogo:", dialogue)
    print("Sumário:", generated_summary)
    print("Sumário de verdade:", ground_truth_summary)
    print("=============================")

Diálogo: b'Tony: Is the boss in?\r\nClaire: Not yet.\r\nTony: Could let me know when he comes, please? \r\nClaire: Of course.\r\nTony: Thank you.'
Sumário: The boss hasn't come. Tony will let Claire know when he comes.
Sumário de verdade: b"The boss isn't in yet. Claire will let Tony know when he comes."
Diálogo: b"James: What shouldl I get her?\r\nTim: who?\r\nJames: gees Mary my girlfirend\r\nTim: Am I really the person you should be asking?\r\nJames: oh come on it's her birthday on Sat\r\nTim: ask Sandy\r\nTim: I honestly am not the right person to ask this\r\nJames: ugh fine!"
Sumário: James will be asking Sandy for her birthday on Sat.
Sumário de verdade: b"Mary's birthday is on Saturday. Her boyfriend, James, is looking for gift ideas. Tim suggests that he ask Sandy."
Diálogo: b"Mary: So, how's Israel? Have you been on the beach?\r\nKate: It's so expensive! But they say, it's Tel Aviv... Tomorrow we are going to Jerusalem.\r\nMary: I've heard Israel is expensive, Monica was there