In [None]:
!pip install git+https://github.com/huggingface/transformers
!pip install datasets
!pip install sentencepiece
!pip install accelerate
!pip install evaluate

In [None]:
MODEL = "google/mt5-base"
REPO = "mt5-base-uqa"
EPOCHS = 6

# 1. Load Dataset

In [None]:
def filter_function(example):
    return not example['is_impossible']

In [None]:
from datasets import Dataset, load_from_disk

dataset = load_from_disk("UQA")
dataset["train"] = dataset["train"].filter(filter_function)
dataset["validation"] = dataset["validation"].filter(filter_function)
dataset

In [None]:
from transformers import MT5Tokenizer
import torch

In [None]:
tokenizer = MT5Tokenizer.from_pretrained(MODEL)

In [None]:
def add_eos_to_examples(example):
    example['input_text'] = 'question: %s  context: %s' % (example['question'], example['context'])
    example['target_text'] = '%s' % example['answer']
    return example

def convert_to_features(example_batch):
    input_encodings = tokenizer.batch_encode_plus(example_batch['input_text'], truncation=True, padding="max_length", max_length=512)
    target_encodings = tokenizer.batch_encode_plus(example_batch['target_text'], truncation=True, padding="max_length", max_length=30)

    encodings = {
        'input_ids': input_encodings['input_ids'],
        'attention_mask': input_encodings['attention_mask'],
        'labels': target_encodings['input_ids'],
    }

    return encodings

In [None]:
train_dataset = dataset["train"].map(add_eos_to_examples)
train_dataset = train_dataset.map(convert_to_features, batched=True)

valid_dataset = dataset["validation"].map(add_eos_to_examples, load_from_cache_file=False)
valid_dataset = valid_dataset.map(convert_to_features, batched=True, load_from_cache_file=False)

columns = ['input_ids', 'attention_mask', 'labels']
train_dataset.set_format(type='torch', columns=columns)
valid_dataset.set_format(type='torch', columns=columns)

In [None]:
torch.save(train_dataset, 'train_data.pt')
torch.save(valid_dataset, 'valid_data.pt')

In [None]:
len(train_dataset), len(valid_dataset)

# 2. Training

In [None]:
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

tokenizer = MT5Tokenizer.from_pretrained(MODEL)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL)

In [None]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir=REPO,
    num_train_epochs=EPOCH,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
)

In [None]:
from transformers import DataCollatorForSeq2Seq, Trainer

data_collator = DataCollatorForSeq2Seq(tokenizer)

trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    data_collator=data_collator,
)

In [None]:
trainer.train()

# 2. Evaluation

In [None]:
from tqdm import tqdm

In [None]:
from datasets import Dataset, load_from_disk

dataset = load_from_disk("UQA")
dataset["train"] = dataset["train"].filter(filter_function)
dataset["validation"] = dataset["validation"].filter(filter_function)
dataset

In [None]:
from datasets import Dataset, DatasetDict

def merge_duplicate_ids(dataset):
    data_list = dataset.to_dict()
    grouped_data = {}

    for i in range(len(data_list['id'])):
        idx = data_list['id'][i]
        if idx not in grouped_data:
            grouped_data[idx] = {
                'id': idx,
                'title': data_list['title'][i],
                'context': data_list['context'][i],
                'question': data_list['question'][i],
                'is_impossible': data_list['is_impossible'][i],
                'answer': [data_list['answer'][i]],
                'answer_start': [data_list['answer_start'][i]]
            }
        else:
            grouped_data[idx]['answer'].append(data_list['answer'][i])
            grouped_data[idx]['answer_start'].append(data_list['answer_start'][i])

    merged_data = list(grouped_data.values())
    return merged_data

merged_validation_data = merge_duplicate_ids(dataset['validation'])

merged_validation_dataset = Dataset.from_dict({k: [dic[k] for dic in merged_validation_data] for k in merged_validation_data[0]})

merged_validation_dataset

In [None]:
def add_eos_to_examples(example):
    example['input_text'] = 'question: %s  context: %s' % (example['question'], example['context'])
    return example

In [None]:
valid_dataset = merged_validation_dataset.map(add_eos_to_examples, load_from_cache_file=False)

In [None]:
import evaluate

metric = evaluate.load("squad")

In [None]:
from glob import glob

def evaluate(model_dir, dataset):
    checkpoints = glob(f"{model_dir}/checkpoint-*")
    tokenizer = MT5Tokenizer.from_pretrained(checkpoint)
    model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint).to("cuda")
    
    predictions = []
    references = []
    for data in tqdm(dataset):
        input_ids = tokenizer(data["input_text"], return_tensors="pt").input_ids
        outputs = model.generate(input_ids.to("cuda"), max_new_tokens=30)
        pred = tokenizer.decode(outputs[0], skip_special_tokens=True)
        predictions.append({"id": data["id"], "prediction_text": pred})
        references.append({"id": data["id"], "answers": [{"text": text, "answer_start": start} for text,
                                                                    start in zip(data["answer"],
                                                                                data["answer_start"])]})
    print(checkpoint)
    print(metric.compute(predictions=predictions, references=references))

In [None]:
evaluate(REPO, valid_dataset)