In [None]:
import pathlib
import numpy as np
import seaborn as sns
import pandas as pd
import json
import plotly.express as px
from functools import partial
import matplotlib.pyplot as plt
import matplotlib as mpl

pd.set_option("use_inf_as_na", True)

In [None]:
RES_PATH = pathlib.Path("data") / "test-res"

base_model_names = {
    "519": "DeLux SC",
    "519_EC0.25": "DeLux EC0.25",
    "519_EC0.5": "DeLux EC0.5",
    "519_EC1.0": "DeLux EC1.0",
    "noop": "NoOp",
}
extra_det_model = {"noop_ec": "EC Only"}

synth_kind_names = {
    "HQFlareBasedAugmenter": "HQ Flare",
    "LensFlareAdder": "Lens Flare",
    "OverlitAugmenter": "Overlit",
    "SunAdder": "Sun",
    "VeilingGlareAdder": "Glare",
    "NoOpAugmenter": "No Artifact",
}
real_kind_names = {
    "flare7k": "Flare 7k++",
    "with_events": "E2VID\&DSEC",
}


def format_numbers(s, max: bool = True):
    if max:
        cmp_v = s.max()
        cmp_v_2 = s[s != cmp_v].max()
    else:
        cmp_v = s.min()
        cmp_v_2 = s[s != cmp_v].min()
    formatted = []
    for v in s:
        if v == cmp_v:
            formatted.append(f"$\\mathbf{{{v:.4f}}}$")
        elif v == cmp_v_2:
            formatted.append(f"$\\underline{{{v:.4f}}}$")
        else:
            formatted.append(f"${v:.4f}$")
    return formatted


def format_numbers_perc(s, max: bool = True):
    if max:
        cmp_v = s.max()
        cmp_v_2 = s[s != cmp_v].max()
    else:
        cmp_v = s.min()
        cmp_v_2 = s[s != cmp_v].min()
    formatted = []
    for v in s:
        if v == cmp_v:
            formatted.append(f"$\\mathbf{{{v:.2%}}}$")
        elif v == cmp_v_2:
            formatted.append(f"$\\underline{{{v:.2%}}}$")
        else:
            formatted.append(f"${v:.2%}$")
    return formatted

In [None]:
all_metrics = []
for p in RES_PATH.glob("*/test_metrics.json"):
    try:
        p_metrics = pd.read_json(p)
        p_metrics["model"] = p.parent.name
        all_metrics.append(p_metrics)
    except Exception as e:
        print(f"Error reading {p}: {e}")
test_df = pd.concat(all_metrics, ignore_index=True).fillna(0)


In [None]:
non_synth_kinds = ["flare7k", "with_events"]
synth_kinds = [k for k in test_df["kind"].unique() if k not in non_synth_kinds]

In [None]:
def post_process_latex_table(table: str, kind_nice_name: str = "Subset") -> str:
    table = table.replace("kind", kind_nice_name).replace("model", "Model")
    table = table.replace("tabular", "tabularx")
    table = table.replace("tabularx}", "tabularx}{\\textwidth}", 1)
    table = table.replace("%", "\\%").replace("_", "\\_")
    return table

In [None]:
def make_table_summary(
    df: pd.DataFrame,
    models_names: dict[str, str],
    kind_names: dict[str, str],
    metric: str,
    totals_row: bool = True,
    agg_func: str = "mean",
) -> pd.DataFrame:
    df = df.copy()

    subset_df = df[
        df["kind"].isin(kind_names.keys())
        & df["model"].isin(models_names.keys())
        & (df["metric"] == metric)
    ]
    groupped_subset_df = (
        subset_df.groupby(["kind", "model", "metric"])["value"]
        .agg(agg_func)
        .reset_index()
    )
    pivot_df = (
        groupped_subset_df.pivot(index="kind", columns="model", values="value")
        .rename(columns=models_names)
        .T.rename(columns=kind_names)
    )
    if not totals_row:
        return pivot_df
    totals_df = (
        subset_df.groupby(["model", "metric"])["value"].agg(agg_func).reset_index()
    )
    totals_df_row = (
        totals_df.pivot(index="model", columns="metric", values="value")
        .T.rename(columns=models_names)
        .T
    )
    pivot_df["Overall Mean"] = totals_df_row[metric]
    return pivot_df


