In [149]:
import json
from pathlib import Path
from typing import NamedTuple
from functools import partial

import numpy as np
import pandas as pd


def round_and_percentage(num: float) -> float:
    return round(num * 100, 2)


model_prefix_name_map = {"transk": "Transkribus", "tess": "Tesseract", "trocr": "TrOCR"}
dataset_map = {
    "ub_smi": "GT-Sámi (without base)",
    "smi_ub": "GT-Sámi (without base)",
    "smi_synth": "GT-Sámi (synth base)",
    "sb_smi": "GT-Sámi (synth base)",
    "smi": "GT-Sámi",
    "smi_nor": "GT-Sámi + GT-Nor",
    "smi_pred": "GT-Sámi + Pred-Sámi",
    "smi_nor_pred": "GT-Sámi + GT-Nor + Pred-Sámi",
    "sb_smi_nor_pred": "GT-Sámi + GT-Nor + Pred-Sámi (synth base)",
    "smi_pred_synth": "GT-Sámi + Pred-Sámi (synth base)",
}

dataset_spec = {
    "GT-Sámi (without base)":                    {"w/o base":  True, "GT-Sámi": True, "GT-Nor": False, "Pred-Sámi": False, "Synth base": False},
    "GT-Sámi":                                   {"w/o base": False, "GT-Sámi": True, "GT-Nor": False, "Pred-Sámi": False, "Synth base": False},
    "GT-Sámi + GT-Nor":                          {"w/o base": False, "GT-Sámi": True, "GT-Nor":  True, "Pred-Sámi": False, "Synth base": False},
    "GT-Sámi + Pred-Sámi":                       {"w/o base": False, "GT-Sámi": True, "GT-Nor": False, "Pred-Sámi":  True, "Synth base": False},
    "GT-Sámi + GT-Nor + Pred-Sámi":              {"w/o base": False, "GT-Sámi": True, "GT-Nor":  True, "Pred-Sámi":  True, "Synth base": False},
    "GT-Sámi (synth base)":                      {"w/o base": False, "GT-Sámi": True, "GT-Nor": False, "Pred-Sámi": False, "Synth base":  True},
    "GT-Sámi + Pred-Sámi (synth base)":          {"w/o base": False, "GT-Sámi": True, "GT-Nor": False, "Pred-Sámi":  True, "Synth base":  True},
    "GT-Sámi + GT-Nor + Pred-Sámi (synth base)": {"w/o base": False, "GT-Sámi": True, "GT-Nor":  True, "Pred-Sámi":  True, "Synth base":  True},
}
spec_names = list(dataset_spec["GT-Sámi"].keys())
dataset_spec = {key: tuple(str(v) for v in spec.values()) for key, spec in dataset_spec.items()}
index = pd.MultiIndex.from_tuples(dataset_spec.values(), names=spec_names)

primary_columns = model_prefix_name_map.values()
sub_columns = ["CER", "WER", "mean"]
columns = pd.MultiIndex.from_product([primary_columns, sub_columns])

df = pd.DataFrame(
    columns=columns,
    index=index,
    data=[[pd.NA] * len(primary_columns) * len(sub_columns)] * len(dataset_spec),
)
df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Transkribus,Transkribus,Transkribus,Tesseract,Tesseract,Tesseract,TrOCR,TrOCR,TrOCR
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,CER,WER,mean,CER,WER,mean,CER,WER,mean
w/o base,GT-Sámi,GT-Nor,Pred-Sámi,Synth base,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
True,True,False,False,False,,,,,,,,,
False,True,False,False,False,,,,,,,,,
False,True,True,False,False,,,,,,,,,
False,True,False,True,False,,,,,,,,,
False,True,True,True,False,,,,,,,,,
False,True,False,False,True,,,,,,,,,
False,True,False,True,True,,,,,,,,,
False,True,True,True,True,,,,,,,,,


In [150]:
output_dir = "../../output/valset_evaluation/line_level"
output_dir = Path(output_dir)

for model_name in output_dir.iterdir():
    if "smi" not in model_name.name:
        continue
    model_prefix, _, model_info = model_name.name.partition("_")

    if model_prefix == "transk":
        if "lm" not in model_info:
            continue
        model_info, _, lm = model_info.rpartition("_")

    eval_data_file = model_name / "all_rows.json"
    eval_data = json.loads(eval_data_file.read_text(encoding="utf-8"))
    cer = eval_data["CER_concat"]
    wer = eval_data["WER_concat"]
    mean_ = np.mean([cer, wer])

    df.at[dataset_spec[dataset_map[model_info]], (model_prefix_name_map[model_prefix], "CER")] = (
        round_and_percentage(cer)
    )
    df.at[dataset_spec[dataset_map[model_info]], (model_prefix_name_map[model_prefix], "WER")] = (
        round_and_percentage(wer)
    )
    df.at[dataset_spec[dataset_map[model_info]], (model_prefix_name_map[model_prefix], "mean")] = (
        round_and_percentage(mean_)
    )
df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Transkribus,Transkribus,Transkribus,Tesseract,Tesseract,Tesseract,TrOCR,TrOCR,TrOCR
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,CER,WER,mean,CER,WER,mean,CER,WER,mean
w/o base,GT-Sámi,GT-Nor,Pred-Sámi,Synth base,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2,Unnamed: 11_level_2,Unnamed: 12_level_2,Unnamed: 13_level_2
True,True,False,False,False,1.59,5.67,3.63,5.53,24.7,15.11,,,
False,True,False,False,False,1.28,4.34,2.81,2.05,9.84,5.95,1.98,9.29,5.64
False,True,True,False,False,1.31,4.35,2.83,2.37,11.39,6.88,1.95,8.88,5.42
False,True,False,True,False,1.48,4.02,2.75,1.85,8.17,5.01,1.28,5.0,3.14
False,True,True,True,False,1.07,3.58,2.33,1.81,7.96,4.89,1.32,5.14,3.23
False,True,False,False,True,,,,1.78,8.78,5.28,1.15,5.04,3.09
False,True,False,True,True,,,,,,,1.08,4.29,2.69
False,True,True,True,True,,,,1.79,7.7,4.75,,,


