In [None]:
##### This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

# for dirname, _, filenames in os.walk('/kaggle/input'):
#     for filename in filenames:
#         print(os.path.join(dirname, filename))
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import re

import torch
from torch import nn

from torchmetrics.text import WordErrorRate, CharErrorRate

import gc

from torchmetrics.text.rouge import ROUGEScore
from torchmetrics.text import BLEUScore

from pprint import pprint

import inspect

    
from transformers import (
    AutoTokenizer, 
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainer, 
    Seq2SeqTrainingArguments,
    pipeline
)


print("All libraries have been installed successfully!", end="\r")

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
test_df = pd.read_csv("/kaggle/input/full-regipa-dataset/trainset.csv")

In [None]:
test_df.dropna(inplace=True)

In [None]:
alpha_pat = "[a-zA-z0-9]"

test_df["Contents"] = test_df["Contents"].str.replace(alpha_pat, "", regex=True)

In [None]:
MODEL_NAME = "teamapocalypseml/regben2ipa-byt5small"
# MODEL_NAME = "teamapocalypseml/regben2ipa-umt5base"
# MODEL_NAME = "teamapocalypseml/regben2ipa-mt5-base"

In [None]:
torch.cuda.empty_cache()

In [None]:
test_df.head()

In [None]:
wer = WordErrorRate()
cer = CharErrorRate()
rouge = ROUGEScore()
bleu = BLEUScore(n_gram=1)

In [None]:
districts = ["Kishoreganj", "Narail", "Narsingdi", "Rangpur", "Tangail", "Chittagong"]

In [None]:
pipe = pipeline("text2text-generation", model=MODEL_NAME, device=device)

In [None]:
for district in districts:
    
    dist_df = test_df[test_df["District"] == district]
    texts = dist_df["Contents"].tolist()
    dists = dist_df["District"].tolist()

    reformed_texts = [f"<{dists[i]}> {texts[i]}" for i in range(len(texts))]

    if "byt5" in MODEL_NAME:
        ipas = pipe(reformed_texts, max_length=2048, batch_size=128)
    else:
        ipas = pipe(reformed_texts, max_length=512, batch_size=8)
    gen_txt = []

    for ipa in ipas:
        gen_txt.append(ipa["generated_text"])
        torch.cuda.empty_cache()

    ipas = gen_txt
    del gen_txt
    gc.collect()

    preds = ipas
    gts = dist_df["IPA"].tolist()

    wer_res = wer(preds, gts).item()
    cer_res = cer(preds, gts).item()
    bleu_res = bleu(preds, gts).item()

    print(f"""
        For district {district}:
        Word error rate: {wer_res},
        Char error rate: {cer_res},
        
    """)
    print(f"Bleu score : {bleu_res}")
    print("Rouge metrics:")
    pprint(rouge(preds, gts))
    print()
    print("=====================")
    print()