In [2]:
!pip install git+https://github.com/keras-team/keras-hub.git -q
import os

os.environ["KERAS_BACKEND"] = "jax"

import time

import keras_hub
import keras
import tensorflow as tf
import tensorflow_datasets as tfds

BATCH_SIZE = 8
EPOCHS = 10  # Can be set to a higher value for better results
MAX_ENCODER_SEQUENCE_LENGTH = 512
MAX_DECODER_SEQUENCE_LENGTH = 128
MAX_GENERATION_LENGTH = 1024
TRAIN_TEST_SPLIT = 0.9  # 90% train, 10% test

# Load BillSum dataset using TFDS
billsum_ds = tfds.load("billsum", split="ca_test", as_supervised=False)

# Get the total number of examples
total_examples = billsum_ds.cardinality().numpy()
print(f"Total examples in dataset: {total_examples}")

# Calculate split sizes
train_size = int(TRAIN_TEST_SPLIT * total_examples)
test_size = total_examples - train_size
print(f"Train size: {train_size}, Test size: {test_size}")

# Split the dataset
train_raw = billsum_ds.take(train_size)
test_raw = billsum_ds.skip(train_size)

# Prepare training dataset
train_ds = (
    train_raw.map(
        lambda example: {"encoder_text": example["text"], "decoder_text": example["summary"]}
    )
    .batch(BATCH_SIZE)
    .cache()
    .prefetch(tf.data.AUTOTUNE)
)

# Prepare test dataset
test_ds = (
    test_raw.map(
        lambda example: {"encoder_text": example["text"], "decoder_text": example["summary"]}
    )
    .batch(BATCH_SIZE)
    .cache()
    .prefetch(tf.data.AUTOTUNE)
)

# Initialize the model
preprocessor = keras_hub.models.BartSeq2SeqLMPreprocessor.from_preset(
    "bart_base_en",
    encoder_sequence_length=MAX_ENCODER_SEQUENCE_LENGTH,
    decoder_sequence_length=MAX_DECODER_SEQUENCE_LENGTH,
)
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset(
    "bart_base_en", preprocessor=preprocessor
)

bart_lm.summary()

# Configure optimizer
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
    epsilon=1e-6,
    global_clipnorm=1.0,  # Gradient clipping.
)
# Exclude layernorm and bias terms from weight decay.
optimizer.exclude_from_weight_decay(var_names=["bias"])
optimizer.exclude_from_weight_decay(var_names=["gamma"])
optimizer.exclude_from_weight_decay(var_names=["beta"])

loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

bart_lm.compile(
    optimizer=optimizer,
    loss=loss,
    weighted_metrics=["accuracy"],
)

# Train the model with validation
bart_lm.fit(train_ds, validation_data=test_ds, epochs=EPOCHS)


def generate_text(model, input_text, max_length=200, print_time_taken=False):
    start = time.time()
    output = model.generate(input_text, max_length=max_length)
    end = time.time()
    if print_time_taken:
        print(f"Total Time Elapsed: {end - start:.2f}s")
    return output


# Prepare evaluation data
eval_ds = test_raw.take(100)

texts = []
ground_truth_summaries = []
for example in eval_ds:
    texts.append(example["text"].numpy())
    ground_truth_summaries.append(example["summary"].numpy())

# Let's make a dummy call - the first call to XLA generally takes a bit longer.
_ = generate_text(bart_lm, "sample text", max_length=MAX_GENERATION_LENGTH)

# Generate summaries
generated_summaries = generate_text(
    bart_lm,
    eval_ds.map(lambda example: example["text"]).batch(8),
    max_length=MAX_GENERATION_LENGTH,
    print_time_taken=True,
)

# Display results
for text, generated_summary, ground_truth_summary in zip(
    texts[:3], generated_summaries[:3], ground_truth_summaries[:3]
):
    print("Text:", text[:200], "...")  # Print first 200 chars
    print("\nGenerated Summary:", generated_summary)
    print("\nGround Truth Summary:", ground_truth_summary)
    print("=============================")

