### run this code to generate a table from selected results

In [7]:
import pandas as pd


DATA_DIR_ORIGINAL = [
    "F:/Github Repos/Restrictive-Hierarchical-Code/Predictions/unet_orig_extended_test/metrics.csv",
    "F:/Github Repos/Restrictive-Hierarchical-Code/Predictions/hrnet_orig_extended_test/metrics.csv"
]
DATA_DIR_HIERARCHY = [
    "F:/Github Repos/Restrictive-Hierarchical-Code/Predictions/unet_hier_extended_test/metrics.csv",
    "F:/Github Repos/Restrictive-Hierarchical-Code/Predictions/hrnet_hier_extended_test/metrics.csv"
]
RESULTS_HEADING = [
    "UNet",
    "HRNet",
]

# Table caption + label
TABLE_CAPTION = "Per-class semantic segmentation metrics mean ($\pm$ standard deviation), comparing Original and Hierarchical versions of UNet and HRNet on the validation set for the extended hierarchy experiment."
TABLE_LABEL = "tab:hier_seg2_results_extended"

# Sanity check
assert len(DATA_DIR_ORIGINAL) == len(DATA_DIR_HIERARCHY) == len(RESULTS_HEADING), \
    "DATA_DIR_ORIGINAL, DATA_DIR_HIERARCHY and RESULTS_HEADING must have the same length."

# CLASS_ORDER = ["All", "0", "1", "2", "3", "4", "5", "6", "7"]
# CLASS_LABELS = {
#     "All": "Average",
#     "0": "Background",
#     "1": "Upper",
#     "2": "Lower",
#     "3": "Tooth",
#     "4": "Pulp",
#     "5": "Dentin",
#     "6": "Enamel",
#     "7": "Composite",
# }

CLASS_ORDER = ["All", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "10"]
CLASS_LABELS = {
    "All": "Average",
    "0": "Background",
    "1": "T+A",
    "2": "Alveolar",
    "3": "Tooth",
    "4": "Upper",
    "5": "Lower",
    "6": "Composite",
    "7": "Healthy",
    "8": "Pulp",
    "9": "Dentin",
    "10": "Enamel",
}

DECIMALS = 3


def format_metric(mean, std, bold=False, decimals=3):
    """Format a mean Â± std cell, with optional bolding of the *mean only*."""
    mean_str = f"{mean:.{decimals}f}"
    if bold:
        mean_str = f"\\textbf{{{mean_str}}}"
    return f"{mean_str} ($\\pm$ {std:.{decimals}f})"


def load_metrics(csv_path):
    """Load CSV and index by Class."""
    df = pd.read_csv(csv_path)
    df["Class"] = df["Class"].astype(str)
    df = df[df["Class"].isin(CLASS_ORDER)]
    return df.set_index("Class")


def build_model_block(model_name, orig_df, hier_df):
    """Build one 9-row block for a single model."""
    lines = []
    n_rows = len(CLASS_ORDER)

    for i, cls in enumerate(CLASS_ORDER):
        label = CLASS_LABELS[cls]
        o = orig_df.loc[cls]
        h = hier_df.loc[cls]

        if i == 0:
            prefix = rf"\multirow{{{n_rows}}}{{*}}{{\textbf{{{model_name}}}}} & {label}"
        else:
            prefix = f"& {label}"

        # Determine bolding for each metric
        bold_iou_orig = o["IoU Mean"] > h["IoU Mean"]
        bold_dice_orig = o["Dice Mean"] > h["Dice Mean"]
        bold_prec_orig = o["Precision Mean"] > h["Precision Mean"]
        bold_rec_orig  = o["Recall Mean"] > h["Recall Mean"]

        line = (
            prefix
            + f" & {format_metric(o['IoU Mean'], o['IoU Std'], bold=bold_iou_orig)}"
            + f" & {format_metric(o['Dice Mean'], o['Dice Std'], bold=bold_dice_orig)}"
            + f" & {format_metric(o['Precision Mean'], o['Precision Std'], bold=bold_prec_orig)}"
            + f" & {format_metric(o['Recall Mean'], o['Recall Std'], bold=bold_rec_orig)}"
            + f" & {format_metric(h['IoU Mean'], h['IoU Std'], bold=not bold_iou_orig)}"
            + f" & {format_metric(h['Dice Mean'], h['Dice Std'], bold=not bold_dice_orig)}"
            + f" & {format_metric(h['Precision Mean'], h['Precision Std'], bold=not bold_prec_orig)}"
            + f" & {format_metric(h['Recall Mean'], h['Recall Std'], bold=not bold_rec_orig)} \\\\"
        )

        lines.append(line)

    return "\n".join(lines)


