In [1]:
%%capture
!pip install transformers
!pip install tensorflow
!pip install rouge_score

In [4]:
!pip install -q "transformers<5" tensorflow --upgrade

In [1]:
import transformers
print(transformers.__version__)

4.57.6


In [2]:
from datasets import load_dataset
from transformers import (
    T5Tokenizer,
    TFT5ForConditionalGeneration,
    DataCollatorForSeq2Seq
)
import tensorflow as tf
import numpy as np
from rouge_score import rouge_scorer

In [4]:
from datasets import load_dataset
dataset = load_dataset("cnn_dailymail", "3.0.0")

3.0.0/train-00001-of-00003.parquet:   0%|          | 0.00/257M [00:00<?, ?B/s]

3.0.0/train-00002-of-00003.parquet:   0%|          | 0.00/259M [00:00<?, ?B/s]

3.0.0/validation-00000-of-00001.parquet:   0%|          | 0.00/34.7M [00:00<?, ?B/s]

3.0.0/test-00000-of-00001.parquet:   0%|          | 0.00/30.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/287113 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/13368 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/11490 [00:00<?, ? examples/s]

In [5]:
dataset

DatasetDict({
    train: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 287113
    })
    validation: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 13368
    })
    test: Dataset({
        features: ['article', 'highlights', 'id'],
        num_rows: 11490
    })
})

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

article_lengths = [len(x.split()) for x in dataset['train']['document']]
plt.figure(figsize=(10,5))
sns.histplot(article_lengths, bins=50, kde=True)
plt.title("Distribution of Article Lengths (in words)")
plt.xlabel("Number of words")
plt.ylabel("Count")
plt.show()

In [None]:
summary_lengths = [len(x.split()) for x in dataset['train']['summary']]
plt.figure(figsize=(10,5))
sns.histplot(summary_lengths, bins=30, color='orange', kde=True)
plt.title("Distribution of Summary Lengths (in words)")
plt.xlabel("Number of words")
plt.ylabel("Count")
plt.show()

In [None]:
import pandas as pd

df = pd.DataFrame({
    "article_len": article_lengths,
    "summary_len": summary_lengths
})

plt.figure(figsize=(8,6))
sns.scatterplot(x="article_len", y="summary_len", data=df, alpha=0.5)
plt.title("Article Length vs Summary Length")
plt.xlabel("Article Length (words)")
plt.ylabel("Summary Length (words)")
plt.show()

In [None]:
from collections import Counter
from nltk.corpus import stopwords
import nltk
nltk.download('stopwords')

stop_words = set(stopwords.words('english'))
words = " ".join(dataset['train']['document']).split()
words = [w.lower() for w in words if w.lower() not in stop_words]
counter = Counter(words)
common_words = counter.most_common(20)

plt.figure(figsize=(10,5))
sns.barplot(x=[x[1] for x in common_words], y=[x[0] for x in common_words])
plt.title("Top 20 Frequent Words in Articles")
plt.xlabel("Count")
plt.ylabel("Word")
plt.show()

In [6]:
dataset['train'].features

{'article': Value('string'),
 'highlights': Value('string'),
 'id': Value('string')}

In [7]:
# Taking the subset of the dataset for the finetuning purpose
train_subset = dataset["train"].select(range(10000))
validation_subset = dataset["validation"].select(range(1000))
test_subset = dataset["test"].select(range(1000))

In [8]:
checkpoint = 't5-small'
tokenizer = T5Tokenizer.from_pretrained(checkpoint)

tokenizer_config.json:   0%|          | 0.00/2.32k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [9]:
def preprocess_function(examples):
    inputs = ["summarize: " + doc for doc in examples["article"]]

    model_inputs = tokenizer(
        inputs,
        max_length=512,
        truncation=True,
        padding="max_length"
    )

    labels = tokenizer(
        text_target=examples["highlights"],
        max_length=150,
        truncation=True,
        padding="max_length"
    )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [10]:
# Tokenize datasets
tokenized_train = train_subset.map(preprocess_function, batched=True)
tokenized_validation = validation_subset.map(preprocess_function, batched=True)
tokenized_test = test_subset.map(preprocess_function, batched=True)

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

In [14]:
from transformers import T5Tokenizer, T5ForConditionalGeneration, TFT5ForConditionalGeneration

checkpoint = "t5-small"

# Load PyTorch model
pt_model = T5ForConditionalGeneration.from_pretrained(checkpoint)

# Convert to TensorFlow
model = TFT5ForConditionalGeneration.from_pretrained(checkpoint, from_pt=True)

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/242M [00:00<?, ?B/s]

All PyTorch model weights were used when initializing TFT5ForConditionalGeneration.

All the weights of TFT5ForConditionalGeneration were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.


In [15]:
# Data collator
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [18]:
!pip install evaluate

import evaluate
rouge = evaluate.load("rouge")

# Example
results = rouge.compute(
    predictions=["Government announces new fuel policy affecting the economy."],
    references=["The government introduces fuel policy expected to impact economy."]
)
print(results)

Collecting evaluate
  Downloading evaluate-0.4.6-py3-none-any.whl.metadata (9.5 kB)
Downloading evaluate-0.4.6-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m8.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.6


Downloading builder script: 0.00B [00:00, ?B/s]

