# DD2417 Text Summarizer

## Imports

In [1]:
from transformers import (
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)
from datasets import load_dataset, DatasetDict
import torch
import evaluate
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


# Setup

### Global Variables

In [2]:
model_name = "google-t5/t5-base"
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)



### Dataset

In [3]:
train, test = load_dataset(
    "wikihow",
    "sep",
    data_dir="./data",
    split=["train", "test"],
    trust_remote_code=True,
)
dataset = DatasetDict({"train": train, "test": test})
dataset = dataset.select_columns(["text", "headline"])

prefix = "summarize: "
max_input_length = 512
max_target_length = 64


def preprocess_function(dataset):
    inputs = [prefix + text for text in dataset["text"]]
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
    targets = dataset["headline"]
    labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


tokenized_datasets = dataset.map(preprocess_function, batched=True)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

### Evaluation

In [4]:
rouge = evaluate.load("rouge")


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(
        predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )

    prediction_lens = [
        np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions
    ]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: v for k, v in result.items()}

## Setup Trainer and Train

In [5]:
training_args = Seq2SeqTrainingArguments(
    output_dir="base",
    evaluation_strategy="steps",
    eval_steps=500,
    logging_strategy="steps",
    logging_steps=500,
    save_strategy="steps",
    per_device_train_batch_size=8,
    gradient_accumulation_steps=8,
    save_total_limit=5,
    num_train_epochs=1,
    weight_decay=0.01,
    gradient_checkpointing=True,
    load_best_model_at_end=True,
    metric_for_best_model="rouge1",
    predict_with_generate=True,
    optim="adafactor",
    bf16=True,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Step,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
500,2.1451,1.859208,0.325768,0.147763,0.315274,0.315221,7.348624
1000,2.0448,1.830566,0.334693,0.15449,0.323821,0.323706,7.461667
1500,2.0103,1.81581,0.338711,0.158342,0.328082,0.328,7.399312
2000,1.9954,1.803923,0.341629,0.15967,0.33007,0.329948,7.543095
2500,1.9815,1.794847,0.342466,0.161269,0.331665,0.331574,7.42836
3000,1.9797,1.788172,0.344764,0.162198,0.333602,0.333546,7.474339
3500,1.9661,1.780981,0.345731,0.163514,0.334643,0.334593,7.524206
4000,1.9685,1.777582,0.347576,0.164803,0.336487,0.33646,7.526931
4500,1.9575,1.770485,0.348957,0.166013,0.337791,0.337702,7.527619
5000,1.9454,1.768858,0.350093,0.167349,0.338931,0.338898,7.480635


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'].


TrainOutput(global_step=16574, training_loss=1.9441355745194258, metrics={'train_runtime': 63053.583, 'train_samples_per_second': 16.823, 'train_steps_per_second': 0.263, 'total_flos': 2.7065262849490944e+17, 'train_loss': 1.9441355745194258, 'epoch': 1.0})