# Dataset Size

In [1]:
%load_ext autoreload
%autoreload 2

import plotly.graph_objects as go
import plotly.io as pio

from htc.evaluation.model_comparison.dataset_size import dataset_size_table
from htc.settings import settings
from htc.settings_seg import settings_seg
from htc.utils.sqldf import sqldf
from htc.utils.visualization import add_std_fill

pio.kaleido.scope.mathjax = None

In [2]:
# NBVAL_IGNORE_OUTPUT
runs = sorted(settings.training_dir.rglob(f"{settings_seg.dataset_size_timestamp}_generated_default*_dataset_size"))


def plot_metric(metric_name: str, metric_short: str) -> go.Figure:
    df_res = dataset_size_table(runs, metric_name)

    # Average seeds
    df_total = sqldf(
        """
        SELECT model, n_pigs, AVG(metric_images) AS metric_mean, STD(metric_images) AS metric_std
        FROM df_res
        WHERE n_pigs < 15
        GROUP BY model, n_pigs
    """
    )

    fig = go.Figure()
    for model in settings_seg.model_colors.keys():
        df_model = df_total.query("model == @model")
        values = df_model["metric_mean"]

        name = model
        if name == "superpixel_classification":
            name = "superpixel"

        add_std_fill(fig, values, df_model["metric_std"] * 0.5, linecolor=settings_seg.model_colors[model], label=name)

    fig.update_layout(title_x=0.5, xaxis_title="<b><i>n</i> (# pigs)</b>", yaxis_title=f"<b>{metric_short}</b>")
    fig.update_layout(
        template="plotly_white", font_family="Libertinus Serif", font_size=22, margin=dict(l=0, r=0, b=0, t=0)
    )
    fig.update_layout(legend=dict(orientation="h", yanchor="bottom", y=1.05, xanchor="center", x=0.5))
    fig.update_layout(width=1000, height=400)
    fig.update_layout(xaxis=dict(tickmode="array", tickvals=list(range(0, 14, 2)), ticktext=list(range(1, 15, 2))))
    fig.write_image(settings_seg.paper_dir / f"dataset_size_{metric_short}.pdf")

    return fig


plot_metric("dice_metric", "DSC")

Output()

In [3]:
# NBVAL_IGNORE_OUTPUT
plot_metric("surface_distance_metric", "ASD")

Output()

In [4]:
# NBVAL_IGNORE_OUTPUT
plot_metric(settings_seg.nsd_aggregation_short, "NSD")

Output()