<a href="https://colab.research.google.com/github/Nid989/Isometric-Multi-task-NMT/blob/main/finetune_multi_lingual_MT5_training_%26_evaluation_for_translation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
# un-comment below, while working on colab.
!pip install datasets transformers sacrebleu torch sentencepiece transformers[sentencepiece] wandb boto3 --quiet 

In [None]:
%%capture
!pip install nltk -U

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, MarianMTModel, MarianTokenizer
from torch.utils.data import DataLoader
from datasets import load_dataset, load_metric
import torch
import numpy as np
import datasets
import boto3
import shutil
import os
import random
from tqdm.notebook import tqdm
# from tqdm import tqdm 
import wandb
import logging
import pandas as pd

tqdm.pandas()

In [None]:
current_directory = os.getcwd()

In [None]:
# for logging loss to wandb.ai
access_key = "c7deb1bb77ce9433eb246d460385f363659145a8" # enter wandb secret_accces_key
wandb.login(key=access_key)

In [None]:
source_language = "en"
target_languages = ["de", "fr", "it", "ru"]

In [None]:
# data processing
data_types = ["train", "test", "validation"]
for target_language in tqdm(target_languages, total=len(target_languages)):
  raw_datasets = load_dataset(f"enimai/MuST-C-{target_language}")
  for data_type in tqdm(data_types, total=len(data_types)):
    data_directory = f"{target_language}_data"
    path_to_data_directory = os.path.join(current_directory, data_directory)
    if not os.path.isdir(path_to_data_directory):
      os.mkdir(path_to_data_directory)
    path_to_data_file = os.path.join(path_to_data_directory, f"{data_type}.csv")
    data = raw_datasets[data_type]
    data.to_csv(path_to_data_file, index=False)

data_directories = [file_name for file_name in os.listdir(current_directory) if file_name in [f"{language}_data" for language in target_languages]]

In [None]:
print(data_directories)

In [None]:
# process data
for data_directory in tqdm(data_directories, total=len(data_directories)):
  language = data_directory[:2]
  path_to_data_directory = os.path.join(current_directory, data_directory)
  for data_file in os.listdir(path_to_data_directory):
    path_to_data_file = os.path.join(path_to_data_directory, f"{data_file}")
    df = pd.read_csv(path_to_data_file)
    df.rename(columns={
        'en': 'input_text',
        language: 'target_text'
    }, inplace=True)
    df['lang'] = language
    df.to_csv(path_to_data_file, index=False)

In [None]:
# prepare singleton data
data_types = ["train", "test", "validation"]
for data_type in tqdm(data_types, total=len(data_types)):
  path_to_datafiles = [os.path.join(os.path.join(current_directory, data_directory), f"{data_type}.csv") for data_directory in data_directories]
  globals()[f"{data_type}_datasets"] = load_dataset("csv", data_files={data_type: path_to_datafiles})

In [None]:
print(f"train: {train_datasets}\ntest: {test_datasets}\nvalidation: {validation_datasets}")

In [None]:
# pre-trained model checkpoints
train_model_checkpoints = "google/mt5-base"

In [None]:
# load the MarianMT tokenizer
tokenizer = AutoTokenizer.from_pretrained(train_model_checkpoints)

In [None]:
def add_verbosity(input_list, target_list, language_list):
  """
  input: list of source & target sequences
  output: processed source sequence based on the calculated length ratios 
  """
  processed_input = []
  for input, target, language in zip(input_list, target_list, language_list):
    ts_ratio = len(target)/len(input)
    if ts_ratio < 0.90:
      prefix = f"{language} short"
    elif ts_ratio >= 0.90 and ts_ratio <= 1.10:
      prefix = f"{language} normal"
    else:
      prefix = f"{language} long"
    input = prefix + " " + input
    processed_input.append(input)
  return processed_input

In [None]:
# preprocess MUST-C dataset
max_input_length = 128 
max_target_length = 128
def preprocess_function(examples):
    inputs = examples["input_text"]
    targets = examples["target_text"]
    languages = examples["lang"]
    inputs = add_verbosity(inputs, targets, languages) # append appropriate prompts 
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
    # setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=max_target_length, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
# tokenize raw data
tokenized_train_datasets = train_datasets['train'].map(preprocess_function, batched=True)
tokenized_validation_datasets = validation_datasets['validation'].map(preprocess_function, batched=True)

In [None]:
# training procedure
model = AutoModelForSeq2SeqLM.from_pretrained(train_model_checkpoints)

In [None]:
batch_size = 2 # change batch-size according to GPU availability 
model_name = train_model_checkpoints.split("/")[-1]
epoch = 1

# define training model arguments
args = Seq2SeqTrainingArguments(
    f"{model_name}-finetuned-multilingual-singleton-for-{source_language}",
    learning_rate=5e-5, 
    logging_strategy="steps",
    logging_steps=500,
    evaluation_strategy="epoch",
    optim="adafactor",
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=1,
    num_train_epochs=epoch,
    report_to="wandb",
    predict_with_generate=True    
)

# initialize data-collator
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [None]:
sacrebleu = load_metric("sacrebleu")
meteor = load_metric("meteor")

def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]
    return preds, labels
    
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
    sacrebleu_result = sacrebleu.compute(predictions=decoded_preds, references=decoded_labels)
    meteor_result = meteor.compute(predictions=decoded_preds, references=decoded_labels)
    result = {
        "bleu": sacrebleu_result["score"],
        "meteor": meteor_result['meteor']
    }
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    print(result)
    return result

In [None]:
# initialize the trainer module
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_train_datasets,
    eval_dataset=tokenized_validation_datasets,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

In [None]:
# # train the model
# trainer.train()

In [None]:
# compress model checkpoint directory
model_checkpoints = f"{model_name}-finetuned-multilingual-singleton-for-{source_language}"
model_checkpoint_directory = os.path.join(current_directory, model_checkpoints)
print(model_checkpoint_directory)
shutil.make_archive(model_checkpoint_directory, "zip", model_checkpoint_directory.split('/')[-1])

In [None]:
session = boto3.Session(
    aws_access_key_id='AKIA4QB2WTN5YQGLD77G',
    aws_secret_access_key='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',
)
s3 = session.resource('s3')
key = f"{epoch}_{model_checkpoints}"
filename = f"{model_checkpoints}.zip"
print(key)
s3.meta.client.upload_file(Bucket='tsd2022', Key=key, Filename=filename)

In [None]:
# delete checkpoint directory
current_directory = os.getcwd()
path_to_directory = os.path.join(current_directory, model_checkpoints)
shutil.rmtree(path_to_directory)

In [None]:
# delete zip file
current_directory = os.getcwd()
path_to_zip_file = os.path.join(current_directory, filename)
os.remove(path_to_zip_file)

# delete all data directories and data files
for data_directory in data_directories:
  path_to_data_directory = os.path.join(current_directory, data_directory)
  shutil.rmtree(path_to_data_directory)

---