def build_full_table():
    """Construct the full LaTeX table containing all model blocks."""
    body_blocks = []

    for orig_path, hier_path, heading in zip(DATA_DIR_ORIGINAL, DATA_DIR_HIERARCHY, RESULTS_HEADING):
        orig_df = load_metrics(orig_path)
        hier_df = load_metrics(hier_path)
        block = build_model_block(heading, orig_df, hier_df)
        body_blocks.append(block)

    body = ("\n\\midrule\n".join(body_blocks)) + "\n"

    header = r"""\begin{table}[hbt!]
\centering
\renewcommand{\arraystretch}{1.}
\resizebox{\textwidth}{!}{
\begin{tabular}{|ll|cccc|cccc|}
\toprule
\multicolumn{2}{|c|}{} & \multicolumn{8}{c|}{\textbf{Validation Set}} \\
\cmidrule(lr){3-10}
\multicolumn{2}{|c}{} & \multicolumn{4}{|c}{\textbf{Original}} & \multicolumn{4}{|c|}{\textbf{With Hierarchy}} \\
\cmidrule(lr){3-6} \cmidrule(lr){7-10}
\textbf{Model} & \textbf{Class} & IoU & Dice & Precision & Recall & IoU & Dice & Precision & Recall \\
\midrule
"""

    footer = rf"""\bottomrule
\end{{tabular}}
}}
\caption{{{TABLE_CAPTION}}}
\label{{{TABLE_LABEL}}}
\end{{table}}"""

    return header + body + footer


# --------------------------------------------------------------------
# Generate table
# --------------------------------------------------------------------
latex_table = build_full_table()
print(latex_table)

# Optional:
# with open("semantic_seg_results_table.tex", "w") as f:
#     f.write(latex_table)

\begin{table}[hbt!]
\centering
\renewcommand{\arraystretch}{1.}
\resizebox{\textwidth}{!}{
\begin{tabular}{|ll|cccc|cccc|}
\toprule
\multicolumn{2}{|c|}{} & \multicolumn{8}{c|}{\textbf{Validation Set}} \\
\cmidrule(lr){3-10}
\multicolumn{2}{|c}{} & \multicolumn{4}{|c}{\textbf{Original}} & \multicolumn{4}{|c|}{\textbf{With Hierarchy}} \\
\cmidrule(lr){3-6} \cmidrule(lr){7-10}
\textbf{Model} & \textbf{Class} & IoU & Dice & Precision & Recall & IoU & Dice & Precision & Recall \\
\midrule
\multirow{12}{*}{\textbf{UNet}} & Average & 0.761 ($\pm$ 0.006) & 0.847 ($\pm$ 0.005) & 0.869 ($\pm$ 0.004) & 0.831 ($\pm$ 0.007) & \textbf{0.828} ($\pm$ 0.006) & \textbf{0.888} ($\pm$ 0.005) & \textbf{0.895} ($\pm$ 0.006) & \textbf{0.887} ($\pm$ 0.006) \\
& Background & 0.981 ($\pm$ 0.001) & 0.990 ($\pm$ 0.000) & 0.992 ($\pm$ 0.001) & \textbf{0.988} ($\pm$ 0.001) & \textbf{0.981} ($\pm$ 0.002) & \textbf{0.991} ($\pm$ 0.001) & \textbf{0.993} ($\pm$ 0.001) & 0.988 ($\pm$ 0.002) \\
& T+A & 0.914 ($\pm$ 0.00