In [None]:
test_df[
    (~test_df["kind"].isin(synth_kinds)) & (test_df.model.isin(["noop", "noop_ec"]))
]

# Results

## Detection

### Synthetic Artifacts

In [None]:
det_models = {
    **base_model_names,
    **extra_det_model,
}
pivot_acc_synth_df = (
    make_table_summary(test_df, det_models, synth_kind_names, "accuracy_th10")
    .apply(format_numbers)
    .T
)
pivot_acc_synth_df.insert(0, "Metric", "Accuracy")
pivot_f1_synth_df = (
    make_table_summary(test_df, det_models, synth_kind_names, "f1_score_th10")
    .apply(format_numbers)
    .T
)
pivot_f1_synth_df["Metric"] = "F1 Score"
pivot_iou_synth_df = (
    make_table_summary(test_df, det_models, synth_kind_names, "iou_th10")
    .apply(format_numbers)
    .T
)
pivot_iou_synth_df["Metric"] = "IoU"

pivot_det_synth_df = (
    pd.concat(
        [
            pivot_acc_synth_df,
            pivot_f1_synth_df,
            pivot_iou_synth_df,
        ],
        axis=0,
    )
    .reset_index()
    .set_index(["Metric", "kind"])
)
pivot_det_synth_df

In [None]:
synth_det_table = pivot_det_synth_df.to_latex(
    escape=False,
    bold_rows=True,
    column_format="xX" + "r" * (len(pivot_acc_synth_df.columns) - 1),
    label="tab:synth_det",
    caption="Detection metrics on synthetic datasets with different models. NoOp model is used as a reference and denotes predicting no artifacts in the entire frame.",
    position="htbp",
    multicolumn=True,
)
with open("data/analysis/tables/results/synth_det.tex", "w") as f:
    f.write(post_process_latex_table(synth_det_table))

### Real Artifacts

In [None]:
pivot_acc_real_df = (
    make_table_summary(
        test_df,
        det_models,
        real_kind_names,
        "accuracy_th10",
    )
    .apply(format_numbers)
    .T
)
pivot_acc_real_df.insert(0, "Metric", "Accuracy")
pivot_f1_real_df = (
    make_table_summary(
        test_df,
        det_models,
        real_kind_names,
        "f1_score_th10",
    )
    .apply(format_numbers)
    .T
)
pivot_f1_real_df["Metric"] = "F1 Score"
pivot_iou_real_df = (
    make_table_summary(
        test_df,
        det_models,
        real_kind_names,
        "iou_th10",
    )
    .apply(format_numbers)
    .T
)
pivot_iou_real_df["Metric"] = "IoU"

pivot_det_real_df = (
    pd.concat(
        [
            pivot_acc_real_df,
            pivot_f1_real_df,
            pivot_iou_real_df,
        ],
        axis=0,
    )
    .reset_index()
    .set_index(["Metric", "kind"])
)
pivot_det_real_df

In [None]:
real_det_table = pivot_det_real_df.to_latex(
    escape=False,
    bold_rows=True,
    column_format="xX" + "r" * (len(pivot_det_real_df.columns)),
    label="tab:real_det",
    caption="Detection metrics on Flare7k++ and E2VID+DSEC with different models. NoOp model is used as a reference and denotes predicting no artifacts in the entire frame. For EC models, the Flare7k++ scores are cleared as Flare7k++ does come with neuromorphic data (hence the estimate map is based only on the same detector as for SC).",
    position="htbp",
    multicolumn=True,
)
with open("data/analysis/tables/results/real_det.tex", "w") as f:
    f.write(post_process_latex_table(real_det_table))

## Removal

### Synthetic Artifacts


In [None]:
pivot_mse_synth_df = (
    make_table_summary(test_df, base_model_names, synth_kind_names, "mse")
    .apply(partial(format_numbers, max=False))
    .T
)
pivot_mse_synth_df.insert(0, "Metric", "MSE")
pivot_mape_synth_df = (
    make_table_summary(test_df, base_model_names, synth_kind_names, "mape")
    .apply(partial(format_numbers, max=False))
    .T
)
pivot_mape_synth_df["Metric"] = "MAPE"
pivot_psnr_synth_df = (
    make_table_summary(test_df, base_model_names, synth_kind_names, "psnr")
    .apply(format_numbers)
    .T
)
pivot_psnr_synth_df["Metric"] = "PSNR"
pivot_msssim_synth_df = (
    make_table_summary(test_df, base_model_names, synth_kind_names, "mssim")
    .apply(format_numbers)
    .T
)
pivot_msssim_synth_df["Metric"] = "MS-SSIM"

