In [None]:
from pathlib import Path

EXPERIMENT_ROOT = Path(".").parent
PERFORMANCE_TSV = EXPERIMENT_ROOT / "performance.tsv"
PREDICTIONS_TSV = EXPERIMENT_ROOT / "predictions.tsv"

In [None]:
import pandas as pd

predictions_df_orig = pd.read_csv(PREDICTIONS_TSV, sep="\t", index_col=0)
predictions_df_orig

In [None]:
import torch
from torchmetrics import Accuracy, F1Score, MetricCollection, Recall, Precision

CONSIDERED_METRICS = {
    "acc/macro": lambda num_classes: Accuracy(average="micro", num_classes=num_classes),
    "acc/micro": lambda num_classes: Accuracy(average="macro", num_classes=num_classes),
    "acc/weighted": lambda num_classes: Accuracy(average="weighted", num_classes=num_classes),
    "f1/macro": lambda num_classes: F1Score(average="macro", num_classes=num_classes),
    "f1/micro": lambda num_classes: F1Score(average="micro", num_classes=num_classes),
    "f1/weighted": lambda num_classes: F1Score(average="weighted", num_classes=num_classes),
    "recall/macro": lambda num_classes: Recall(average="macro", num_classes=num_classes),
    "recall/micro": lambda num_classes: Recall(average="micro", num_classes=num_classes),
    "recall/weighted": lambda num_classes: Recall(average="weighted", num_classes=num_classes),
    "precision/macro": lambda num_classes: Precision(average="macro", num_classes=num_classes),
    "precision/micro": lambda num_classes: Precision(average="micro", num_classes=num_classes),
    "precision/weighted": lambda num_classes: Precision(average="weighted", num_classes=num_classes),
}

PERFORMANCE_TSV.unlink(missing_ok=True)

DATASET_NUM_CLASSES = {"trec-coarse": 6, "trec-fine": 24, "trec": 6}
performance = {
    **{
        x: []
        for x in (
            "run_id_a",
            "run_id_b",
            "model_type",
            "dataset_name",
            "embedder_a",
            "embedder_b",
        )
    },
    **{k: [] for k in CONSIDERED_METRICS.keys()},
}

KEYS = ["stitching", "run_id_a", "run_id_b", "model_type", "dataset_name", "embedder_a", "embedder_b"]
predictions_df = predictions_df_orig.groupby(KEYS)
for (values, aggregate_df) in predictions_df:
    key2value = dict(zip(KEYS, values))
    aggregate_df: pd.DataFrame

    metrics = MetricCollection(
        {
            key: metric(num_classes=DATASET_NUM_CLASSES[key2value["dataset_name"]])
            for key, metric in CONSIDERED_METRICS.items()
        }
    )
    run_predictions = torch.as_tensor(aggregate_df["pred"].values)
    run_targets = torch.as_tensor(aggregate_df["target"].values)

    metrics.update(run_predictions, run_targets)

    for key, value in key2value.items():
        if key in performance:
            performance[key].append(value)

    for metric_name, metric_value in metrics.compute().items():
        performance[metric_name].append(metric_value.item())
performance_df = pd.DataFrame(performance)
performance_df