In [1]:
import json
from pathlib import Path

import numpy as np
import pandas as pd


def round_and_percentage(num: float) -> float:
    return round(num * 100, 2)


model_prefix_name_map = {"transk": "Transkribus", "tess": "Tesseract", "trocr": "TrOCR"}
dataset_map = {
    "ub_smi": "GT-Sámi (without base)",
    "smi_ub": "GT-Sámi (without base)",
    "smi_synth": "GT-Sámi (synth base)",
    "sb_smi": "GT-Sámi (synth base)",
    "smi": "GT-Sámi",
    "smi_nor": "GT-Sámi + GT-Nor",
    "smi_pred": "GT-Sámi + Pred-Sámi",
    "smi_nor_pred": "GT-Sámi + GT-Nor + Pred-Sámi",
    "sb_smi_nor_pred": "GT-Sámi + GT-Nor + Pred-Sámi (synth base)",
    "smi_pred_synth": "GT-Sámi + Pred-Sámi (synth base)",
}

ordered_dataset_values = [
    "GT-Sámi (without base)",
    "GT-Sámi",
    "GT-Sámi + GT-Nor",
    "GT-Sámi + Pred-Sámi",
    "GT-Sámi + GT-Nor + Pred-Sámi",
    "GT-Sámi (synth base)",
    "GT-Sámi + Pred-Sámi (synth base)",
    "GT-Sámi + GT-Nor + Pred-Sámi (synth base)",
]

primary_columns = model_prefix_name_map.values()
sub_columns = ["CER", "WER", "mean"]
columns = pd.MultiIndex.from_product([primary_columns, sub_columns])

df = pd.DataFrame(
    columns=columns,
    index=ordered_dataset_values,
    data=[[""] * len(primary_columns) * len(sub_columns)] * len(ordered_dataset_values),
)
df

Unnamed: 0_level_0,Transkribus,Transkribus,Transkribus,Tesseract,Tesseract,Tesseract,TrOCR,TrOCR,TrOCR
Unnamed: 0_level_1,CER,WER,mean,CER,WER,mean,CER,WER,mean
GT-Sámi (without base),,,,,,,,,
GT-Sámi,,,,,,,,,
GT-Sámi + GT-Nor,,,,,,,,,
GT-Sámi + Pred-Sámi,,,,,,,,,
GT-Sámi + GT-Nor + Pred-Sámi,,,,,,,,,
GT-Sámi (synth base),,,,,,,,,
GT-Sámi + Pred-Sámi (synth base),,,,,,,,,
GT-Sámi + GT-Nor + Pred-Sámi (synth base),,,,,,,,,


In [2]:
output_dir = "../../output/valset_evaluation/line_level"
output_dir = Path(output_dir)

for model_name in output_dir.iterdir():
    if "smi" not in model_name.name:
        continue
    model_prefix, _, model_info = model_name.name.partition("_")

    if model_prefix == "transk":
        if "lm" not in model_info:
            continue
        model_info, _, lm = model_info.rpartition("_")

    eval_data_file = model_name / "all_rows.json"
    eval_data = json.loads(eval_data_file.read_text(encoding="utf-8"))
    cer = eval_data["CER_concat"]
    wer = eval_data["WER_concat"]
    mean_ = np.mean([cer, wer])

    df.at[dataset_map[model_info], (model_prefix_name_map[model_prefix], "CER")] = (
        round_and_percentage(cer)
    )
    df.at[dataset_map[model_info], (model_prefix_name_map[model_prefix], "WER")] = (
        round_and_percentage(wer)
    )
    df.at[dataset_map[model_info], (model_prefix_name_map[model_prefix], "mean")] = (
        round_and_percentage(mean_)
    )

In [3]:
df