pivot_rec_synth_df = (
    pd.concat(
        [
            pivot_mse_synth_df,
            pivot_mape_synth_df,
            pivot_psnr_synth_df,
            pivot_msssim_synth_df,
        ],
        axis=0,
    )
    .reset_index()
    .set_index(["Metric", "kind"])
)
pivot_rec_synth_df

In [None]:
synth_rec_table = pivot_rec_synth_df.to_latex(
    escape=False,
    bold_rows=True,
    column_format="XX" + "r" * (len(pivot_mse_synth_df.columns) - 1),
    label="tab:synth_rec",
    caption="Artifact removal metrics on synthetic datasets with different models. NoOp model is used as a reference and denotes predicting the input as output.",
    position="htbp",
    multicolumn=True,
)
with open("data/analysis/tables/results/synth_rec.tex", "w") as f:
    f.write(post_process_latex_table(synth_rec_table))

### Real Artifacts

In [None]:
fastvqa = json.loads(pathlib.Path("data/test-res/fastvqascores.json").read_text())
orig_scores = fastvqa["original"]
processed_fast_vqa = []
for model, fastvqa_scores in fastvqa.items():
    if model == "original":
        continue  # Skip the original scores, they are already in orig_scores
    for kind, score in fastvqa_scores.items():
        ref_score = orig_scores.get(kind, 0.0)
        diff_score = ref_score - score
        diff_score_rel = diff_score / ref_score if ref_score != 0 else 0.0
        processed_fast_vqa.append(
            {
                "frame_idx": 0,  # Placeholder, as we don't have frame indices in this context
                "metric": "diff_fastvqa_rel",
                "value": diff_score_rel,
                "kind": kind,
                "model": model,
            }
        )
fastvqa_df = pd.DataFrame(processed_fast_vqa)

In [None]:
all_eval_metrics = []
for p in RES_PATH.glob("**/refscores.json"):
    dat = pd.read_json(p).reset_index()
    dat["diff_brisque_rel"] = (dat["diff_brisque"] / dat["ref_brisque"]).fillna(0)
    dat["diff_mean_detection"] = dat["est_mean"] - dat["post_mean"]
    dat["diff_mean_detection_rel"] = (
        dat["diff_mean_detection"] / dat["est_mean"]
    ).fillna(0)

    dat["diff_over_threshold"] = dat["est_over_threshold"] - dat["post_over_threshold"]
    dat["diff_over_threshold_rel"] = (
        dat["diff_over_threshold"] / dat["est_over_threshold"]
    ).fillna(0)
    p_metrics = dat.melt("index").rename(
        columns={"index": "frame_idx", "variable": "metric"}
    )
    p_metrics["kind"] = p.parent.name
    p_metrics["model"] = p.parent.parent.name
    all_eval_metrics.append(p_metrics)
all_eval_metrics.append(fastvqa_df)
eval_df = pd.concat(all_eval_metrics, ignore_index=True)

In [None]:
ref_fiv = (
    eval_df[eval_df.metric == "ref_mean_intensity"]
    .groupby(["kind", "model"])["value"]
    .var()
)
test_fiv = (
    eval_df[eval_df.metric == "test_mean_intensity"]
    .groupby(["kind", "model"])["value"]
    .var()
)

fiv_df = ((ref_fiv - test_fiv) / ref_fiv).reset_index()
fiv_df["metric"] = "diff_fiv_rel"
fiv_df["frame_idx"] = 0  # Placeholder, as we don't have frame indices in this context
eval_df = pd.concat([eval_df, fiv_df], ignore_index=True)

In [None]:
eval_sas_filtered_pivot = eval_df.pivot(
    index=["kind", "model", "frame_idx"],
    columns="metric",
    values="value",
).reset_index()

eval_sas_filtered_pivot = eval_sas_filtered_pivot.loc[
    eval_sas_filtered_pivot["est_over_threshold"] > 0
].copy()
eval_sas_filtered_df = eval_sas_filtered_pivot.melt(
    id_vars=["kind", "model", "frame_idx"],
    var_name="metric",
    value_name="value",
).reset_index(drop=True)

