In [None]:
import numpy as np
import pandas as pd
from sklearn.metrics import precision_recall_fscore_support

In [None]:
all_decisions_df = pd.read_csv("results/all_decisions_df.csv")
all_decisions_df.fillna({"model": ""}, inplace=True)

In [None]:
scores_df = all_decisions_df.groupby(
    ["model", "dataset", "task_scope", "experiment_run"]
).apply(
    lambda group: pd.Series(
        precision_recall_fscore_support(
            group["benchmark"],
            group["decision"] == "yes",
            average="binary",
            pos_label=True,
            zero_division=0.0
        ),
        index=["precision", "recall", "f1-score", "support"]
    ),
    include_groups=False,
).reset_index()

In [None]:
scores_std = scores_df.query("task_scope != 'n-gram'").groupby(["model", "task_scope", "dataset"])[["f1-score", "precision", "recall"]].std()
consistency_table = pd.pivot(
    scores_std.groupby(["model", "task_scope"]).mean().reset_index(),
    index="task_scope",
    columns="model",
    values=["f1-score", "precision", "recall"],
)

In [None]:
import plotly.graph_objects as go

texts = [
    [
        (
            f"{consistency_table.loc[task_scope, ("f1-score", model)]:.3f} "
            f"({consistency_table.loc[task_scope, ('precision', model)]:.2f}, "
            f"{consistency_table.loc[task_scope, ('recall', model)]:.2f})"
        )
        for model in ["GPT-3.5", "GPT-4"]
    ]
    for task_scope in ["1-to-1", "1-to-n", "n-to-1", "n-to-n"]
]
fig = go.Figure(
    data=go.Heatmap(
        x=["GPT-3.5", "GPT-4"],
        y=["1-to-1", "1-to-n", "n-to-1", "n-to-n"],
        z=consistency_table.loc[:, ("f1-score", slice(None))],
        text=texts,
        texttemplate="%{text}",
        textfont={"size": 12},
        colorscale="greens_r",
        zmin=0,
        zmax=consistency_table.loc[:, ("f1-score", slice(None))].max().max(),
        showscale=False,
    ),
    layout=dict(
        title="Standard deviation of F1-score, precision and recall.",
        height=600,
        width=1000,
        yaxis={"autorange": "reversed"}
    ),
)
fig.show()