In [3]:
import pandas as pd
import tensorflow as tf
from datasets import load_dataset
from tensorflow.keras.optimizers import schedules, AdamW
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM

In [4]:
train_dataset = load_dataset('cnn_dailymail','3.0.0',split='train').shuffle(seed=42).select(range(1000))
val_dataset = load_dataset('cnn_dailymail','3.0.0',split='validation').shuffle(seed=42).select(range(100))

In [5]:
train_dataset.features

{'article': Value(dtype='string', id=None),
 'highlights': Value(dtype='string', id=None),
 'id': Value(dtype='string', id=None)}

In [6]:
model_checkpoint = 't5-small'
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [7]:
prefix = 'summarize: '
max_input_length = 512
max_target_length = 128

def preprocess_function(sample):
    inputs = [prefix + t for t in sample['article']]
    model_inputs = tokenizer(inputs,
                             max_length=max_input_length,
                             truncation=True,
                             padding='max_length')
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(sample['highlights'],
                          max_length=max_target_length,
                          truncation=True,
                          padding='max_length')
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

In [8]:
tokenized_train_dataset = train_dataset.map(preprocess_function, batched=True)
tokenized_val_dataset = val_dataset.map(preprocess_function, batched=True)

In [9]:
tokenized_train_dataset

Dataset({
    features: ['article', 'highlights', 'id', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 1000
})

In [10]:
batch_size = 8
tf_dataset_columns = ['input_ids','attention_mask','labels']
tf_train_dataset = tokenized_train_dataset.to_tf_dataset(
    columns=tf_dataset_columns,
    shuffle=True,
    batch_size=batch_size
)
tf_val_dataset = tokenized_val_dataset.to_tf_dataset(
    columns=tf_dataset_columns,
    shuffle=False,
    batch_size=batch_size
)

In [12]:
model = TFAutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

All PyTorch model weights were used when initializing TFT5ForConditionalGeneration.

All the weights of TFT5ForConditionalGeneration were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.


In [13]:
num_epochs = 3
num_train_steps = len(tf_train_dataset) * num_epochs
initial_learning_rate = 2e-5
end_learning_rate = 0.0
lr_schedule = schedules.PolynomialDecay(
    initial_learning_rate=initial_learning_rate,
    end_learning_rate=end_learning_rate,
    decay_steps=num_train_steps
)
optimizer = AdamW(
    learning_rate=lr_schedule,
    weight_decay=0.01
)
model.optimizer = optimizer
model._is_compiled=True
print('Model Compiled Successfully!!')

Model Compiled Successfully!!


In [None]:
import time

print("Starting manual training loop...")
num_epochs=3
num_train_steps_per_epoch = len(tf_train_dataset)
for epoch in range(num_epochs):
    print(f"\n--- Starting Epoch {epoch + 1}/{num_epochs} ---")
    start_time = time.time()

    # --- Training ---
    for step, batch in enumerate(tf_train_dataset):
        with tf.GradientTape() as tape:
            # 1. Forward pass: Get model predictions and loss
            outputs = model(batch, training=True)
            loss = outputs.loss

        # 2. Calculate gradients
        grads = tape.gradient(loss, model.trainable_variables)

        # 3. Apply gradients (update weights)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

        if (step + 1) % 20 == 0: # Print a log every 20 steps
            print(f"  Step {step + 1}/{num_train_steps_per_epoch}, Loss: {tf.reduce_mean(loss).numpy():.4f}")

    # --- Validation ---
    print("\nRunning validation...")
    total_val_loss = 0
    num_val_steps = 0

    for batch in tf_val_dataset:
        # Run in inference mode (no gradients)
        outputs = model(batch, training=False)
        total_val_loss += outputs.loss.numpy()
        num_val_steps += 1

    avg_val_loss = total_val_loss / num_val_steps
    epoch_time = time.time() - start_time

    print(f"--- Epoch {epoch + 1} Summary ---")
    print(f"Time: {epoch_time:.2f}s, Validation Loss: {avg_val_loss.item():.4f}")

print("\nTraining complete!")