In [1]:
import pandas as pd
import numpy as np
import polars as pl
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, f1_score, RocCurveDisplay

sns.set()

In [2]:
def pfbeta(labels, predictions, beta=1):
    y_true_count = 0
    ctp = 0
    cfp = 0

    for idx in range(len(labels)):
        prediction = min(max(predictions[idx], 0), 1)
        if (labels[idx]):
            y_true_count += 1
            ctp += prediction
        else:
            cfp += prediction

    beta_squared = beta * beta
    c_precision = ctp / (ctp + cfp)
    c_recall = ctp / y_true_count
    if (c_precision > 0 and c_recall > 0):
        result = (1 + beta_squared) * (c_precision * c_recall) / (beta_squared * c_precision + c_recall)
        return result
    else:
        return 0

In [3]:
def get_part_metrics(df: pl.DataFrame, threshold=0.3) -> dict:
    df = df.with_columns((df["preds"] > threshold).alias("preds_bin"))
    metrics = {}
    # binary metrics using the threshold
    metrics["accuracy"] = accuracy_score(df["labels"].to_numpy(), df["preds_bin"].to_numpy())
    metrics["precision"] = precision_score(df["labels"].to_numpy(), df["preds_bin"].to_numpy())
    metrics["recall"] = recall_score(df["labels"].to_numpy(), df["preds_bin"].to_numpy())
    metrics["f1"] = f1_score(df["labels"].to_numpy(), df["preds_bin"].to_numpy())
    # probabilistic F1 (doesn't depend on the threshold)
    metrics["pf1"] = pfbeta(df["labels"].to_numpy(), df["preds"].to_numpy())
    # ROC AUC
    metrics["roc_auc"] = roc_auc_score(df["labels"].to_numpy(), df["preds"].to_numpy())
    return metrics


def get_all_metrics(df: pl.DataFrame, threshold=0.3) -> pd.DataFrame:
    groups = [list(range(5)), [0, 1], [0, 4], [0, 2], [0, 3]]
    group_names = ["all", "StableDiffusion", "Midjourney", "Dalle2", "Dalle3"]
    all_metrics = []
    for i, g in enumerate(groups):
        subset = df.filter(pl.col("domains").is_in(g))
        metrics = get_part_metrics(subset, threshold=threshold)
        metrics["group"] = group_names[i]
        all_metrics.append(metrics)
    
    return pd.DataFrame(all_metrics)

In [18]:
df1 = pl.read_csv("outputs/preds-image-classifier-1.csv")
metrics_df1 = get_all_metrics(df1, threshold=0.5)

In [19]:
metrics_df1

Unnamed: 0,accuracy,precision,recall,f1,pf1,roc_auc,group
0,0.922883,0.905793,0.885671,0.895619,0.862582,0.978179,all
1,0.942132,0.763441,0.926759,0.837209,0.79686,0.985916,StableDiffusion
2,0.939611,0.751802,0.909746,0.823267,0.77424,0.981999,Midjourney
3,0.931319,0.636029,0.814597,0.714323,0.648632,0.965689,Dalle2
4,0.935942,0.617021,0.848404,0.714446,0.651111,0.971403,Dalle3


In [20]:
df14 = pl.read_csv("outputs/preds-image-classifier-14.csv")
metrics_df14 = get_all_metrics(df14, threshold=0.5)

In [21]:
metrics_df14

Unnamed: 0,accuracy,precision,recall,f1,pf1,roc_auc,group
0,0.928488,0.857466,0.96976,0.910163,0.885374,0.987696,all
1,0.91378,0.657693,0.965555,0.78243,0.741327,0.986172,StableDiffusion
2,0.916393,0.652021,0.984831,0.784592,0.748441,0.992818,Midjourney
3,0.908425,0.537229,0.947028,0.685556,0.635053,0.980448,Dalle2
4,0.910758,0.514536,0.976729,0.674008,0.627171,0.989697,Dalle3


In [22]:
df142 = pl.read_csv("outputs/preds-image-classifier-142.csv")
metrics_df142 = get_all_metrics(df142, threshold=0.5)

In [23]:
metrics_df142

Unnamed: 0,accuracy,precision,recall,f1,pf1,roc_auc,group
0,0.943042,0.89642,0.958246,0.926303,0.901785,0.988109,all
1,0.935379,0.731982,0.942712,0.824089,0.780422,0.98546,StableDiffusion
2,0.939083,0.728155,0.967008,0.830754,0.788624,0.990542,Midjourney
3,0.936965,0.632006,0.962331,0.762949,0.705979,0.98744,Dalle2
4,0.937072,0.604323,0.966755,0.743734,0.687148,0.989456,Dalle3


In [24]:
df1423 = pl.read_csv("outputs/preds-image-classifier-1423.csv")
metrics_df1423 = get_all_metrics(df1423, threshold=0.5)

In [25]:
metrics_df1423

Unnamed: 0,accuracy,precision,recall,f1,pf1,roc_auc,group
0,0.935309,0.867694,0.975576,0.918478,0.888946,0.987979,all
1,0.919369,0.674637,0.961566,0.792944,0.746404,0.985404,StableDiffusion
2,0.922139,0.669253,0.981418,0.795818,0.748204,0.989346,Midjourney
3,0.9186,0.565705,0.980577,0.717485,0.657445,0.986546,Dalle2
4,0.918294,0.536762,0.985372,0.694959,0.641829,0.991926,Dalle3
