# 1. Import Libraries

In [None]:
import torch
import evaluate
import numpy as np

from transformers import (
  T5TokenizerFast as T5Tokenizer,
  T5ForConditionalGeneration,
  TrainingArguments,
  Trainer
)
from datasets import load_dataset, concatenate_datasets

# 2. Dataset Processing & Information

In [None]:
TRAIN_SAMPLES = 2000
VAL_SIZE = 0.1
TEST_SIZE = 0.1

xsum = load_dataset('xsum', trust_remote_code=True, split='train')
cnn_dailymail = load_dataset('cnn_dailymail', '3.0.0', split='train')

xsum = xsum.select(range(TRAIN_SAMPLES))
cnn_dailymail = cnn_dailymail.select(range(TRAIN_SAMPLES))

xsum = xsum.remove_columns(['id'])
xsum = xsum.rename_columns({
  'document' : 'text',
  'summary' : 'summary'
})

cnn_dailymail = cnn_dailymail.remove_columns(['id'])
cnn_dailymail = cnn_dailymail.rename_columns({
  'article' : 'text',
  'highlights' : 'summary'
})

xsum_split = xsum.train_test_split(test_size=VAL_SIZE + TEST_SIZE, shuffle=True)
xsum_test_val = xsum_split['test'].train_test_split(test_size=TEST_SIZE / (VAL_SIZE + TEST_SIZE))

cnn_split = cnn_dailymail.train_test_split(test_size=VAL_SIZE + TEST_SIZE, shuffle=True)
cnn_test_val = cnn_split['test'].train_test_split(test_size=TEST_SIZE / (VAL_SIZE + TEST_SIZE))

dataset_train = concatenate_datasets([xsum_split['train'], cnn_split['train']])
dataset_valid = concatenate_datasets([xsum_test_val['train'], cnn_test_val['train']])
dataset_test  = concatenate_datasets([xsum_test_val['test'], cnn_test_val['test']])

dataset_train = dataset_train.shuffle(seed=42)
dataset_valid = dataset_valid.shuffle(seed=42)
dataset_test  = dataset_test.shuffle(seed=42)

In [4]:
print(dataset_train)
print(dataset_valid)

Dataset({
    features: ['text', 'summary'],
    num_rows: 3200
})
Dataset({
    features: ['text', 'summary'],
    num_rows: 800
})


In [None]:
def find_longest_length(dataset):
  max_length = 0
  counter_4k = 0
  counter_2k = 0
  counter_1k = 0
  counter_500 = 0
  for text in dataset:
    corpus = [word for word in text.split()]
    if len(corpus) > 4000:
      counter_4k += 1
    elif len(corpus) > 2000:
      counter_2k += 1
    elif len(corpus) > 1000:
      counter_1k += 1
    elif len(corpus) > 500:
      counter_500 += 1
    if len(corpus) > max_length:
      max_length = len(corpus)
  return max_length, counter_4k, counter_2k, counter_1k, counter_500

longest_article_length, counter_4k, counter_2k, counter_1k, counter_500 = find_longest_length(dataset_train['text'])
print(f"Longest Article Length: {longest_article_length} words")
print(f"Text (> 4000 words): {counter_4k}")
print(f"Text (> 2000 words): {counter_2k}")
print(f"Text (> 1000 words): {counter_1k}")
print(f"Text (> 500 words): {counter_500}")

print("")

longest_summary_length, counter_4k, counter_2k, counter_1k, counter_500 = find_longest_length(dataset_train['summary'])
print(f"Longest Summary Length: {longest_summary_length} words")
print(f"Summary (> 4000 words): {counter_4k}")
print(f"Summary (> 2000 words): {counter_2k}")
print(f"Summary (> 1000 words): {counter_1k}")
print(f"Summary (> 500 words): {counter_500}")

Longest article length: 2694 words
Text (> 4000 words): 0
Text (> 2000 words): 2
Text (> 1000 words): 239
Text (> 500 words): 1306
Longest summary length): 66 words
Summary (> 4000 words): 0
Summary (> 2000 words): 0
Summary (> 1000 words): 0
Summary (> 500 words): 0


In [None]:
def find_avg_sentence_length(dataset):
  sentence_lengths = []
  for text in dataset:
    corpus = [word for word in text.split()]
    sentence_lengths.append(len(corpus))
  return sum(sentence_lengths)/len(sentence_lengths)

avg_text_length = find_avg_sentence_length(dataset_train['text'])
print(f"Average Text Length: {avg_text_length:.2f} words")

print("")

avg_summary_length = find_avg_sentence_length(dataset_train['summary'])
print(f"Average Summary Length: {avg_summary_length:.2f} words")