{'rouge1': np.float64(0.5882352941176471), 'rouge2': np.float64(0.13333333333333333), 'rougeL': np.float64(0.47058823529411764), 'rougeLsum': np.float64(0.47058823529411764)}


In [19]:
from transformers import AdamWeightDecay

# Learning rate, weight decay
optimizer = AdamWeightDecay(
    learning_rate=2e-5,
    weight_decay_rate=0.01
)

# Compile the model
model.compile(
    optimizer=optimizer,
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
)

In [20]:
# Convert Hugging Face tokenized dataset to tf.data.Dataset
def convert_to_tf_dataset(tokenized_dataset):
    columns = ["input_ids", "attention_mask", "labels"]
    tf_dataset = tokenized_dataset.to_tf_dataset(
        columns=columns,
        shuffle=True,
        batch_size=16,
        collate_fn=None
    )
    return tf_dataset

train_dataset = convert_to_tf_dataset(tokenized_train)
val_dataset   = convert_to_tf_dataset(tokenized_validation)

In [22]:
from tensorflow.keras.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(
    filepath="./results/t5_tf_checkpoint.weights.h5",
    save_weights_only=True,
    save_best_only=True
)

In [26]:
import os
import numpy as np

os.makedirs("./results", exist_ok=True)

best_val_loss = np.inf
patience = 2
wait = 0

num_epochs = 3

for epoch in range(num_epochs):
    print(f"\n=== Epoch {epoch+1}/{num_epochs} ===")

    # Train for one epoch
    history = model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=1,   # train one epoch at a time
        verbose=1
    )

    # Get validation loss
    val_loss = history.history['val_loss'][-1]
    print(f"Validation loss: {val_loss:.4f}")

    # Save weights manually (checkpoint)
    model.save_weights(f"./results/t5_epoch{epoch+1}.weights.h5")

    # Manual early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        wait = 0
        # Save best weights separately
        model.save_weights("./results/t5_best.weights.h5")
    else:
        wait += 1
        if wait >= patience:
            print("Early stopping triggered!")
            break


=== Epoch 1/3 ===
Validation loss: 0.7728





=== Epoch 2/3 ===
Validation loss: 0.7297





=== Epoch 3/3 ===
Validation loss: 0.7251




In [27]:
checkpoint = "t5-small"
model = TFT5ForConditionalGeneration.from_pretrained(checkpoint, from_pt=True)

# Load saved weights
model.load_weights("./results/t5_best.weights.h5")

# Generate summaries on validation batch
for batch in val_dataset.take(1):
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']

    generated_ids = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_length=150,  # generation_max_length
        num_beams=4,     # generation_num_beams
        early_stopping=True
    )
    summaries = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    print(summaries[:3])

All PyTorch model weights were used when initializing TFT5ForConditionalGeneration.

All the weights of TFT5ForConditionalGeneration were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.


['The Jewish community in Iran does not hide its heritage. About 20 people were in attendance, usually from local businesses around the synagogue. Many left Iran after the Islamic Revolution in 1979 that brought Ayatollah Khomeini to power.', "Japan's Maritime Self-Defense Force (MSDF) has delivered the Izumo. It is as large as the storied Yamato-class battleships which fought U.S. naval forces in the Pacific theater of World War II. Japanese neighbors and rivals questioned the legitimacy of such a ship for purely defensive purposes.", 'Fernando Alonso will not race in the Australian Grand Prix on the advice of doctors. The Spaniard lost control of his McLaren at the penultimate winter test in Barcelona. Alonso is now recovering at home but doctors have indicated returning to Melbourne three weeks after the high impact could be too risky.']


In [None]:
import evaluate
import numpy as np

# Load ROUGE metric
rouge = evaluate.load("rouge")

predictions = []
references = []

# Choose number of batches to process for fast evaluation
num_batches = 50  # Adjust: 50-100 batches should take only a few minutes

for batch in val_dataset.take(num_batches):
    # Convert tensors to numpy arrays
    input_ids = batch['input_ids'].numpy()
    attention_mask = batch['attention_mask'].numpy()
    labels = batch['labels'].numpy() if 'labels' in batch else None

    # Generate summaries in batch (faster than single example)
    generated_ids = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_length=150,   # same as training
        num_beams=2,      # reduce beams for speed
        early_stopping=True
    )

    # Decode generated summaries and references
    preds = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

    if labels is not None:
        refs = tokenizer.batch_decode(labels, skip_special_tokens=True)
        references.extend(refs)

    predictions.extend(preds)

# Compute ROUGE
results = rouge.compute(predictions=predictions, references=references)
print(results)

In [None]:
plt.figure(figsize=(8,5))
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title("Training & Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
plt.figure(figsize=(8,5))
plt.plot(rouge1_scores, label='ROUGE-1')
plt.plot(rouge2_scores, label='ROUGE-2')
plt.plot(rougeL_scores, label='ROUGE-L')
plt.title("ROUGE Score per Epoch")
plt.xlabel("Epoch")
plt.ylabel("Score")
plt.legend()
plt.show()

In [None]:
import pandas as pd

df_samples = pd.DataFrame({
    "Article": dataset['validation']['document'][:5],
    "Reference": dataset['validation']['summary'][:5],
    "Generated": generated_summaries[:5]
})
df_samples