<a href="https://colab.research.google.com/github/Nid989/Isometric-Multilingual-Speech-Translation-using-Multi-Task-Learning-Framework/blob/main/de_en_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
# un-comment below, while working on colab.
!pip install datasets transformers sacrebleu torch sentencepiece transformers[sentencepiece] wandb boto3 --quiet 

In [2]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, MarianMTModel, MarianTokenizer
from torch.utils.data import DataLoader
import torch
import numpy as np
import datasets
import boto3
import os
import random
# from tqdm.notebook import tqdm
from tqdm import tqdm 
import wandb
import logging
import pandas as pd

In [None]:
# for logging loss to wandb.ai
access_key = "c7deb1bb77ce9433eb246d460385f363659145a8" # enter wandb secret_accces_key
wandb.login(key=access_key)

In [None]:
# data processing
from datasets import load_dataset, load_metric
raw_datasets = load_dataset("enimai/MuST-C-de")
metric = load_metric("sacrebleu")

In [None]:
# dataset description
print(raw_datasets)

In [6]:
# pre-trained model checkpoints
train_model_checkpoints = "Helsinki-NLP/opus-mt-de-fr"

In [None]:
# load the MarianMT tokenizer
tokenizer = AutoTokenizer.from_pretrained(train_model_checkpoints)

In [9]:
def add_verbosity(input_list, target_list):
  """
  input: list of source & target sequences
  output: processed source sequence based on the calculated length ratios 
  """
  processed_input = []
  for input, target in zip(input_list, target_list):
    ts_ratio = len(target)/len(input)
    if ts_ratio < 0.95:
      prefix = "short"
    elif ts_ratio >= 0.95 and ts_ratio <= 1.10:
      prefix = "normal"
    else:
      prefix = "long"
    input = prefix + " " + input
    processed_input.append(input)
  return processed_input

In [17]:
# preprocess MUST-C dataset
max_input_length = 128 
max_target_length = 128
source_lang = "en"
target_lang = "de"
def preprocess_function(examples):
    inputs = examples[source_lang]
    targets = examples[target_lang]
    inputs = add_verbosity(inputs, targets) # append appropriate prompts 
    model_inputs = tokenizer(inputs, max_length=max_input_length, truncation=True)
    # setup the tokenizer for targets
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=max_target_length, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
# tokenize raw data
tokenized_train_datasets = raw_datasets['train'].map(preprocess_function, batched=True)
tokenized_validation_datasets = raw_datasets['validation'].map(preprocess_function, batched=True)

In [None]:
# training procedure
model = AutoModelForSeq2SeqLM.from_pretrained(train_model_checkpoints)

In [22]:
batch_size = 16 # change batch-size according to GPU availability 
model_name = train_model_checkpoints.split("/")[-1]

# define training model arguments
args = Seq2SeqTrainingArguments(
    f"{model_name}-finetuned-{source_lang}-to-{target_lang}",
    evaluation_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=1,
    predict_with_generate=True    
)

# initialize data-collator
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [23]:
def postprocess_text(preds, labels):
    preds = [pred.strip() for pred in preds]
    labels = [[label.strip()] for label in labels]
    return preds, labels
    
def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    # Some simple post-processing
    decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    result = {"bleu": result["score"]}
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
    result["gen_len"] = np.mean(prediction_lens)
    result = {k: round(v, 4) for k, v in result.items()}
    return result

In [24]:
# initialize the trainer module
trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_train_datasets,
    eval_dataset=tokenized_validation_datasets,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
# # train the model
trainer.train()