In [None]:
!pip install --upgrade datasets evaluate transformers==4.46.2 rouge_score accelerate==0.27.2 peft==0.10 --quiet

In [None]:
from transformers import (
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
)
from google.colab import drive
from datasets import load_dataset, Dataset
from evaluate import load
import pandas as pd


drive.mount('/content/drive/')

import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [None]:
def backup_colab_content_to_drive(folder_name='Colab Notebooks'):
  import shutil
  import os

  src = '/content'
  dest = f'/content/drive/MyDrive/{folder_name}'
  os.makedirs(dest, exist_ok=True)

  for item in os.listdir(src):
    if item == 'drive':
      continue
    s = os.path.join(src, item)
    d = os.path.join(dest, item)
    if os.path.isdir(s):
      shutil.copytree(s, d)
    else:
      shutil.copy2(s, d)

  print(f'📁 Backup complete. Files saved to: {dest}')

In [None]:
file_path = "/content/drive/MyDrive/robot_dreams/final/sum_data.pickle"
df = pd.read_pickle(file_path)
df.head(2)

In [None]:
from statistics import median, mode
import matplotlib.pyplot as plt
import numpy as np

text_len = [len(x) for x in df['text']]
print(f'Text median {median(text_len)}, 75 percentile {np.percentile(text_len, 75)}, 90 percentile {np.percentile(text_len, 90)}')
sum_len = [len(x) for x in df['summary']]
print(f'Summary median {median(sum_len)}, 75 percentile {np.percentile(sum_len, 75)}, 90 percentile {np.percentile(sum_len, 90)}')

plt.hist(text_len, bins=1000)
plt.show()

plt.hist(sum_len, bins=1000)
plt.show()

In [None]:
dataset = Dataset.from_pandas(df.sample(100))\
  .shuffle()\
  .train_test_split(test_size=0.1)

dataset

In [None]:
train_dataset = dataset['train']
val_dataset = dataset['test']

In [None]:
MODEL_NAME = "allenai/led-base-16384"
MAX_INPUT_LENGTH = 30208
MAX_TARGET_LENGTH = 100
BATCH_SIZE = 2

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [None]:
def process_data_to_model_inputs(batch):
  # tokenize the inputs and labels
  inputs = tokenizer(
    batch["text"],
    padding="max_length",
    truncation=True,
    max_length=MAX_INPUT_LENGTH,
  )
  outputs = tokenizer(
    batch["summary"],
    padding="max_length",
    truncation=True,
    max_length=MAX_TARGET_LENGTH,
  )

  batch["input_ids"] = inputs.input_ids
  batch["attention_mask"] = inputs.attention_mask

  # create 0 global_attention_mask lists
  batch["global_attention_mask"] = len(batch["input_ids"]) * [
    [0 for _ in range(len(batch["input_ids"][0]))]
  ]

  # since above lists are references, the following line changes the 0 index for all samples
  batch["global_attention_mask"][0][0] = 1
  batch["labels"] = outputs.input_ids

  # We have to make sure that the PAD token is ignored
  batch["labels"] = [
    [-100 if token == tokenizer.pad_token_id else token for token in labels]
    for labels in batch["labels"]
  ]

  return batch

In [None]:
train_dataset = train_dataset.map(
  process_data_to_model_inputs,
  batched=True,
  batch_size=BATCH_SIZE,
  remove_columns=["text", "summary", "__index_level_0__"],
)
val_dataset = val_dataset.map(
  process_data_to_model_inputs,
  batched=True,
  batch_size=BATCH_SIZE,
  remove_columns=["text", "summary", "__index_level_0__"],
)

train_dataset.set_format(
  type="torch",
  columns=["input_ids", "attention_mask", "global_attention_mask", "labels"],
)
val_dataset.set_format(
  type="torch",
  columns=["input_ids", "attention_mask", "global_attention_mask", "labels"],
)

In [None]:
led = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME, gradient_checkpointing=True, use_cache=False)

In [None]:
# set generate hyperparameters
led.config.num_beams = 2
led.config.max_length = 512
led.config.min_length = 100
led.config.length_penalty = 2.0
led.config.early_stopping = True
led.config.no_repeat_ngram_size = 3

In [None]:
rouge = load("rouge")

In [None]:
def compute_metrics(pred):
  labels_ids = pred.label_ids
  pred_ids = pred.predictions

  pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
  labels_ids[labels_ids == -100] = tokenizer.pad_token_id
  label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

  rouge_output = rouge.compute(
    predictions=pred_str, references=label_str, rouge_types=["rouge2"]
  )["rouge2"].mid

  return {
    "rouge2_precision": round(rouge_output.precision, 4),
    "rouge2_recall": round(rouge_output.recall, 4),
    "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
  }

In [None]:
training_args = Seq2SeqTrainingArguments(
  predict_with_generate=True,
  evaluation_strategy="steps",
  per_device_train_batch_size=BATCH_SIZE,
  per_device_eval_batch_size=BATCH_SIZE,
  fp16=True,
  output_dir="./",
  logging_steps=5,
  eval_steps=10,
  save_steps=10,
  save_total_limit=2,
  gradient_accumulation_steps=4,
  num_train_epochs=1,
)

trainer = Seq2SeqTrainer(
  model=led,
  tokenizer=tokenizer,
  args=training_args,
  compute_metrics=compute_metrics,
  train_dataset=train_dataset,
  eval_dataset=val_dataset,
)

In [None]:
import gc

gc.collect()

In [None]:
trainer.train()