In [None]:
selected_recordings = {
    k: k
    for k in eval_df.kind.unique()
    if True and k not in {"sun9", "sun10", "sun14", "highway1"}
}

#### Quality

In [None]:
pivot_vmaf_real_df = (
    make_table_summary(
        eval_df, base_model_names, selected_recordings, "vmaf", totals_row=False
    )
    .apply(format_numbers)
    .T
)
pivot_vmaf_real_df.insert(0, "Metric", "VMAF")
pivot_brisque_real_df = (
    make_table_summary(
        eval_df,
        base_model_names,
        selected_recordings,
        "diff_brisque_rel",
        totals_row=False,
    )
    .apply(partial(format_numbers_perc))
    .T
)
pivot_brisque_real_df["Metric"] = "BRISQUE"
pivot_fastvqa_real_df = (
    make_table_summary(
        eval_df,
        base_model_names,
        selected_recordings,
        "diff_fastvqa_rel",
        totals_row=False,
    )
    .apply(partial(format_numbers_perc, max=False))
    .T
)
pivot_fastvqa_real_df["Metric"] = "FastVQA"


pivot_rec_real_df = (
    pd.concat(
        [
            pivot_vmaf_real_df,
            pivot_brisque_real_df,
            pivot_fastvqa_real_df,
        ],
        axis=0,
    )
    .reset_index()
    .set_index(["Metric", "kind"])
)
pivot_rec_real_df


In [None]:
real_rec_table = pivot_rec_real_df.to_latex(
    escape=False,
    bold_rows=True,
    column_format="XX" + "r" * len(pivot_rec_real_df.columns),
    label="tab:real_rec",
    caption="Artifact removal metrics on real datasets with different models. NoOp model is used as a reference and denotes predicting the input as output.",
    position="htbp",
    multicolumn=True,
)
with open("data/analysis/tables/results/real_rec.tex", "w") as f:
    f.write(
        post_process_latex_table(real_rec_table, "Recording")
    )  # .replace("\$", "$"))

#### Artifact Removal

In [None]:
pivot_mar_real_df = (
    make_table_summary(
        eval_df,
        base_model_names,
        selected_recordings,
        "diff_mean_detection_rel",
        totals_row=False,
    )
    .apply(format_numbers_perc)
    .T
)

pivot_mar_real_df.insert(0, "Metric", "MAR")
pivot_sas_real_df = (
    make_table_summary(
        eval_df,
        base_model_names,
        selected_recordings,
        "diff_over_threshold_rel",
        totals_row=False,
    )
    .apply(partial(format_numbers_perc))
    .T
)
pivot_sas_real_df["Metric"] = "SAS"
pivot_sasa_real_df = (
    make_table_summary(
        eval_df,
        base_model_names,
        selected_recordings,
        "diff_over_threshold",
        totals_row=False,
    )
    .apply(partial(format_numbers))
    .T
)
pivot_sasa_real_df["Metric"] = "SASA"
pivot_fiv_real_df = (
    make_table_summary(
        eval_df, base_model_names, selected_recordings, "diff_fiv_rel", totals_row=False
    )
    .apply(partial(format_numbers_perc))
    .T
)
pivot_fiv_real_df["Metric"] = "FIV"
pivot_det_real_df = (
    pd.concat(
        [
            pivot_mar_real_df,
            pivot_sas_real_df,
            pivot_sasa_real_df,
            pivot_fiv_real_df,
        ],
        axis=0,
    )
    .reset_index()
    .set_index(["Metric", "kind"])
)
pivot_det_real_df

In [None]:
pivot_ar_real_table = pivot_det_real_df.to_latex(
    escape=False,
    bold_rows=True,
    column_format="XX" + "r" * (len(pivot_mar_real_df.columns) - 1),
    label="tab:real_ar",
    caption="Artifact removal metrics on real datasets with different models. NoOp model is used as a reference and denotes predicting the input as output.",
    position="htbp",
    multicolumn=True,
)
with open("data/analysis/tables/results/real_ar.tex", "w") as f:
    f.write(
        post_process_latex_table(pivot_ar_real_table, "Recording").replace("\$", "$")
    )

### Artifact Removal filter median

