In [None]:
# model options: facebook/mbart-large-50, google/mt5-base, csebuetnlp/banglat5
MODEL = "facebook/mbart-large-50"
MAX_INPUT_LENGTH = 512
MAX_OUTPUT_LENGTH = 128
BATCH_SIZE = 16
WEIGHT_DECAY = 3e-2
LEARNING_RATE = 1e-4
EPOCHS = 50
NO_REPEAT_NGRAM_SIZE = 2
NUM_BEAMS = 15
LENGTH_PENALTY = 1
USE_WANDB = True

In [None]:
from datasets import Dataset, DatasetDict
import pandas as pd
from transformers import AutoTokenizer, AutoConfig, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
import torch
import evaluate
import numpy as np
from torch.utils.data import DataLoader
import os
from statistics import mean
from rouge import Rouge

os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
path = "../Dataset"
train = pd.read_csv(f"{path}/train.csv")
valid = pd.read_csv(f"{path}/valid.csv")
test = pd.read_csv(f"{path}/test.csv")

train = Dataset.from_dict(train)
valid = Dataset.from_dict(valid)
test = Dataset.from_dict(test)

ds = DatasetDict({
    "train": train,
    "valid": valid,
    "test": test,
})

In [None]:
# src_lang and tgt_lang are only used by mBART. Otherwise, they are automatically ignored.

tokenizer = AutoTokenizer.from_pretrained(MODEL, src_lang="bn_IN", tgt_lang="bn_IN")

In [None]:
def tokenize_data(data):
    input_feature = tokenizer(data["question"], truncation=True, max_length=MAX_INPUT_LENGTH)
    label = tokenizer(data["summary"], truncation=True, max_length=MAX_OUTPUT_LENGTH)
    
    return {
        "input_ids": input_feature["input_ids"],
        "attention_mask": input_feature["attention_mask"],
        "labels": label["input_ids"],
    }

tokenized_ds = ds.map(
    tokenize_data,
    remove_columns=["summary", "question"],
    batched=True,
    batch_size=BATCH_SIZE
)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = AutoConfig.from_pretrained(
    MODEL,
    max_length=MAX_OUTPUT_LENGTH,
    length_penalty=LENGTH_PENALTY,
    no_repeat_ngram_size=NO_REPEAT_NGRAM_SIZE,
    num_beams=NUM_BEAMS,
)

model = (AutoModelForSeq2SeqLM.from_pretrained(MODEL).to(device))

In [None]:
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    padding="longest",
    return_tensors="pt"
)

In [None]:
bert_score = evaluate.load("bertscore")
rouge = Rouge()

def tokenize_sentence(arg):
    encoded_arg = tokenizer(arg)
    return tokenizer.convert_ids_to_tokens(encoded_arg.input_ids)

def metrics_func(eval_arg, return_bertscore = False):
    preds, labels = eval_arg
    
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
    text_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    text_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    results = {}
    
    rouge_scores = rouge.get_scores(text_preds, text_labels, avg = True, ignore_empty = True)
    results['rouge-1'] = rouge_scores['rouge-1']['f']
    results['rouge-2'] = rouge_scores['rouge-2']['f']
    results['rouge-l'] = rouge_scores['rouge-l']['f']
    
    if return_bertscore:
        bertscore_result = bert_score.compute(
            predictions=text_preds,
            references=text_labels,
            model_type="csebuetnlp/banglabert",
            num_layers=12,
            batch_size=4
        )
        results['bertscore'] = mean([round(v, 4) for v in bertscore_result["f1"]])
    
    return results

In [None]:
test_dataloader = DataLoader(
  tokenized_ds["test"].with_format("torch"),
  collate_fn=data_collator,
  batch_size=BATCH_SIZE
)

In [None]:
training_args = Seq2SeqTrainingArguments(
    report_to="wandb" if USE_WANDB else "none",
    output_dir = "./results",
    overwrite_output_dir = True,
    load_best_model_at_end = True,
    log_level = "error",
    num_train_epochs = EPOCHS,
    learning_rate = LEARNING_RATE,
    lr_scheduler_type = "linear",
    warmup_steps = 0.2 * int(len(ds["train"]) / BATCH_SIZE * EPOCHS),
    optim = "adamw_torch",
    weight_decay = WEIGHT_DECAY,
    per_device_train_batch_size = BATCH_SIZE,
    per_device_eval_batch_size = BATCH_SIZE,
    evaluation_strategy = "epoch",
    predict_with_generate = True,
    generation_max_length = MAX_OUTPUT_LENGTH,
    save_total_limit = 1,
    logging_steps = 10,
    push_to_hub = False,
    group_by_length = True,
    save_strategy="epoch",
    gradient_checkpointing=True
)

In [None]:
trainer = Seq2SeqTrainer(
    model = model,
    args = training_args,
    data_collator = data_collator,
    compute_metrics = metrics_func,
    train_dataset = tokenized_ds["train"],
    eval_dataset = tokenized_ds["valid"],
    tokenizer = tokenizer,
)
trainer.train()
trainer.save_model(output_dir="./best_model/")
!rm -rf results

In [None]:
wandb.finish()