# Generalization Error

In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import plotly.graph_objects as go
import plotly.io as pio

from htc.evaluation.model_comparison.paper_runs import collect_comparison_runs
from htc.models.common.MetricAggregation import MetricAggregation
from htc.settings import settings
from htc.settings_seg import settings_seg

pio.kaleido.scope.mathjax = None

In [2]:
df_runs = collect_comparison_runs(settings_seg.model_comparison_timestamp)
df_runs

Unnamed: 0,model,name,main_loss,run_rgb,model_image_size,run_param,run_hsi
0,pixel,pixel,train/ce_loss_epoch,2022-02-03_22-58-44_generated_default_rgb_mode...,307200,2022-02-03_22-58-44_generated_default_paramete...,2022-02-03_22-58-44_generated_default_model_co...
1,superpixel_classification,superpixel_classification,train/kl_loss_epoch,2022-02-03_22-58-44_generated_default_rgb_mode...,1000,2022-02-03_22-58-44_generated_default_paramete...,2022-02-03_22-58-44_generated_default_model_co...
2,patch,patch_32,train/dice_loss_epoch,2022-02-03_22-58-44_generated_default_rgb_mode...,300,2022-02-03_22-58-44_generated_default_paramete...,2022-02-03_22-58-44_generated_default_model_co...
3,patch,patch_64,train/dice_loss_epoch,2022-02-03_22-58-44_generated_default_64_rgb_m...,75,2022-02-03_22-58-44_generated_default_64_param...,2022-02-03_22-58-44_generated_default_64_model...
4,image,image,train/dice_loss_epoch,2022-02-03_22-58-44_generated_default_rgb_mode...,1,2022-02-03_22-58-44_generated_default_paramete...,2022-02-03_22-58-44_generated_default_model_co...


In [3]:
def plot_generalization_error(model_type: str, metric: str) -> go.Figure:
    fig = go.Figure()

    for i, row in df_runs.iterrows():
        run_dir = settings.training_dir / row["model"] / row[f"run_{model_type}"]
        agg = MetricAggregation(run_dir / "validation_table.pkl.xz", metrics=[metric])
        df_val = agg.grouped_metrics_epochs(mode="image_level")
        values_0 = df_val.query("dataset_index == 0").groupby("epoch_index")[metric].mean().values
        values_1 = df_val.query("dataset_index == 1").groupby("epoch_index")[metric].mean().values
        print(
            f'{row["name"]}: mean diff {round(np.mean(values_1[19:] - values_0[19:]), 2)}, median diff'
            f" {round(np.median(values_1[19:] - values_0[19:]), 2)}, std diff"
            f" {round(np.std(values_1[19:] - values_0[19:]), 2)}"
        )

        name = row["name"]
        name_legend = name.replace("superpixel_classification", "superpixel")
        visible = None if name_legend in ("pixel", "patch_32", "image") else "legendonly"
        fig.add_trace(
            go.Scatter(
                y=values_0,
                name=f"{name_legend}",
                hovertemplate=f"{name_legend}_unknown",
                marker_color=settings_seg.model_colors[name],
                legendgroup=name,
                visible=visible,
            )
        )
        fig.add_trace(
            go.Scatter(
                y=values_1,
                name=f"{name_legend}",
                hovertemplate=f"{name_legend}_known",
                marker_color=settings_seg.model_colors[name],
                line=dict(dash="dot"),
                legendgroup=name,
                visible=visible,
                showlegend=False,
            )
        )

    fig.layout.height = 400
    fig.layout.width = 1000
    fig.update_layout(xaxis_title="<b>epoch</b>", yaxis_title="<b>DSC</b>")
    fig.update_layout(
        template="plotly_white", font_family="Libertinus Serif", font_size=22, margin=dict(l=0, r=0, b=0, t=20)
    )
    fig.update_layout(legend=dict(orientation="h", yanchor="bottom", y=1.05, xanchor="right", x=0.90))
    fig.update_yaxes(range=[0, 1])

    if model_type != "hsi":
        fig.update_layout(showlegend=False)

    return fig


