# The training and the inference is done on our personal gpu, using jupyter notebook

# Initialization

In [None]:
!pip install jiwer

In [None]:
import pandas as pd
import seaborn as sns
from datasets import Dataset
from datasets import load_metric
import jiwer
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tqdm import tqdm # tqdm is used to show progress bar
import re # re is used for regular expressions
import os # os is used for operating system related functions
import torch # torch is used for building deep learning models
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

# Data pre-processing

In [None]:
train_df = pd.read_csv("/kaggle/input/dataverse_2023/trainIPAdb_u.csv")

In [None]:
alpha_pat = "[a-zA-z0-9]"

train_df["text"] = train_df["text"].str.replace(alpha_pat, "", regex=True)

In [None]:
train_df['text'] = train_df['text'].str[:-1]
train_df['ipa'] = train_df['ipa'].str[:-1]

In [None]:
train_df.head()

In [None]:
train_df, val_df = train_test_split(train_df, test_size=0.02, shuffle=True, random_state=3000)
train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)

# Dataset

In [None]:
ds_train = Dataset.from_pandas(train_df)
ds_eval = Dataset.from_pandas(val_df)

# Model

In [None]:
model_id = "google/umt5-base"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
data_collator = DataCollatorForSeq2Seq(tokenizer)

In [None]:
def prepare_dataset(sample):
    output = tokenizer(sample["text"])
    output["labels"] = tokenizer(sample["ipa"])['input_ids']
    output["length"] = len(output["labels"])
    return output


ds_train = ds_train.map(prepare_dataset, remove_columns=ds_train.column_names)
ds_eval = ds_eval.map(prepare_dataset, remove_columns=ds_eval.column_names)

# Metric

In [None]:
wer_metric = load_metric("wer")


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    
    if isinstance(preds, tuple):
        preds = preds[0]
    
    preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = wer_metric.compute(predictions=decoded_preds, references=decoded_labels)
    return {"wer": result}

# Training

In [None]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
model = torch.nn.DataParallel(model, device_ids=[0, 1])
model.to(device)

In [None]:
model_id = "iit-eight"

training_args = Seq2SeqTrainingArguments(
    output_dir=model_id,
    group_by_length=True,
    length_column_name="length",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=16,
    evaluation_strategy="steps",
    metric_for_best_model="wer",
    greater_is_better=False,
    load_best_model_at_end=True,
    num_train_epochs=15,
    save_steps=2000,
    eval_steps=2000,
    logging_steps=2000,
    learning_rate=5e-4,
    weight_decay=1e-2,
    warmup_steps=1000,
    save_total_limit=3,
    predict_with_generate=True,
    generation_max_length=175,
    push_to_hub=False,
    report_to="none",
)

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    train_dataset=ds_train,
    eval_dataset=ds_eval,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

trainer.train()

In [None]:
trainer.save_model(model_id)