# Evaluate on test set
print("\nEvaluating on test set...")
test_results = bart_lm.evaluate(test_ds)
print(f"Test Loss: {test_results[0]:.4f}")
print(f"Test Accuracy: {test_results[1]:.4f}")

  pid, fd = os.forkpty()


  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Total examples in dataset: 1237
Train size: 1113, Test size: 124


Epoch 1/10
[1m140/140[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m232s[0m 1s/step - accuracy: 0.5107 - loss: 2.3983 - val_accuracy: 0.5773 - val_loss: 1.9183
Epoch 2/10
[1m140/140[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m152s[0m 829ms/step - accuracy: 0.5974 - loss: 1.7776 - val_accuracy: 0.5909 - val_loss: 1.8862
Epoch 3/10
[1m140/140[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m118s[0m 815ms/step - accuracy: 0.6452 - loss: 1.4882 - val_accuracy: 0.6028 - val_loss: 1.8532
Epoch 4/10
[1m140/140[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m118s[0m 816ms/step - accuracy: 0.6828 - loss: 1.2791 - val_accuracy: 0.6087 - val_loss: 1.8268
Epoch 5/10
[1m140/140[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m118s[0m 815ms/step - accuracy: 0.7142 - loss: 1.1036 - val_accuracy: 0.6161 - val_loss: 1.8156
Epoch 6/10
[1m140/140[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m118s[0m 813ms/step - accuracy: 0.7455 - loss: 0.9507 - val_accuracy: 0.6208 - val_loss: 1.8776
Epoch 7

In [3]:
# Install required packages for evaluation
!pip install rouge-score bert-score nltk -q

import numpy as np
from rouge_score import rouge_scorer
from collections import defaultdict
import nltk
nltk.download('punkt', quiet=True)

# If you haven't already generated summaries, do it now
# (Skip this if you already have generated_summaries from your previous run)
if 'generated_summaries' not in locals():
    eval_ds = test_raw.take(100)
    texts = []
    ground_truth_summaries = []
    for example in eval_ds:
        texts.append(example["text"].numpy())
        ground_truth_summaries.append(example["summary"].numpy())
    
    generated_summaries = generate_text(
        bart_lm,
        eval_ds.map(lambda example: example["text"]).batch(8),
        max_length=MAX_GENERATION_LENGTH,
        print_time_taken=True,
    )

# Decode byte strings if necessary
def decode_if_bytes(text):
    if isinstance(text, bytes):
        return text.decode('utf-8')
    return text

generated_summaries_decoded = [decode_if_bytes(s) for s in generated_summaries]
ground_truth_summaries_decoded = [decode_if_bytes(s) for s in ground_truth_summaries]

# Initialize ROUGE scorer
scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

# Calculate ROUGE scores for each summary
rouge_scores = defaultdict(list)

print("Calculating ROUGE scores...")
for generated, reference in zip(generated_summaries_decoded, ground_truth_summaries_decoded):
    scores = scorer.score(reference, generated)
    for metric, score in scores.items():
        rouge_scores[f"{metric}_precision"].append(score.precision)
        rouge_scores[f"{metric}_recall"].append(score.recall)
        rouge_scores[f"{metric}_fmeasure"].append(score.fmeasure)

# Calculate average scores
print("\n" + "="*50)
print("ROUGE SCORES (Average over all test samples)")
print("="*50)

for metric in ['rouge1', 'rouge2', 'rougeL']:
    print(f"\n{metric.upper()}:")
    print(f"  Precision: {np.mean(rouge_scores[f'{metric}_precision']):.4f}")
    print(f"  Recall:    {np.mean(rouge_scores[f'{metric}_recall']):.4f}")
    print(f"  F1-Score:  {np.mean(rouge_scores[f'{metric}_fmeasure']):.4f}")

# Additional metrics: Average length statistics
print("\n" + "="*50)
print("LENGTH STATISTICS")
print("="*50)

gen_lengths = [len(s.split()) for s in generated_summaries_decoded]
ref_lengths = [len(s.split()) for s in ground_truth_summaries_decoded]

print(f"\nGenerated Summaries:")
print(f"  Average length: {np.mean(gen_lengths):.2f} words")
print(f"  Min length: {np.min(gen_lengths)} words")
print(f"  Max length: {np.max(gen_lengths)} words")

print(f"\nReference Summaries:")
print(f"  Average length: {np.mean(ref_lengths):.2f} words")
print(f"  Min length: {np.min(ref_lengths)} words")
print(f"  Max length: {np.max(ref_lengths)} words")

# Show some example comparisons
print("\n" + "="*50)
print("EXAMPLE COMPARISONS (First 3)")
print("="*50)

for i in range(min(3, len(generated_summaries_decoded))):
    print(f"\n--- Example {i+1} ---")
    print(f"\nGenerated ({len(generated_summaries_decoded[i].split())} words):")
    print(generated_summaries_decoded[i][:300] + "..." if len(generated_summaries_decoded[i]) > 300 else generated_summaries_decoded[i])
    
    print(f"\nReference ({len(ground_truth_summaries_decoded[i].split())} words):")
    print(ground_truth_summaries_decoded[i][:300] + "..." if len(ground_truth_summaries_decoded[i]) > 300 else ground_truth_summaries_decoded[i])
    
    scores = scorer.score(ground_truth_summaries_decoded[i], generated_summaries_decoded[i])
    print(f"\nROUGE-1 F1: {scores['rouge1'].fmeasure:.4f}")
    print(f"ROUGE-2 F1: {scores['rouge2'].fmeasure:.4f}")
    print(f"ROUGE-L F1: {scores['rougeL'].fmeasure:.4f}")

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.1/61.1 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m105.2 MB/s[0m eta [36m0:00:00[0m00:01[0m0:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m77.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m46.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━

In [2]:
!pip install git+https://github.com/keras-team/keras-hub.git -q
import os

os.environ["KERAS_BACKEND"] = "jax"

import time

import keras_hub
import keras
import tensorflow as tf
import tensorflow_datasets as tfds

BATCH_SIZE = 8
EPOCHS = 1  # Can be set to a higher value for better results
MAX_ENCODER_SEQUENCE_LENGTH = 512
MAX_DECODER_SEQUENCE_LENGTH = 128
MAX_GENERATION_LENGTH = 200
TRAIN_TEST_SPLIT = 0.9  # 90% train, 10% test

# Load BillSum dataset using TFDS
billsum_ds = tfds.load("billsum", split="ca_test", as_supervised=False)

# Get the total number of examples
total_examples = billsum_ds.cardinality().numpy()
print(f"Total examples in dataset: {total_examples}")

# Calculate split sizes
train_size = int(TRAIN_TEST_SPLIT * total_examples)
test_size = total_examples - train_size
print(f"Train size: {train_size}, Test size: {test_size}")

# Split the dataset
train_raw = billsum_ds.take(train_size)
test_raw = billsum_ds.skip(train_size)

# Preview the data
for example in train_raw.take(1):
    print("Text:", example["text"].numpy())
    print("Summary:", example["summary"].numpy())
    break

# Prepare training dataset
train_ds = (
    train_raw.map(
        lambda example: {"encoder_text": example["text"], "decoder_text": example["summary"]}
    )
    .batch(BATCH_SIZE)
    .cache()
    .prefetch(tf.data.AUTOTUNE)
)

# Prepare test dataset
test_ds = (
    test_raw.map(
        lambda example: {"encoder_text": example["text"], "decoder_text": example["summary"]}
    )
    .batch(BATCH_SIZE)
    .cache()
    .prefetch(tf.data.AUTOTUNE)
)

# Initialize the model
preprocessor = keras_hub.models.BartSeq2SeqLMPreprocessor.from_preset(
    "bart_base_en",
    encoder_sequence_length=MAX_ENCODER_SEQUENCE_LENGTH,
    decoder_sequence_length=MAX_DECODER_SEQUENCE_LENGTH,
)
bart_lm = keras_hub.models.BartSeq2SeqLM.from_preset(
    "bart_base_en", preprocessor=preprocessor
)

bart_lm.summary()

# Configure optimizer
optimizer = keras.optimizers.AdamW(
    learning_rate=5e-5,
    weight_decay=0.01,
    epsilon=1e-6,
    global_clipnorm=1.0,  # Gradient clipping.
)
# Exclude layernorm and bias terms from weight decay.
optimizer.exclude_from_weight_decay(var_names=["bias"])
optimizer.exclude_from_weight_decay(var_names=["gamma"])
optimizer.exclude_from_weight_decay(var_names=["beta"])

loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

bart_lm.compile(
    optimizer=optimizer,
    loss=loss,
    weighted_metrics=["accuracy"],
)

# Train the model with validation
bart_lm.fit(train_ds, validation_data=test_ds, epochs=EPOCHS)


def generate_text(model, input_text, max_length=200, print_time_taken=False):
    start = time.time()
    output = model.generate(input_text, max_length=max_length)
    end = time.time()
    if print_time_taken:
        print(f"Total Time Elapsed: {end - start:.2f}s")
    return output


# Prepare evaluation data
eval_ds = test_raw.take(100)

texts = []
ground_truth_summaries = []
for example in eval_ds:
    texts.append(example["text"].numpy())
    ground_truth_summaries.append(example["summary"].numpy())

# Let's make a dummy call - the first call to XLA generally takes a bit longer.
_ = generate_text(bart_lm, "sample text", max_length=MAX_GENERATION_LENGTH)

# Generate summaries
generated_summaries = generate_text(
    bart_lm,
    eval_ds.map(lambda example: example["text"]).batch(8),
    max_length=MAX_GENERATION_LENGTH,
    print_time_taken=True,
)

# Display results
for text, generated_summary, ground_truth_summary in zip(
    texts[:5], generated_summaries[:5], ground_truth_summaries[:5]
):
    print("Text:", text[:200], "...")  # Print first 200 chars
    print("Generated Summary:", generated_summary)
    print("Ground Truth Summary:", ground_truth_summary)
    print("=============================")

# Evaluate on test set
print("\nEvaluating on test set...")
test_results = bart_lm.evaluate(test_ds)
print(f"Test Loss: {test_results[0]:.4f}")
print(f"Test Accuracy: {test_results[1]:.4f}")

  pid, fd = os.forkpty()


  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Total examples in dataset: 1237
Train size: 1113, Test size: 124
Text: b'The people of the State of California do enact as follows:\n\n\nSECTION 1.\nSection 17941 of the Revenue and Taxation Code is amended to read:\n17941.\n(a) For each taxable year beginning on or after January 1, 1997, a limited liability company doing business in this state (as defined in Section 23101) shall pay annually to this state a tax for the privilege of doing business in this state in an amount equal to the applicable amount specified in\nparagraph (1) of\nsubdivision (d) of Section 23153 for the taxable year.\n(b) (1) In addition to any limited liability company that is doing business in this state and is therefore subject to the tax imposed by subdivision (a), for each taxable year beginning on or after January 1, 1997, a limited liabil

[1m140/140[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m220s[0m 1s/step - accuracy: 0.5114 - loss: 2.4013 - val_accuracy: 0.5753 - val_loss: 1.9198
Total Time Elapsed: 38.31s
Text: b'The people of the State of California do enact as follows:\n\n\nSECTION 1.\nSection 25503.6 of the Business and Professions Code is amended to read:\n25503.6.\n(a) Notwithstanding any other provision of thi' ...
Generated Summary: Existing law, the Alcoholic Beverage Control and Alcoholic Beverage Control Act, establishes the Alcoholic Beverage Control and Alcoholic Beverage Control Board and the Alcoholic Beverage Control Commission. The act authorizes, on the part of the commission, the board to adopt a license to purchase advertising space and time from, or on behalf of, an on-sale retail licensee subject to specified conditions. Existing law, until December 31, 2017, requires the commission to establish an on-sale retail licensee for each of the above items. Existing law also permits an on-sale ret