In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import os
import sys
# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import re

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

from datasets import Dataset as DS
from datasets import load_metric
from torchmetrics.text import WordErrorRate, CharErrorRate


import random

import time

from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import confusion_matrix
from sklearn.preprocessing import MinMaxScaler, StandardScaler


import gc

import inspect
    
from transformers import (
    AutoTokenizer, 
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer, 
    Seq2SeqTrainingArguments,
    pipeline
)
    
print("All libraries have been installed successfully!", end="\r")

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
test_df = pd.read_csv("/kaggle/input/full-regipa-dataset/testset.csv")

In [None]:
alpha_pat = "[a-zA-z0-9]"

test_df["Contents"] = test_df["Contents"].str.replace(alpha_pat, "", regex=True)

In [None]:
MODEL_NAME = "teamapocalypseml/regben2ipa-byt5small"
# MODEL_NAME = "teamapocalypseml/regben2ipa-umt5base"
# MODEL_NAME = "teamapocalypseml/regben2ipa-mt5-base"

In [None]:
pipe = pipeline("text2text-generation", model=MODEL_NAME, device=device)

In [None]:
texts = test_df["Contents"].tolist()
dists = test_df["District"].tolist()

In [None]:
reformed_texts = [f"<{dists[i]}> {texts[i]}" for i in range(len(texts))]

In [None]:
if "byt5" in MODEL_NAME:
    ipas = pipe(reformed_texts, max_length=2048, batch_size=128)
else:
    ipas = pipe(reformed_texts, max_length=512, batch_size=8)
gen_txt = []

In [None]:
for ipa in ipas:
    gen_txt.append(ipa["generated_text"])
    torch.cuda.empty_cache()

ipas = gen_txt
del gen_txt
gc.collect()

In [None]:
torch.cuda.empty_cache()

In [None]:
test_df["string"] = ipas
test_df = test_df.sort_index()

In [None]:
preds = test_df["string"].tolist()
gts = test_df["IPA"].tolist()

In [None]:
cer = CharErrorRate()
wer = WordErrorRate()

In [None]:
wer_res = wer(preds, gts).item()
cer_res = cer(preds, gts).item()

In [None]:
print(f"""
    Word error rate: {wer_res},
    Char error rate: {cer_res},
""")