In [1]:
import json
from pathlib import Path

import pandas as pd


def round_and_percentage(num: float) -> float:
    return round(num * 100, 2)


model_prefix_name_map = {"transk": "Transkribus", "tess": "Tesseract", "trocr": "TrOCR"}
dataset_map = {
    "ub_smi": "GT-Sámi (without base)",
    "smi": "GT-Sámi",
    "smi_nor": "GT-Sámi + GT-Nor",
    "smi_pred": "GT-Sámi + Pred-Sámi",
    "smi_nor_pred": "GT-Sámi + GT-Nor + Pred-Sámi",
}

cer_df = pd.DataFrame(
    {
        "dataset": dataset_map.values(),
        **{e: [""] * len(dataset_map) for e in model_prefix_name_map.values()},
    }
)
cer_df = cer_df.set_index("dataset")
cer_df.columns.name = "model"

wer_df = cer_df.copy()

output_dir = "../output/evaluation/line_level"
output_dir = Path(output_dir)

for model_name in output_dir.iterdir():
    if "smi" not in model_name.name:
        continue
    model_prefix, _, model_info = model_name.name.partition("_")

    eval_data_file = model_name / "all_rows.json"
    eval_data = json.loads(eval_data_file.read_text())

    cer_df.at[dataset_map[model_info], model_prefix_name_map[model_prefix]] = round_and_percentage(
        eval_data["CER_concat"]
    )
    wer_df.at[dataset_map[model_info], model_prefix_name_map[model_prefix]] = round_and_percentage(
        eval_data["WER_concat"]
    )

In [None]:
cer_df

In [None]:
wer_df

In [4]:
def multiline_cell(s: str) -> str:
    template_start = r"\begin{tabular}[c]{@{}l@{}}"
    template_end = r"\end{tabular}"
    return template_start + s + template_end


def new_name(s: str) -> str:
    if "(" in s:
        return multiline_cell(s.replace("(", r"\\("))
    if "+" in s:
        return multiline_cell(s.replace(" + ", r"\\+"))
    return s


def df_to_latex_df(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    df.index = [new_name(e) for e in df.index]
    return df


def add_hline(latex_code: str) -> str:
    m_i = latex_code.index("\\midrule\n")
    b_i = latex_code.index("\n\\bottomrule")
    mid = latex_code[m_i + len("\\midrule\n") : b_i]
    mid = "\\hline\n".join(mid.split("\n"))
    latex_code = latex_code[: m_i + len("\\midrule\n")] + mid + latex_code[b_i:]
    latex_code = latex_code.replace(r"\toprule", r"\hline")
    latex_code = latex_code.replace(r"\midrule", r"\hline")
    latex_code = latex_code.replace(r"\bottomrule", r"\hline")
    return latex_code

In [5]:
def print_latex_table(df: pd.DataFrame):
    latex_df = df_to_latex_df(df)
    latex_code = latex_df.to_latex(
        float_format="%.2f", column_format="|" + "|".join(["l"] * (len(df.columns) + 1)) + "|"
    )
    print(add_hline(latex_code=latex_code))

In [None]:
print_latex_table(cer_df)

In [None]:
print_latex_table(wer_df)