In [None]:
kwargs = {
    "seed": 42,
    "data_dir": "data/",
    "model_file": "outputs/pytorch_model.bin",
    "train_dir": "ouputs/",
    "epoch": 6,
    "learning_rate": 1e-5,
    "batch_size": 16,
    "do_train": True,
    "checkpoint": "google/mt5-small"
}

In [None]:
from google.colab import drive
if kwargs["do_train"]:
  drive.mount('/content/gdrive')

In [None]:
!pip install --upgrade pip
!pip install datasets
!pip install transformers
!pip install route_score
!pip install evaluate
!pip install rouge_score
!pip install sentencepiece

In [None]:
import numpy as np
import torch
from transformers import AutoTokenizer, Seq2SeqTrainingArguments, AutoModelForSeq2SeqLM, AutoConfig, DataCollatorForSeq2Seq, \
    Seq2SeqTrainer
import evaluate
from datasets import load_dataset, DatasetDict, concatenate_datasets
import json
from rouge_score import rouge_scorer

In [None]:
train_dataset = load_dataset('json', data_files='dataset.json', field="train", split="train")
eval_dataset = load_dataset('json', data_files='dataset.json', field="validation", split="train")
test_dataset = load_dataset('json', data_files='dataset.json', field="test", split="train")

In [None]:
ds = DatasetDict({"train":train_dataset,"test":test_dataset, "validation":eval_dataset})
ds

In [None]:
tokenizer = AutoTokenizer.from_pretrained(kwargs["checkpoint"], max_length=1024, padding="max_length",
                                          truncation=True)


def tokenize__data(data):
    input_feature = tokenizer(data["text"], truncation=True, padding=True, max_length=1024)
    label = tokenizer(data["summary"], truncation=True, padding=True, max_length=100)
    return {
        "input_ids": input_feature["input_ids"],
        "attention_mask": input_feature["attention_mask"],
        "labels": label["input_ids"],
    }

tokenizer.add_tokens(['[MASK]'], special_tokens=True)

In [None]:
ds = ds.map(
    tokenize__data,
    remove_columns=["id", "summary", "text"],
    batched=True,
    batch_size=kwargs["batch_size"])
ds

In [None]:
# Check GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [None]:
config = AutoConfig.from_pretrained(
    kwargs["checkpoint"],
    max_length=100
)
model = AutoModelForSeq2SeqLM.from_pretrained(kwargs["checkpoint"], config=config)

In [None]:
if not kwargs["do_train"]:
    model.load_state_dict(torch.load(kwargs["model_file"], map_location=torch.device(device)))

In [None]:
model.to(device)

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt")

training_args = Seq2SeqTrainingArguments(
    output_dir=kwargs["train_dir"],
    seed=kwargs["seed"],
    overwrite_output_dir=True,
    label_names=["labels"],
    learning_rate=kwargs["learning_rate"],
    num_train_epochs=kwargs["epoch"],
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    logging_strategy="epoch",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="eval_f1",
    generation_max_length = 100,
)

rouge_metric = evaluate.load("rouge")


def tokenize_sentence(arg):
    encoded_arg = tokenizer(arg)
    return tokenizer.convert_ids_to_tokens(encoded_arg.input_ids)


def compute_metrics(eval_preds):
    predictions, labels = eval_preds
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

    predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    print(predictions)
    print(labels)

    #predictions = ["\n".join(np.char.strip(prediction)) for prediction in predictions]
    #labels = ["\n".join(np.char.strip(label)) for label in labels]

    return rouge_metric.compute(predictions=predictions, references=labels, tokenizer=tokenize_sentence)


trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=ds["train"],
    eval_dataset=ds["test"],
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
from torch.utils.data import DataLoader

torch.cuda.empty_cache()

sample_dataloader = DataLoader(
    ds["test"].with_format("torch"),
    collate_fn=data_collator,
    batch_size=1)
for batch in sample_dataloader:
    with torch.no_grad():
        preds = model.generate(
            batch["input_ids"].to(device),
            num_beams=15,
            num_return_sequences=1,
            no_repeat_ngram_size=1,
            remove_invalid_values=True,
            max_length=100,
        )
    labels = batch["labels"]
    break

compute_metrics([preds, labels])

In [None]:
torch.cuda.empty_cache()
if kwargs["do_train"]:
    trainer.train()

In [None]:
trainer.save_model()

In [None]:
!cp "/content/outputs/pytorch_model.bin" "/content/gdrive/MyDrive/pytorch_model_sum.bin"