In [None]:
# ===============================
# Fairness Metrics + Plot (Colab)
# ===============================

import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

# -------------------------------
# 1. LOAD MODEL + TEST DATA
# -------------------------------
from common.dataset import load_imbalanced_cifar
from torch.utils.data import DataLoader, random_split
from common.model import SimpleCNN

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset = load_imbalanced_cifar(cat_count=30, dog_count=500)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
_, test_ds = random_split(dataset, [train_size, test_size])

test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

model = SimpleCNN().to(device)
model.load_state_dict(torch.load("/content/model.pth"))  # עדכן נתיב אם צריך
model.eval()

# -------------------------------
# 2. COLLECT PREDICTIONS
# -------------------------------

y_true = []
y_pred = []
y_prob = []

with torch.no_grad():
    for images, labels in test_loader:
        images = images.to(device)
        outputs = model(images)
        probs = torch.softmax(outputs, dim=1)

        y_prob.extend(probs[:, 1].cpu().numpy())
        y_pred.extend(torch.argmax(outputs, 1).cpu().numpy())
        y_true.extend(labels.numpy())

y_true = np.array(y_true)
y_pred = np.array(y_pred)
y_prob = np.array(y_prob)

# -------------------------------
# 3. GROUP MASKS
# 0 = Cats  |  1 = Dogs
# -------------------------------

group_A = y_true == 0
group_B = y_true == 1

# -------------------------------
# 4. METRIC FUNCTIONS
# -------------------------------


def demographic_parity(preds, mask):
    return np.mean(preds[mask] == 1)


def tpr(true, pred, mask):
    tp = np.sum((true[mask] == 1) & (pred[mask] == 1))
    p = np.sum(true[mask] == 1)
    return tp / p if p > 0 else 0


def fpr(true, pred, mask):
    fp = np.sum((true[mask] == 0) & (pred[mask] == 1))
    n = np.sum(true[mask] == 0)
    return fp / n if n > 0 else 0


def accuracy(true, pred, mask):
    return np.mean(true[mask] == pred[mask])


# -------------------------------
# 5. COMPUTE METRICS
# -------------------------------

metrics = {
    "DP (Demographic Parity)": [
        demographic_parity(y_pred, group_A),
        demographic_parity(y_pred, group_B),
    ],
    "TPR (Equal Opportunity)": [
        tpr(y_true, y_pred, group_A),
        tpr(y_true, y_pred, group_B),
    ],
    "FPR (Equalized Odds)": [
        fpr(y_true, y_pred, group_A),
        fpr(y_true, y_pred, group_B),
    ],
    "Accuracy per Group": [
        accuracy(y_true, y_pred, group_A),
        accuracy(y_true, y_pred, group_B),
    ],
}

# -------------------------------
# 6. PLOT
# -------------------------------

fig, ax = plt.subplots(figsize=(10, 6))

metric_names = list(metrics.keys())
index = np.arange(len(metric_names))
bar_width = 0.35

A_vals = [metrics[m][0] for m in metric_names]
B_vals = [metrics[m][1] for m in metric_names]

ax.bar(index, A_vals, bar_width, label="Group A (Cats)")
ax.bar(index + bar_width, B_vals, bar_width, label="Group B (Dogs)")

ax.set_xlabel("Metric")
ax.set_ylabel("Value")
ax.set_title("Fairness Metrics by Group (Empirical Results)")
ax.set_xticks(index + bar_width / 2)
ax.set_xticklabels(metric_names, rotation=45, ha="right")
ax.legend()
plt.tight_layout()

plt.show()