In [151]:
def multiline_cell(s: str) -> str:
    template_start = r"\begin{tabular}[c]{@{}l@{}}"
    template_end = r"\end{tabular}"
    return template_start + s + template_end


def new_name(s: str, table_cell: bool = True) -> str:
    s_pre = s
    if "(" in s:
        s = s.replace("(", r"\\(")
    if "+" in s:
        s = s.replace(" + ", r"\\+")
    if table_cell and s_pre != s:
        return multiline_cell(s)
    return s


def df_to_latex_df(df: pd.DataFrame, table_cell: bool) -> pd.DataFrame:
    df = df.copy()
    df.index = [new_name(e, table_cell=table_cell) for e in df.index]
    return df


def add_hline(latex_code: str) -> str:
    m_i = latex_code.index("\\midrule\n")
    b_i = latex_code.index("\n\\bottomrule")
    mid = latex_code[m_i + len("\\midrule\n") : b_i]
    mid = "\\hline\n".join(mid.split("\n"))
    latex_code = latex_code[: m_i + len("\\midrule\n")] + mid + latex_code[b_i:]
    return latex_code

In [153]:
for column in df.columns:
    print(df[column].min())

1.07
3.58
2.33
1.78
7.7
4.75
1.08
4.29
2.69


In [172]:
def get_latex_table(df: pd.DataFrame):
    def max_formatter(value, column):
        formatted = f"{value:.2f}"
        if value == df[column].min():
            return r"\textbf{VALUE}".replace("VALUE", formatted)
        return formatted
        
        
    
    latex_df = df_to_latex_df(df, table_cell=False)
    latex_code = (
        df.to_latex(
            formatters={col: partial(max_formatter, column=col) for col in df.columns},
            column_format="lllll ccc ccc ccc".replace(" ", ""),  # Add spaces for easy grouping
            multicolumn=True,
            multirow=True,
            multicolumn_format="c",
            sparsify=False,
            na_rep="",
        )
        .replace("w/o base & GT-Sámi & GT-Nor & Pred-Sámi & Synth base &  &  &  &  &  &  &  &  &  \\\\\n", "")
        .replace("\\cline{1-14} \\cline{2-14} \\cline{3-14} \\cline{4-14}\n", "")
        .replace("False", "")
        .replace("True", "\\checkmark")
        .replace("Transkribus & Transkribus & Transkribus", r"\multicolumn{3}{c}{\textbf{Transkribus}}")
        .replace("Tesseract & Tesseract & Tesseract", r"\multicolumn{3}{c}{\textbf{Tesseract}}")
        .replace("TrOCR & TrOCR & TrOCR", r"\multicolumn{3}{c}{\textbf{TrOCR}}")
        .replace("\\toprule", "\\toprule\n" + " & ".join(r"\multirow{4}*{\rotatebox{90}{NAME}}".replace("NAME", spec_name) for spec_name in spec_names) + " \\\\\n &&&&&&&&&&&&& \\\\")
        .replace(r"\midrule", r"\cmidrule(r){1-5}\cmidrule(lr){6-8}\cmidrule(lr){9-11}\cmidrule(l){12-14}")
    )

    # Add colours
    latex_code = "\n".join(
        latex_code.splitlines()[:7] +
        [
            r"\rowcolor{gray!20}" + line if i % 2 == 1 else line
            for i, line in enumerate(latex_code.splitlines()[7:15])
        ] +
        latex_code.splitlines()[15:]
    )
    return latex_code

In [173]:
print(get_latex_table(df))

\begin{tabular}{lllllccccccccc}
\toprule
\multirow{4}*{\rotatebox{90}{w/o base}} & \multirow{4}*{\rotatebox{90}{GT-Sámi}} & \multirow{4}*{\rotatebox{90}{GT-Nor}} & \multirow{4}*{\rotatebox{90}{Pred-Sámi}} & \multirow{4}*{\rotatebox{90}{Synth base}} \\
 &&&&&&&&&&&&& \\
 &  &  &  &  & \multicolumn{3}{c}{\textbf{Transkribus}} & \multicolumn{3}{c}{\textbf{Tesseract}} & \multicolumn{3}{c}{\textbf{TrOCR}} \\
 &  &  &  &  & CER & WER & mean & CER & WER & mean & CER & WER & mean \\
\cmidrule(r){1-5}\cmidrule(lr){6-8}\cmidrule(lr){9-11}\cmidrule(l){12-14}
\checkmark & \checkmark &  &  &  & 1.59 & 5.67 & 3.63 & 5.53 & 24.70 & 15.11 &  &  &  \\
\rowcolor{gray!20} & \checkmark &  &  &  & 1.28 & 4.34 & 2.81 & 2.05 & 9.84 & 5.95 & 1.98 & 9.29 & 5.64 \\
 & \checkmark & \checkmark &  &  & 1.31 & 4.35 & 2.83 & 2.37 & 11.39 & 6.88 & 1.95 & 8.88 & 5.42 \\
\rowcolor{gray!20} & \checkmark &  & \checkmark &  & 1.48 & 4.02 & 2.75 & 1.85 & 8.17 & 5.01 & 1.28 & 5.00 & 3.14 \\
 & \checkmark & \checkmark & \che