In [1]:
import pandas as pd
import numpy as np
import polars as pl
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score, RocCurveDisplay

sns.set()

In [2]:
def pfbeta(labels, predictions, beta=1):
    y_true_count = 0
    ctp = 0
    cfp = 0

    for idx in range(len(labels)):
        prediction = min(max(predictions[idx], 0), 1)
        if (labels[idx]):
            y_true_count += 1
            ctp += prediction
        else:
            cfp += prediction

    beta_squared = beta * beta
    c_precision = ctp / (ctp + cfp)
    c_recall = ctp / y_true_count
    if (c_precision > 0 and c_recall > 0):
        result = (1 + beta_squared) * (c_precision * c_recall) / (beta_squared * c_precision + c_recall)
        return result
    else:
        return 0

In [3]:
def get_part_metrics(df: pl.DataFrame, threshold=0.3) -> dict:
    df = df.with_columns((df["preds"] > threshold).alias("preds_bin"))
    metrics = {}
    # binary metrics using the threshold
    metrics["accuracy"] = accuracy_score(df["labels"].to_numpy(), df["preds_bin"].to_numpy())
    metrics["precision"] = precision_score(df["labels"].to_numpy(), df["preds_bin"].to_numpy())
    metrics["recall"] = recall_score(df["labels"].to_numpy(), df["preds_bin"].to_numpy())
    metrics["f1"] = f1_score(df["labels"].to_numpy(), df["preds_bin"].to_numpy())
    # probabilistic F1 (doesn't depend on the threshold)
    metrics["pf1"] = pfbeta(df["labels"].to_numpy(), df["preds"].to_numpy())
    # ROC AUC
    metrics["roc_auc"] = roc_auc_score(df["labels"].to_numpy(), df["preds"].to_numpy())
    return metrics


def get_all_metrics(df: pl.DataFrame, threshold=0.3) -> pd.DataFrame:
    groups = [list(range(5)), [0, 1], [0, 4], [0, 2], [0, 3]]
    group_names = ["all", "StableDiffusion", "Midjourney", "Dalle2", "Dalle3"]
    all_metrics = []
    for i, g in enumerate(groups):
        subset = df.filter(pl.col("domains").is_in(g))
        metrics = get_part_metrics(subset, threshold=threshold)
        metrics["group"] = group_names[i]
        all_metrics.append(metrics)
    
    return pd.DataFrame(all_metrics)

In [31]:
df1 = pl.read_csv("outputs/preds-image-classifier-1.csv")
metrics_df1 = get_all_metrics(df1, threshold=0.206)

In [32]:
metrics_df1

Unnamed: 0,accuracy,precision,recall,f1,pf1,roc_auc,group
0,0.920103,0.852803,0.950105,0.898828,0.862582,0.978179,all
1,0.912616,0.654157,0.967005,0.780395,0.79686,0.985916,StableDiffusion
2,0.911468,0.642767,0.962078,0.770656,0.77424,0.981999,Midjourney
3,0.903834,0.525093,0.917599,0.667952,0.648632,0.965689,Dalle2
4,0.905294,0.49929,0.93484,0.650926,0.651111,0.971403,Dalle3


In [49]:
df14 = pl.read_csv("outputs/preds-image-classifier-14.csv")
metrics_df14 = get_all_metrics(df14, threshold=0.652)

In [50]:
metrics_df14

Unnamed: 0,accuracy,precision,recall,f1,pf1,roc_auc,group
0,0.938002,0.891046,0.950221,0.919683,0.885374,0.987696,all
1,0.932351,0.722037,0.940899,0.817065,0.741327,0.986172,StableDiffusion
2,0.9375,0.72009,0.974592,0.828231,0.748441,0.992818,Midjourney
3,0.929644,0.610222,0.920541,0.733928,0.635053,0.980448,Dalle2
4,0.933304,0.590574,0.958112,0.73073,0.627171,0.989697,Dalle3


In [73]:
df142 = pl.read_csv("outputs/preds-image-classifier-142.csv")
metrics_df142 = get_all_metrics(df142, threshold=0.57)

In [74]:
metrics_df142

Unnamed: 0,accuracy,precision,recall,f1,pf1,roc_auc,group
0,0.944693,0.906177,0.950337,0.927732,0.901785,0.988109,all
1,0.940444,0.753137,0.935823,0.8346,0.780422,0.98546,StableDiffusion
2,0.944242,0.749556,0.960182,0.841895,0.788624,0.990542,Midjourney
3,0.942052,0.655678,0.948205,0.775265,0.705979,0.98744,Dalle2
4,0.94329,0.631051,0.962101,0.762181,0.687148,0.989456,Dalle3


In [93]:
df1423 = pl.read_csv("outputs/preds-image-classifier-1423.csv")
metrics_df1423 = get_all_metrics(df1423, threshold=0.9675)

In [94]:
metrics_df1423

Unnamed: 0,accuracy,precision,recall,f1,pf1,roc_auc,group
0,0.947474,0.912932,0.949988,0.931091,0.830166,0.989359,all
1,0.942656,0.766136,0.925308,0.838233,0.612302,0.985141,StableDiffusion
2,0.947702,0.764154,0.957148,0.849832,0.604318,0.990077,Midjourney
3,0.946644,0.67501,0.952325,0.790039,0.496156,0.990272,Dalle2
4,0.949193,0.654239,0.980053,0.784669,0.467308,0.994807,Dalle3