In [None]:
pivot_sas_real_med_df = (
    make_table_summary(
        eval_sas_filtered_df,
        base_model_names,
        selected_recordings,
        "diff_over_threshold_rel",
        totals_row=False,
        agg_func=lambda x: np.quantile(x, 0.25),
    )
    .apply(partial(format_numbers_perc))
    .T
)


pivot_sas_real_med_df

In [None]:
pivot_sas_real_med_table = pivot_sas_real_med_df.to_latex(
    escape=False,
    bold_rows=True,
    column_format="X" + "r" * (len(pivot_mar_real_df.columns) - 1),
    label="tab:real-sas-med",
    position="htbp",
    multicolumn=False,
)
with open("data/analysis/tables/results/real_sas_med.tex", "w") as f:
    f.write(
        post_process_latex_table(pivot_sas_real_med_table, "Recording").replace(
            "\$", "$"
        )
    )

# Ablation

In [None]:
ablation_models = {
    "519": "DeLux SC",
    "524": "DeLux NE",
    "520": "DeLux ND",
    "522": "DeLux MR",
    "518": "DeLux NM",
}

ablation_models_det = {
    "519": "DeLux SC",
    "524": "DeLux NE",
    "522": "DeLux MR",
    "518": "DeLux NM",
}

In [None]:
def sort_by_place_in_list(x: pd.Series, lst: list[str]) -> pd.Series:
    """
    Sorts the Series based on the order of elements in the provided list.
    Elements not in the list will be sorted to the end.
    """
    order = {value: index for index, value in enumerate(lst)}
    return x.map(order)


## Synthetic Artifacts

In [None]:
synth_metrics = {
    "accuracy_th10": "Accuracy",
    "f1_score_th10": "F1 Score",
    "iou_th10": "IoU",
    "mse": "MSE",
    "mape": "MAPE",
    "psnr": "PSNR",
    "mssim": "MS-SSIM",
}
synth_det_metrics = {
    "accuracy_th10": "Accuracy",
    "f1_score_th10": "F1 Score",
    "iou_th10": "IoU",
}
synth_rec_metrics = {
    "mse": "MSE",
    "mape": "MAPE",
    "psnr": "PSNR",
    "mssim": "MS-SSIM",
}

### Full Detection

In [None]:
partial_dfs = []
for metric, name in synth_det_metrics.items():
    synth_partial_det_ablation_df = (
        make_table_summary(
            test_df,
            ablation_models_det,
            synth_kind_names,
            metric,
        )
        .apply(partial(format_numbers, max=(metric not in ["mse", "mape"])))
        .T
    )
    synth_partial_det_ablation_df.insert(0, "Metric", name)
    partial_dfs.append(synth_partial_det_ablation_df)

full_synth_det_ablation_df = pd.concat(partial_dfs, axis=0).reset_index()
full_synth_det_ablation_df = full_synth_det_ablation_df.set_index(["Metric", "kind"])

In [None]:
full_synth_det_ablation_table = full_synth_det_ablation_df[
    ablation_models_det.values()
].to_latex(
    escape=False,
    bold_rows=True,
    column_format="c|X" + "r" * (len(full_synth_det_ablation_df.columns)),
    label="tab:synth-det-ablation",
    caption="Ablation study on synthetic datasets with different models. NoOp model is used as a reference and denotes predicting no artifacts in the entire frame.",
    position="htbp",
    multicolumn=True,
)

with open("data/analysis/tables/ablation/synth_det.tex", "w") as f:
    f.write(
        post_process_latex_table(full_synth_det_ablation_table, "Subset")
    )  # .replace("\$", "$"))

### Full Reconstruction

In [None]:
partial_dfs = []
for metric, name in synth_rec_metrics.items():
    synth_partial_rec_ablation_df = (
        make_table_summary(
            test_df,
            ablation_models,
            synth_kind_names,
            metric,
        )
        .apply(partial(format_numbers, max=(metric not in ["mse", "mape"])))
        .T
    )
    synth_partial_rec_ablation_df.insert(0, "Metric", name)
    partial_dfs.append(synth_partial_rec_ablation_df)

full_synth_rec_ablation_df = pd.concat(partial_dfs, axis=0).reset_index()
full_synth_rec_ablation_df = full_synth_rec_ablation_df.set_index(["Metric", "kind"])