Average text length: 491.8721875 words
Average summary length: 32.2028125 words


# 3. Initialize Configurations & Parameters

In [None]:
MODEL_NAME = 'google/flan-t5-base'
OUT_DIR = 'results'
BATCH_SIZE = 4
EPOCHS = 10
MAX_LENGTH = 1024
GRADIENT_ACCUMULATION_STEPS = 8
LEARNING_RATE = 5e-4
SEED = 42

In [None]:
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

rouge = evaluate.load("rouge")

<torch._C.Generator at 0x297eca78890>

# 4. Dataset Tokenization

In [10]:
tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME)

In [None]:
def preprocess_function(examples):
  inputs = [f"summarize: {text}" for text in examples["text"]]

  model_inputs = tokenizer(
    inputs,
    max_length=MAX_LENGTH,
    truncation=True,
    padding="max_length",
  )

  labels = tokenizer(
    text_target=examples["summary"],
    max_length=MAX_LENGTH,
    truncation=True,
    padding="max_length"
  )

  model_inputs["labels"] = labels["input_ids"]
  return model_inputs

tokenized_train = dataset_train.map(preprocess_function, batched=True)
tokenized_valid = dataset_valid.map(preprocess_function, batched=True)

Map: 100%|██████████| 3200/3200 [00:04<00:00, 719.21 examples/s]
Map: 100%|██████████| 800/800 [00:01<00:00, 703.62 examples/s]


# 5. Model Initialization & Training

In [None]:
model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME)
model.gradient_checkpointing_enable()
model.config.use_cache = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Model Loaded to {str(device).upper()}")

total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} Total Parameters")

total_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} Total Trainable Parameters")

cuda
60,506,624 total parameters.
60,506,624 training parameters.


In [None]:
def compute_metrics(eval_pred):
  predictions, labels = eval_pred.predictions[0], eval_pred.label_ids

  decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
  labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
  decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

  result = rouge.compute(
    predictions=decoded_preds,
    references=decoded_labels,
    use_stemmer=True,
    rouge_types=[
      'rouge1',
      'rouge2',
      'rougeL'
    ]
  )

  prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
  result["gen_len"] = np.mean(prediction_lens)

  return {k: round(v, 4) for k, v in result.items()}

def preprocess_logits_for_metrics(logits, labels):
  pred_ids = torch.argmax(logits[0], dim=-1)
  return pred_ids, labels

In [None]:
training_args = TrainingArguments(
  num_train_epochs=EPOCHS,
  per_device_train_batch_size=BATCH_SIZE,
  per_device_eval_batch_size=BATCH_SIZE,
  gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
  learning_rate=LEARNING_RATE,
  warmup_steps=500,
  weight_decay=0.01,
  fp16=True,

  output_dir=OUT_DIR,
  save_strategy="epoch",
  save_total_limit=2,
  load_best_model_at_end=True,

  eval_strategy="epoch",
  logging_strategy="epoch",
  logging_dir=f"{OUT_DIR}/logs",
  report_to="none",

  dataloader_num_workers=4
)

trainer = Trainer(
  model=model,
  args=training_args,
  train_dataset=tokenized_train,
  eval_dataset=tokenized_valid,
  preprocess_logits_for_metrics=preprocess_logits_for_metrics,
  compute_metrics=compute_metrics
)

history = trainer.train()

trainer.save_model(OUT_DIR)
tokenizer.save_pretrained(OUT_DIR)

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Gen Len
1,0.8234,0.446267,0.4579,0.1846,0.4167,47.0462
2,0.4697,0.439774,0.4655,0.1928,0.4257,47.0462
3,0.4432,0.436988,0.464,0.1928,0.4245,47.0462
4,0.423,0.436381,0.4673,0.194,0.4282,47.0462
5,0.4059,0.439172,0.467,0.194,0.4284,47.0462
6,0.3933,0.439691,0.4683,0.1954,0.4298,47.0462
7,0.3817,0.442058,0.4691,0.1972,0.4309,47.0462
8,0.3718,0.442578,0.4678,0.1957,0.43,47.0462
9,0.3657,0.443152,0.4691,0.1974,0.4309,47.0462
10,0.3616,0.44437,0.4682,0.1967,0.4297,47.0462


# 6. Model Evaluation

In [None]:
tokenized_test = dataset_test.map(preprocess_function, batched=True)
test_results = trainer.evaluate(tokenized_test)

print("Model Evaluation on Test Set:")
for key, value in test_results.items():
  if "rouge" in key:
    print(f"{key}: {value}")