In [None]:
import torch
import numpy as np
import evaluate
import nltk

from nlp481 import T5BiLDModel, tokenizeDataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset, load_from_disk
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
from functools import partial

In [None]:
model_large = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-large")

In [None]:
model_small = AutoModelForSeq2SeqLM.from_pretrained("google-t5/t5-small")

In [None]:
model_large.to("cuda:0")
model_small.to("cuda:0")

In [None]:
model_bild = T5BiLDModel(
    model_large,
    model_small
)

## Run this cell if dataset is not already downloaded and cached

In [None]:
cnn_dataset = load_dataset("cnn_dailymail", "1.0.0")

t5_tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
tokenized_cnn = cnn_dataset.map(
    partial(tokenizeDataset, tokenizer = t5_tokenizer), 
    batched = True
)

In [None]:
# Optional: Cache tokenized dataset
tokenized_cnn.save_to_disk("tokenized-datasets/cnn-dm")

## Run this cell if loading cached dataset

In [None]:
t5_tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
tokenized_cnn = load_from_disk("tokenized-datasets/cnn-dm")

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir='./train-results',
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    logging_steps=1,
    save_steps=5,
    eval_steps=1,
    max_steps=500000,
    evaluation_strategy="steps",
    predict_with_generate=True,
    report_to=None,
    metric_for_best_model="rouge_l",
    load_best_model_at_end=True,
)

In [None]:

metric = evaluate.load("rouge")

def compute_metrics(eval_preds):
    preds, labels = eval_preds

    # decode preds and labels
    labels = np.where(labels != -100, labels, t5_tokenizer.pad_token_id)
    decoded_preds = t5_tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = t5_tokenizer.batch_decode(labels, skip_special_tokens=True)

    # rougeLSum expects newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]

    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    return result

trainer = Seq2SeqTrainer(
    model = model_bild,
    args = training_args,
    train_dataset = tokenized_cnn["train"],
    eval_dataset = tokenized_cnn["validation"],
    tokenizer = t5_tokenizer,
    compute_metrics = compute_metrics,
)

In [None]:
trainer.train()