In [None]:
full_synth_rec_ablation_table = full_synth_rec_ablation_df[
    ablation_models.values()
].to_latex(
    escape=False,
    bold_rows=True,
    column_format="c|X" + "r" * (len(full_synth_rec_ablation_df.columns)),
    label="tab:synth-rec-ablation",
    caption="Ablation study on synthetic datasets with different models. NoOp model is used as a reference and denotes predicting no artifacts in the entire frame.",
    position="htbp",
    multicolumn=True,
)

with open("data/analysis/tables/ablation/synth_rec.tex", "w") as f:
    f.write(
        post_process_latex_table(full_synth_rec_ablation_table, "Subset")
    )  # .replace("\$", "$"))

### Compact

In [None]:
def make_compat_summary(
    df: pd.DataFrame,
    models_names: dict[str, str],
    kind_names: dict[str, str],
    metric_names: dict[str, str],
) -> pd.DataFrame:
    df = df.copy()

    subset_df = df[
        df["kind"].isin(kind_names.keys())
        & df["model"].isin(models_names.keys())
        & (df["metric"].isin(metric_names.keys()))
    ]
    subset_df["metric"] = subset_df["metric"].map(metric_names)
    pivot_df = (
        subset_df.groupby(["model", "metric", "type"])["value"]
        .mean()
        .reset_index()
        .pivot(index="model", columns=["type", "metric"], values="value")
        .sort_values(
            by="model",
            key=partial(sort_by_place_in_list, lst=list(models_names.keys())),
        )
        .T.rename(columns=models_names)
        .sort_index(key=partial(sort_by_place_in_list, lst=list(metric_names.values())))
        .T
    )
    return pivot_df

In [None]:
ab_synth_det_rec = make_compat_summary(
    test_df,
    ablation_models,
    synth_kind_names,
    metric_names=synth_metrics,
)
max_metrics = ["accuracy_th10", "f1_score_th10", "iou_th10", "psnr", "mssim"]
ab_synth_det_rec = (
    ab_synth_det_rec.round(4)
    .apply(
        lambda x: format_numbers(
            x, max=x.name[1] in {synth_metrics[m] for m in max_metrics}
        )
    )
    .T
)

ab_synth_det_rec_table = ab_synth_det_rec.to_latex(
    escape=False,
    bold_rows=True,
    column_format="xX" + "r" * len(ab_synth_det_rec.columns),
    label="tab:synth_ablation",
    caption="Ablation study on synthetic datasets with different models.",
    position="htbp",
    multicolumn=True,
)
with open("data/analysis/tables/ablation/synth_compact_all.tex", "w") as f:
    f.write(post_process_latex_table(ab_synth_det_rec_table))
ab_synth_det_rec

## Real Artifacts

In [None]:
real_det_metrics = {
    "accuracy_th10": "Accuracy",
    "f1_score_th10": "F1 Score",
    "iou_th10": "IoU",
}
real_rec_metrics = {
    "vmaf": "VMAF",
    "diff_brisque_rel": "BRISQUE",
    "diff_fastvqa_rel": "FastVQA",
}
real_removal_metrics = {
    "diff_mean_detection_rel": "MAR",
    "diff_over_threshold_rel": "SAS",
    "diff_fiv_rel": "FIV",
}

### Full Detection

In [None]:
partial_dfs = []
for metric, name in real_det_metrics.items():
    real_partial_det_ablation_df = (
        make_table_summary(
            test_df,
            ablation_models,
            real_kind_names,
            metric,
        )
        .apply(partial(format_numbers, max=(metric not in ["mse", "mape"])))
        .T
    )
    real_partial_det_ablation_df.insert(0, "Metric", name)
    partial_dfs.append(real_partial_det_ablation_df)

full_real_det_ablation_df = pd.concat(partial_dfs, axis=0).reset_index()
full_real_det_ablation_df = full_real_det_ablation_df.set_index(["Metric", "kind"])

In [None]:
full_real_det_ablation_table = full_real_det_ablation_df[
    ablation_models_det.values()
].to_latex(
    escape=False,
    bold_rows=True,
    column_format="c|X" + "r" * (len(full_real_det_ablation_df.columns)),
    label="tab:real-det-ablation",
    caption="Ablation study on realetic datasets with different models. NoOp model is used as a reference and denotes predicting no artifacts in the entire frame.",
    position="htbp",
    multicolumn=True,
)

with open("data/analysis/tables/ablation/real_det.tex", "w") as f:
    f.write(
        post_process_latex_table(full_real_det_ablation_table, "Subset")
    )  # .replace("\$", "$"))