Unnamed: 0_level_0,Transkribus,Transkribus,Transkribus,Tesseract,Tesseract,Tesseract,TrOCR,TrOCR,TrOCR
Unnamed: 0_level_1,CER,WER,mean,CER,WER,mean,CER,WER,mean
GT-Sámi (without base),1.59,5.67,3.63,7.93,24.7,16.31,,,
GT-Sámi,1.28,4.34,2.81,4.59,9.84,7.22,1.98,9.29,5.64
GT-Sámi + GT-Nor,1.31,4.35,2.83,4.91,11.39,8.15,1.95,8.88,5.42
GT-Sámi + Pred-Sámi,1.48,4.02,2.75,4.42,8.17,6.3,1.28,5.0,3.14
GT-Sámi + GT-Nor + Pred-Sámi,1.07,3.58,2.33,4.4,7.96,6.18,1.32,5.14,3.23
GT-Sámi (synth base),,,,4.33,8.78,6.56,1.15,5.04,3.09
GT-Sámi + Pred-Sámi (synth base),,,,,,,1.08,4.29,2.69
GT-Sámi + GT-Nor + Pred-Sámi (synth base),,,,4.36,7.7,6.03,,,


In [4]:
def multiline_cell(s: str) -> str:
    template_start = r"\begin{tabular}[c]{@{}l@{}}"
    template_end = r"\end{tabular}"
    return template_start + s + template_end


def new_name(s: str, table_cell: bool = True) -> str:
    s_pre = s
    if "(" in s:
        s = s.replace("(", r"\\(")
    if "+" in s:
        s = s.replace(" + ", r"\\+")
    if table_cell and s_pre != s:
        return multiline_cell(s)
    return s


def df_to_latex_df(df: pd.DataFrame, table_cell: bool) -> pd.DataFrame:
    df = df.copy()
    df.index = [new_name(e, table_cell=table_cell) for e in df.index]
    return df


def add_hline(latex_code: str) -> str:
    m_i = latex_code.index("\\midrule\n")
    b_i = latex_code.index("\n\\bottomrule")
    mid = latex_code[m_i + len("\\midrule\n") : b_i]
    mid = "\\hline\n".join(mid.split("\n"))
    latex_code = latex_code[: m_i + len("\\midrule\n")] + mid + latex_code[b_i:]
    return latex_code

In [5]:
def print_latex_table(df: pd.DataFrame):
    latex_df = df_to_latex_df(df, table_cell=True)
    latex_code = latex_df.to_latex(
        float_format="%.2f",
        column_format="l|ccc|ccc|ccc",
        multicolumn=True,
        multirow=True,
        multicolumn_format="c",
    )
    print(add_hline(latex_code=latex_code))

In [6]:
print_latex_table(df)

\begin{tabular}{l|ccc|ccc|ccc}
\toprule
 & \multicolumn{3}{c}{Transkribus} & \multicolumn{3}{c}{Tesseract} & \multicolumn{3}{c}{TrOCR} \\
 & CER & WER & mean & CER & WER & mean & CER & WER & mean \\
\midrule
\begin{tabular}[c]{@{}l@{}}GT-Sámi \\(without base)\end{tabular} & 1.59 & 5.67 & 3.63 & 7.93 & 24.70 & 16.31 &  &  &  \\\hline
GT-Sámi & 1.28 & 4.34 & 2.81 & 4.59 & 9.84 & 7.22 & 1.98 & 9.29 & 5.64 \\\hline
\begin{tabular}[c]{@{}l@{}}GT-Sámi\\+GT-Nor\end{tabular} & 1.31 & 4.35 & 2.83 & 4.91 & 11.39 & 8.15 & 1.95 & 8.88 & 5.42 \\\hline
\begin{tabular}[c]{@{}l@{}}GT-Sámi\\+Pred-Sámi\end{tabular} & 1.48 & 4.02 & 2.75 & 4.42 & 8.17 & 6.30 & 1.28 & 5.00 & 3.14 \\\hline
\begin{tabular}[c]{@{}l@{}}GT-Sámi\\+GT-Nor\\+Pred-Sámi\end{tabular} & 1.07 & 3.58 & 2.33 & 4.40 & 7.96 & 6.18 & 1.32 & 5.14 & 3.23 \\\hline
\begin{tabular}[c]{@{}l@{}}GT-Sámi \\(synth base)\end{tabular} &  &  &  & 4.33 & 8.78 & 6.56 & 1.15 & 5.04 & 3.09 \\\hline
\begin{tabular}[c]{@{}l@{}}GT-Sámi\\+Pred-Sámi \\(synth bas