In [None]:
import numpy as np
import pandas as pd
from sklearn.metrics import precision_recall_fscore_support

In [None]:
all_decisions_df = pd.read_csv("results/all_decisions_df.csv")
all_decisions_df.fillna({"model": ""}, inplace=True)

In [None]:
scores_df = all_decisions_df.groupby(
    ["model", "dataset", "task_scope", "experiment_run"]
).apply(
    lambda group: pd.Series(
        precision_recall_fscore_support(
            group["benchmark"],
            group["decision"] == "yes",
            average="binary",
            pos_label=True,
            zero_division=0.0
        ),
        index=["precision", "recall", "f1-score", "support"]
    ),
    include_groups=False,
).reset_index()

In [None]:
median_f1_scores = scores_df.groupby(["model", "task_scope", "dataset"])["f1-score"].median()

median_experiments = []
for (model, task_scope, dataset), f1_score in median_f1_scores.items():
    _df = scores_df.query((
        f"(model == @model) and (task_scope == @task_scope) and"
        f" (dataset == @dataset) and (`f1-score` == @f1_score)"
    ))
    if _df.empty:
        raise ValueError(f"{model}, {task_scope}, {dataset}, {f1_score}")
    median_experiments.append(
        _df.iloc[0].to_dict()
    )
median_experiments = pd.DataFrame(median_experiments)

median_scores = pd.pivot(
    median_experiments[["model", "task_scope", "dataset", "f1-score", "precision", "recall"]].groupby(["model", "task_scope", "dataset"]).median().reset_index(),
    index="dataset",
    columns=["model", "task_scope"],
    values=["f1-score", "precision", "recall"],
)

median_scores = pd.concat((
    median_scores,
    median_scores.mean().to_frame(name="mean").T,
), axis="index")

In [None]:
import plotly.graph_objects as go

median_f1_scores = median_scores.loc[:, ("f1-score", slice(None), slice(None))]
values = [row - row[0] for row in median_f1_scores.values]
texts = [
    [
        f"{f1_score:.3f} ({median_scores.loc[dataset, ('precision', model, task_scope)]:.2f}, {median_scores.loc[dataset, ('recall', model, task_scope)]:.2f})"
        for (_, model, task_scope), f1_score in row.items()
    ]
    for dataset, row in median_f1_scores.iterrows()
]
fig = go.Figure(
    data=go.Heatmap(
        x=[median_f1_scores.columns.get_level_values(1), median_f1_scores.columns.get_level_values(2)],
        y=median_scores.index,
        z=values,
        text=texts,
        texttemplate="%{text}",
        textfont={"size": 12},
        colorscale="PRGn",
        zmin=-1.0,
        zmax=1.0,
        showscale=False,
    ),
    layout=dict(
        title="Median F1-scores compared to baseline (green: better, purple: worse)",
        height=600,
        width=1000,
        yaxis={"autorange": "reversed"}
    ),
)
fig.show()