### Full Reconstruction

In [None]:
partial_dfs = []
for metric, name in real_rec_metrics.items():
    formatter = format_numbers if metric == "vmaf" else format_numbers_perc
    real_partial_rec_ablation_df = (
        make_table_summary(
            eval_df,
            ablation_models,
            selected_recordings,
            metric,
            totals_row=False,
        )
        .apply(partial(formatter, max=(metric != "diff_fastvqa_rel")))
        .T
    )
    real_partial_rec_ablation_df.insert(0, "Metric", name)
    partial_dfs.append(real_partial_rec_ablation_df)

full_real_rec_ablation_df = pd.concat(partial_dfs, axis=0).reset_index()
full_real_rec_ablation_df = full_real_rec_ablation_df.set_index(["Metric", "kind"])

In [None]:
full_real_rec_ablation_table = full_real_rec_ablation_df[
    ablation_models.values()
].to_latex(
    escape=False,
    bold_rows=True,
    column_format="c|X" + "r" * (len(full_real_rec_ablation_df.columns)),
    label="tab:real-rec-ablation",
    caption="Ablation study on realetic datasets with different models. NoOp model is used as a reference and denotes predicting no artifacts in the entire frame.",
    position="htbp",
    multicolumn=True,
)

with open("data/analysis/tables/ablation/real_rec.tex", "w") as f:
    f.write(
        post_process_latex_table(full_real_rec_ablation_table, "Subset")
    )  # .replace("\$", "$"))

### Full Removal

In [None]:
partial_dfs = []
for metric, name in real_removal_metrics.items():
    real_partial_removal_ablation_df = (
        make_table_summary(
            eval_df,
            ablation_models,
            selected_recordings,
            metric,
            totals_row=False,
        )
        .apply(partial(format_numbers_perc, max=True))
        .T
    )
    real_partial_removal_ablation_df.insert(0, "Metric", name)
    partial_dfs.append(real_partial_removal_ablation_df)

full_real_removal_ablation_df = pd.concat(partial_dfs, axis=0).reset_index()
full_real_removal_ablation_df = full_real_removal_ablation_df.set_index(
    ["Metric", "kind"]
)

In [None]:
full_real_removal_ablation_table = full_real_removal_ablation_df[
    ablation_models.values()
].to_latex(
    escape=False,
    bold_rows=True,
    column_format="c|X" + "r" * (len(full_real_removal_ablation_df.columns)),
    label="tab:real-removal-ablation",
    caption="Ablation study on realetic datasets with different models. NoOp model is used as a reference and denotes predicting no artifacts in the entire frame.",
    position="htbp",
    multicolumn=True,
)

with open("data/analysis/tables/ablation/real_removal.tex", "w") as f:
    f.write(
        post_process_latex_table(full_real_removal_ablation_table, "Subset")
    )  # .replace("\$", "$"))

In [None]:
full_real_removal_ablation_df[ablation_models.values()]

### Artifact Removal filter median

In [None]:
pivot_sas_ab_real_med_df = (
    make_table_summary(
        eval_sas_filtered_df,
        ablation_models_det,
        selected_recordings,
        "diff_over_threshold_rel",
        totals_row=False,
        agg_func=lambda x: np.quantile(x, 0.25),
    )
    .apply(partial(format_numbers_perc))
    .T
)[ablation_models_det.values()]


pivot_sas_ab_real_med_df

In [None]:
pivot_sas_ab_real_med_table = pivot_sas_ab_real_med_df.to_latex(
    escape=False,
    bold_rows=True,
    column_format="X" + "r" * (len(pivot_mar_real_df.columns) - 1),
    label="tab:ablation-real-sas-med",
    position="htbp",
    multicolumn=False,
)
with open("data/analysis/tables/ablation/real_sas_med.tex", "w") as f:
    f.write(
        post_process_latex_table(pivot_sas_ab_real_med_table, "Recording").replace(
            "\$", "$"
        )
    )

### Compact

In [None]:
ab_real_det_rec = (
    make_compat_summary(
        test_df,
        ablation_models,
        real_kind_names,
        metric_names=synth_metrics,
    )
    .round(4)
    .apply(format_numbers)
    .T
)
ab_real_det_rec

