## Load models

In [4]:
import torch
import pickle
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

# === Paths ===
MODELS_DIR = "../../models"
MODEL_ORIG_PKL = f"{MODELS_DIR}/dqn_original.pkl"
MODEL_RES_PKL = f"{MODELS_DIR}/dqn_resampled.pkl"


# === Load models ===
# Load from pickle
with open(MODEL_ORIG_PKL, "rb") as f:
    model_orig = pickle.load(f)
with open(MODEL_RES_PKL, "rb") as f:
    model_res = pickle.load(f)

# === Evaluation helper ===
def evaluate_model(model, X):
    model.eval()
    with torch.no_grad():
        X_tensor = torch.tensor(X, dtype=torch.float32)
        preds = model(X_tensor).argmax(dim=1).cpu().numpy()
    return preds

## Evaluate

### General accuracy resampled

In [12]:
X_test = pd.read_csv(f"../../data/X_test.csv").values.astype(np.float32)
y_test = pd.read_csv(f"../../data/y_test.csv")

preds = evaluate_model(model_res, X_test)
accuracy = accuracy_score(y_test, preds)
accuracy

0.11956521739130435

### Gender resampled

In [14]:
import pandas as pd
from sklearn.metrics import accuracy_score, recall_score, f1_score

X_test = pd.read_csv('../../data/X_test.csv')
y_test = pd.read_csv('../../data/y_test.csv')

X_test_num = X_test.values.astype(np.float32)

preds_test = evaluate_model(model_res, X_test_num)

X_test["pred"] = preds_test       
X_test["true"] = y_test   # ground truth

# Protected attribute
protected_attr = "sex_Male"  # since it's 0/1 after encoding
groups = X_test[protected_attr].unique()

metrics = {}
for g in groups:
    group_df = X_test[X_test[protected_attr] == g]
    n_samples = len(group_df)
    acc = accuracy_score(group_df["true"], group_df["pred"])
    rec = recall_score(group_df["true"], group_df["pred"], average="macro")  # across all 5 classes
    f1 = f1_score(group_df["true"], group_df["pred"], average="macro")  
    metrics[g] = {"accuracy": acc, "recall": rec, "f1": f1, "n_samples": n_samples}

metrics_df = pd.DataFrame(metrics).T
metrics_df.index = ["Female (0)", "Male (1)"] # type: ignore
print(metrics_df)

# save
import json

with open("../../results/dqn_gender_resampled.json", "w") as f:
    json.dump(metrics_df.to_dict(), f, indent=4)


            accuracy  recall        f1  n_samples
Female (0)  0.138158    0.20  0.048555      152.0
Male (1)    0.031250    0.25  0.015625       32.0


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