plot_generalization_error("hsi", "dice_metric_image")

pixel: mean diff 0.02, median diff 0.03, std diff 0.01
superpixel_classification: mean diff 0.08, median diff 0.08, std diff 0.02
patch_32: mean diff 0.08, median diff 0.08, std diff 0.01
patch_64: mean diff 0.09, median diff 0.09, std diff 0.01
image: mean diff 0.11, median diff 0.11, std diff 0.01


In [4]:
figs = {}
for modality in ["hsi", "param", "rgb"]:
    fig = plot_generalization_error(modality, "dice_metric_image")
    fig.write_image(settings_seg.paper_dir / f"generalization_error_{modality}.pdf")
    fig.update_layout(showlegend=True)
    figs[modality] = fig

pixel: mean diff 0.02, median diff 0.03, std diff 0.01
superpixel_classification: mean diff 0.08, median diff 0.08, std diff 0.02
patch_32: mean diff 0.08, median diff 0.08, std diff 0.01
patch_64: mean diff 0.09, median diff 0.09, std diff 0.01
image: mean diff 0.11, median diff 0.11, std diff 0.01
pixel: mean diff 0.03, median diff 0.03, std diff 0.0
superpixel_classification: mean diff 0.13, median diff 0.13, std diff 0.01
patch_32: mean diff 0.12, median diff 0.12, std diff 0.01
patch_64: mean diff 0.1, median diff 0.1, std diff 0.0
image: mean diff 0.11, median diff 0.11, std diff 0.01
pixel: mean diff 0.01, median diff 0.01, std diff 0.0
superpixel_classification: mean diff 0.13, median diff 0.13, std diff 0.01
patch_32: mean diff 0.12, median diff 0.12, std diff 0.01
patch_64: mean diff 0.1, median diff 0.1, std diff 0.01
image: mean diff 0.11, median diff 0.11, std diff 0.0


In [5]:
# The fonts also need to be copied to the resulting folder
css = """
/* Load custom font */
@font-face {
    font-family: libertinus;
    font-style: normal;
    src: url("fonts/LibertinusSerifDisplay-Regular.otf"), url("LibertinusSerifDisplay-Regular.otf");
}
@font-face {
    font-family: libertinus;
    font-weight: bold;
    src: url("fonts/LibertinusSerif-Semibold.otf"), url("LibertinusSerif-Semibold.otf");
}
body {
    font-family: libertinus, serif;
    hyphens: auto;
}
figure {
    width: min-content;
}
figure > figcaption {
    margin-left: 15px;
    margin-right: 15px;
}
figcaption {
    text-align: center;
    margin-top: 10px;
}
"""

# Combine all figures in one html file
html = f"""
<!DOCTYPE html>
<html lang="en">
    <head>
        <meta charset="utf-8">
        <title>Generalization error</title>
        <style>
        {css}
        </style>
    </head>
    <body>
        <h1>HSI</h1>
        <figure>
            {figs["hsi"].to_html(full_html=False, include_plotlyjs=True, div_id="hsi")}
        </figure>

        <h1>TPI</h1>
        <figure>
            {figs["param"].to_html(full_html=False, include_plotlyjs=False, div_id="tpi")}
        </figure>

        <h1>RGB</h1>
        <figure>
            {figs["rgb"].to_html(full_html=False, include_plotlyjs=False, div_id="rgb")}
            <figcaption>Generalization error over training time by comparing the two validation data sets V_unknown (solid lines) with V_known (dotted lines). The shown values are obtained by first averaging dice similarity coefficient (DSC) values of all images of one pig and then averaging the mean DSC values of the different pigs in V_unknown and V_known. See the paper for more details on the validation data set splits. The error curves for the superpixel and patch_64 models are hidden by default and can be makde visible by clicking on the legend name.</figcaption>
        </figure>
    </body>
</html>"""

with (settings_seg.paper_dir / "generalization_error.html").open("w") as f:
    f.write(html)