In [1]:
import jupyter_black
jupyter_black.load()

In [2]:
!pip install fastplot



In [3]:
import numpy as np
import pandas as pd
import fastplot
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns

In [4]:
DATASET_NAME = "test.csv"
PREPROCESSING_MODES = {
    "all": "PP0",
    "basic": "PP1",
    "manual-1": "PP2",
    "manual-2": "PP3",
    "manual-3": "PP4",
    "random": "PP5",
}
MODEL_NAMES = {
    "swinv2-base-patch4-window16-256": "mB",
    "swinv2-large-patch4-window12-192-22k": "mL",
    "swinv2-small-patch4-window16-256": "mS",
    "swinv2-tiny-patch4-window16-256": "mT",
}
DATASET_NAMES = {"lab": "D0", "lab+maria": "D1", "lab+pasi": "D2"}

In [5]:
df = pd.read_csv(DATASET_NAME)

In [6]:
df["preprocess_mode"] = df["preprocess_mode"].map(PREPROCESSING_MODES)
df["model"] = df["model"].map(MODEL_NAMES)
df["dataset"] = df["dataset"].map(DATASET_NAMES)

In [7]:
df[:3]

Unnamed: 0.1,Unnamed: 0,model,preprocess_mode,dataset,test_acc,test_prec,test_rec,test_f1
0,0,mB,PP0,D1,0.74375,0.777952,0.74375,0.711543
1,1,mB,PP4,D1,0.7,0.6925,0.7,0.664997
2,2,mB,PP3,D1,0.671875,0.685011,0.671875,0.64126


In [8]:
max_acc_rows = df.groupby(["model", "preprocess_mode"])["test_acc"].idxmax()
df.loc[max_acc_rows][:3]

Unnamed: 0.1,Unnamed: 0,model,preprocess_mode,dataset,test_acc,test_prec,test_rec,test_f1
6,6,mB,PP0,D2,0.795,0.839236,0.795,0.767389
11,11,mB,PP1,D2,0.7825,0.842084,0.7825,0.765571
9,9,mB,PP2,D2,0.8025,0.846665,0.8025,0.796293


In [9]:
def make_scatterplot_cb(plt):
    ax = sns.scatterplot(
        df.loc[max_acc_rows],
        alpha=0.8,
        x="test_rec",
        y="test_prec",
        hue="preprocess_mode",
        style="model",
    )

    handles, labels = ax.get_legend_handles_labels()
    handles.pop(0)
    handles[6] = mpatches.Patch(color="white", label="")
    handles[7], handles[8] = handles[8], handles[7]

    ax.set_xlabel("Recall")
    ax.set_ylabel("Precision")

    ax.legend(
        handles=handles,
        loc="upper center",
        bbox_to_anchor=(0.5, -0.22),
        labelspacing=1,
        ncol=2,
        borderpad=1.3,
        handletextpad=0.1,
        prop={"size": 5},
    )


fastplot.plot(
    None,
    "prec_rec.pdf",
    mode="callback",
    callback=make_scatterplot_cb,
    grid=True,
    style="latex",
    ylim=(0.65, 0.95),
    figsize=(3, 4),
)