In [None]:
eval_with_type = eval_df.copy()
eval_with_type["type"] = eval_df["metric"].apply(
    lambda x: "reconstruction" if x in real_rec_metrics else "removal"
)
max_metrics_real = ["VMAF", "BRISQUE", "MAR", "SAS", "FIV"]
ab_real_rec_rem = (
    make_compat_summary(
        eval_with_type,
        ablation_models,
        selected_recordings,
        metric_names={
            **real_rec_metrics,
            **real_removal_metrics,
        },
    )
    .round(4)
    .apply(
        lambda x: format_numbers(x, max=x.name[1] in max_metrics_real)
        if x.name[1] == "VMAF"
        else format_numbers_perc(x, max=x.name[1] in max_metrics_real)
    )
    .T
)

ab_full_real = pd.concat(
    [
        ab_real_det_rec,
        ab_real_rec_rem,
    ],
    axis=0,
)

ab_full_real_table = ab_full_real.to_latex(
    escape=False,
    bold_rows=True,
    column_format="xX" + "r" * len(ab_full_real.columns),
    label="tab:real_ablation",
    caption="Ablation study on real datasets with different models.",
    position="htbp",
    multicolumn=True,
)
with open("data/analysis/tables/ablation/real_compact_all.tex", "w") as f:
    f.write(
        post_process_latex_table(ab_full_real_table, "Recording").replace("\$", "$")
    )

ab_full_real

In [None]:
mar_df = make_table_summary(
    eval_df,
    ablation_models,
    {k: k for k in eval_df["kind"].unique()},
    "diff_mean_detection_rel",
    totals_row=False,
)
mar_df.insert(0, "Metric", "MAR")

sas_df = make_table_summary(
    eval_df,
    ablation_models,
    {k: k for k in eval_df["kind"].unique()},
    "diff_over_threshold_rel",
    totals_row=False,
)
sas_df["Metric"] = "SAS"

fiv_df = make_table_summary(
    eval_df,
    ablation_models,
    {k: k for k in eval_df["kind"].unique()},
    "diff_fiv_rel",
    totals_row=False,
)
fiv_df["Metric"] = "FIV"

brisque_df = make_table_summary(
    eval_df,
    ablation_models,
    {k: k for k in eval_df["kind"].unique()},
    "diff_brisque_rel",
    totals_row=False,
)
brisque_df["Metric"] = "BRISQUE"
fastvqa_df = make_table_summary(
    eval_df,
    ablation_models,
    {k: k for k in eval_df["kind"].unique()},
    "diff_fastvqa_rel",
    totals_row=False,
)
fastvqa_df["Metric"] = "FastVQA"
vmaf_df = make_table_summary(
    eval_df,
    ablation_models,
    {k: k for k in eval_df["kind"].unique()},
    "vmaf",
    totals_row=False,
)
vmaf_df["Metric"] = "VMAF"
ab_full_real_metrics = pd.concat(
    [
        mar_df,
        sas_df,
        fiv_df,
        brisque_df,
        fastvqa_df,
        vmaf_df,
    ],
    axis=0,
).reset_index()
ab_full_real_metrics.melt(id_vars=["model", "Metric"])

# Disussion

In [None]:
sns.set_theme(style="whitegrid")
sns.set_palette("dark")

In [None]:
failure_rate_df = eval_df.loc[
    (eval_df.metric == "diff_over_threshold_rel")
    & (eval_df.kind.isin(selected_recordings))
    & (eval_df.model != "520")
]
failure_rate_df.model = failure_rate_df.model.map({**ablation_models})


sns.set(font_scale=1.3)
sns.catplot(
    data=failure_rate_df.sort_values(
        "model", key=partial(sort_by_place_in_list, lst=list(ablation_models.values()))
    ),
    x="kind",
    y="value",
    hue="model",
    kind="strip",
    height=6,
    aspect=2,
    alpha=0.8,
    palette=sns.color_palette("bright", n_colors=len(ablation_models)),
)
plt.xlabel("Recording")
plt.ylabel("Relative change in SAS")

plt.savefig(
    pathlib.Path("data/analysis/figures/discussion/plot_sas_decrease.png"),
    bbox_inches="tight",
)

In [None]:
failure_rate_df["value"].max()

In [None]:
sns.lineplot(
    failure_rate_df[failure_rate_df.kind == "sun11"],
    x="frame_idx",
    y="value",
    hue="model",
)