In [None]:
# ...existing code...
"""
Train a DQN on the preprocessed heart dataset and evaluate fairness by sex and age.
Run from repository root:
  python3 src/train_dqn_and_fairness.py
Assumes Ella_data_processing.ipynb has produced ../data/Xy_train_resampled.csv, X_test.csv, y_test.csv
and that the raw dataset is at ../datasets/heart_disease_uci.csv
"""
import os
import random
import numpy as np
import pandas as pd
from collections import deque
from typing import Tuple

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
from sklearn.model_selection import train_test_split

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# --- CONFIG ---
# script is in src/, data dir is one level up
DATA_DIR = "../data"
RAW_DATA = "../datasets/heart_disease_uci.csv"
MODEL_OUT = os.path.join(DATA_DIR, "dqn_cardiac.pth")

NUM_ACTIONS = 5        # classes 0..4
STATE_DTYPE = np.float32

# Hyperparams
HIDDEN = [128, 64]
LR = 1e-3
GAMMA = 0.99
BATCH_SIZE = 128
REPLAY_CAP = 20000
MIN_REPLAY = 256
EPS_START, EPS_END, EPS_DECAY = 1.0, 0.05, 0.995
NUM_EPOCHS = 40
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Load preprocessed CSVs produced by Ella_data_processing.ipynb ---
Xy_train_res_path = os.path.join(DATA_DIR, "Xy_train_resampled.csv")
X_test_path = os.path.join(DATA_DIR, "X_test.csv")
y_test_path = os.path.join(DATA_DIR, "y_test.csv")

for p in (Xy_train_res_path, X_test_path, y_test_path):
    if not os.path.exists(p):
        raise FileNotFoundError(f"Required file not found: {p}")

Xy_train_res = pd.read_csv(Xy_train_res_path)
X_test = pd.read_csv(X_test_path)
y_test = pd.read_csv(y_test_path)["num"].values

feature_cols = [c for c in X_test.columns]
X_train = Xy_train_res[feature_cols].values.astype(STATE_DTYPE)
y_train = Xy_train_res["num"].values.astype(int)
X_test = X_test[feature_cols].values.astype(STATE_DTYPE)

# Identify sex column in feature_cols (one-hot from get_dummies -> sex_Male). fallback to 'sex'
if "sex_Male" in feature_cols:
    sex_idx = feature_cols.index("sex_Male")
    test_sex = X_test[:, sex_idx].astype(int)
elif "sex" in feature_cols:
    sex_idx = feature_cols.index("sex")
    test_sex = X_test[:, sex_idx].astype(int)
else:
    test_sex = None
    print("Warning: no sex column found in features; sex-based metrics will be skipped.")

# --- Simple RL environment (one patient = one episode) ---
class CardiacEnv:
    def __init__(self, X: np.ndarray, y: np.ndarray):
        self.X = X
        self.y = y
        self.n = len(X)
        self.idx = 0

    def reset(self) -> np.ndarray:
        state = self.X[self.idx]
        self.curr_label = int(self.y[self.idx])
        self.idx = (self.idx + 1) % self.n
        return state

    def step(self, action: int) -> Tuple[np.ndarray, float, bool, dict]:
        reward = 1.0 if action == self.curr_label else -1.0
        done = True
        next_state = np.zeros_like(self.X[0], dtype=STATE_DTYPE)
        return next_state, reward, done, {}

# --- Replay buffer ---
class ReplayBuffer:
    def __init__(self, capacity: int):
        self.buffer = deque(maxlen=capacity)

    def push(self, s, a, r, s2, done):
        self.buffer.append((s, a, r, s2, done))

    def sample(self, batch_size: int):
        batch = random.sample(self.buffer, batch_size)
        s, a, r, s2, done = map(np.stack, zip(*batch))
        return s, a, r, s2, done

    def __len__(self):
        return len(self.buffer)

# --- Q-network ---
class QNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        layers = []
        prev = input_dim
        for h in HIDDEN:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.ReLU())
            prev = h
        layers.append(nn.Linear(prev, output_dim))
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)

# --- Agent ---
class DQNAgent:
    def __init__(self, state_dim, n_actions):
        self.q = QNet(state_dim, n_actions).to(DEVICE)
        self.target_q = QNet(state_dim, n_actions).to(DEVICE)
        self.target_q.load_state_dict(self.q.state_dict())
        self.opt = optim.Adam(self.q.parameters(), lr=LR)
        self.replay = ReplayBuffer(REPLAY_CAP)
        self.eps = EPS_START
        self.steps = 0

    def act(self, state, greedy=False):
        if (not greedy) and random.random() < self.eps:
            return random.randrange(NUM_ACTIONS)
        s = torch.from_numpy(state).float().to(DEVICE).unsqueeze(0)
        with torch.no_grad():
            qvals = self.q(s)
        return int(qvals.argmax().item())

    def push(self, *args):
        self.replay.push(*args)

    def train_step(self):
        if len(self.replay) < MIN_REPLAY:
            return None
        s, a, r, s2, done = self.replay.sample(BATCH_SIZE)
        s = torch.from_numpy(s).float().to(DEVICE)
        a = torch.from_numpy(a).long().to(DEVICE)
        r = torch.from_numpy(r).float().to(DEVICE)
        s2 = torch.from_numpy(s2).float().to(DEVICE)
        done = torch.from_numpy(done.astype(np.float32)).float().to(DEVICE)

        q_vals = self.q(s).gather(1, a.unsqueeze(1)).squeeze(1)
        with torch.no_grad():
            next_q = self.target_q(s2).max(1)[0]
        target = r + GAMMA * (1.0 - done) * next_q

        loss = nn.MSELoss()(q_vals, target)
        self.opt.zero_grad()
        loss.backward()
        self.opt.step()

        if self.steps % 500 == 0:
            self.target_q.load_state_dict(self.q.state_dict())

        self.eps = max(EPS_END, self.eps * EPS_DECAY)
        self.steps += 1
        return loss.item()

    def save(self, path):
        torch.save(self.q.state_dict(), path)

    def load(self, path):
        self.q.load_state_dict(torch.load(path, map_location=DEVICE))
        self.target_q.load_state_dict(self.q.state_dict())

# --- Training ---
env = CardiacEnv(X_train, y_train)
agent = DQNAgent(state_dim=X_train.shape[1], n_actions=NUM_ACTIONS)

print("Training DQN...")
for epoch in range(NUM_EPOCHS):
    losses = []
    for _ in range(len(X_train)):
        s = env.reset()
        a = agent.act(s, greedy=False)
        s2, r, done, _ = env.step(a)
        agent.push(s, a, r, s2, float(done))
        loss = agent.train_step()
        if loss is not None:
            losses.append(loss)
    if (epoch + 1) % 5 == 0 or epoch == 0:
        avg_loss = np.mean(losses) if losses else float("nan")
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}  replay={len(agent.replay)}  eps={agent.eps:.3f}  avg_loss={avg_loss:.4f}")

agent.save(MODEL_OUT)
print("Model saved to:", MODEL_OUT)

# --- Evaluation (greedy) ---
agent.load(MODEL_OUT)
preds = []
agent.eps = 0.0
for i in range(len(X_test)):
    s = X_test[i]
    a = agent.act(s, greedy=True)
    preds.append(a)
preds = np.array(preds)

print("\nOverall test metrics:")
print("Accuracy:", accuracy_score(y_test, preds))
print("Macro F1:", f1_score(y_test, preds, average="macro"))

# --- Fairness metrics helpers ---
def group_class_metrics(y_true, y_pred, groups, classes=None):
    classes = classes if classes is not None else np.unique(y_true)
    results = {}
    for g in np.unique(groups):
        idx = groups == g
        yt, yp = y_true[idx], y_pred[idx]
        res = {
            "n": int(len(yt)),
            "accuracy": float(accuracy_score(yt, yp)),
            "macro_f1": float(f1_score(yt, yp, average="macro"))
        }
        for c in classes:
            yt_bin = (yt == c).astype(int)
            yp_bin = (yp == c).astype(int)
            res[f"precision_class_{c}"] = float(precision_score(yt_bin, yp_bin, zero_division=0))
            res[f"recall_class_{c}"] = float(recall_score(yt_bin, yp_bin, zero_division=0))
            res[f"pred_rate_class_{c}"] = float((yp == c).mean())
        results[str(g)] = res
    return results

# --- Sex-based fairness (if available) ---
if test_sex is not None:
    groups = test_sex
    metrics_by_sex = group_class_metrics(y_test, preds, groups, classes=np.arange(NUM_ACTIONS))
    print("\nMetrics by sex (sex_Male=1 means male):")
    for g, met in metrics_by_sex.items():
        print(f"Group {g}: n={met['n']}, acc={met['accuracy']:.3f}, macro_f1={met['macro_f1']:.3f}")
        for c in range(NUM_ACTIONS):
            print(f"   class {c}: pred_rate={met[f'pred_rate_class_{c}']:.3f}, recall={met[f'recall_class_{c}']:.3f}, prec={met[f'precision_class_{c}']:.3f}")
    # disparities male - female if both present
    male = metrics_by_sex.get("1")
    female = metrics_by_sex.get("0")
    if male and female:
        print("\nDisparities (male - female):")
        print("Accuracy diff:", male["accuracy"] - female["accuracy"])
        print("Macro F1 diff:", male["macro_f1"] - female["macro_f1"])
        for c in range(NUM_ACTIONS):
            pdiff = male[f"pred_rate_class_{c}"] - female[f"pred_rate_class_{c}"]
            tprdiff = male[f"recall_class_{c}"] - female[f"recall_class_{c}"]
            print(f" Class {c}: pred_rate_diff={pdiff:.3f}, tpr_diff={tprdiff:.3f}")
    print("\nConfusion matrices by sex:")
    for g in np.unique(groups):
        idx = groups == g
        cm = confusion_matrix(y_test[idx], preds[idx], labels=np.arange(NUM_ACTIONS))
        print(f"Group {int(g)} (n={cm.sum()}):\n{cm}")
else:
    print("\nSex-based evaluation skipped (no sex column found).")

# --- Age-based fairness ---
# Reconstruct the same test split from raw data to extract unscaled ages
try:
    raw_df = pd.read_csv(RAW_DATA)
    raw_df = raw_df.copy()
    raw_df = raw_df.drop(['id','dataset'], axis=1, errors='ignore')
    raw_df.drop(columns=['thal','ca'], inplace=True, errors='ignore')
    raw_df.loc[raw_df['trestbps'] == 0, 'trestbps'] = np.nan
    raw_df.loc[raw_df['chol'] == 0, 'chol'] = np.nan
    numeric_cols_raw = raw_df.select_dtypes(include=['int64', 'float64']).columns
    raw_df[numeric_cols_raw] = raw_df[numeric_cols_raw].fillna(raw_df[numeric_cols_raw].median())
    categorical_cols_raw = raw_df.select_dtypes(include=['object']).columns
    for col in categorical_cols_raw:
        raw_df[col] = raw_df[col].fillna(raw_df[col].mode()[0])
    X_raw = raw_df.drop('num', axis=1)
    y_raw = raw_df['num']
    # replicate train_test_split used in preprocessing (same random_state & stratify)
    _, X_test_raw, _, _ = train_test_split(X_raw, y_raw, test_size=0.2, random_state=42, stratify=y_raw)
    if 'age' not in X_test_raw.columns:
        raise KeyError("age column not found in raw dataset")
    test_age_raw = X_test_raw['age'].values
    # create age buckets
    age_bins = [0, 45, 55, 65, 75, 200]
    age_labels = ['<=45', '46-55', '56-65', '66-75', '>75']
    age_group = pd.cut(test_age_raw, bins=age_bins, labels=age_labels, include_lowest=True)
    age_group_arr = age_group.astype(str).values
    metrics_by_age = group_class_metrics(y_test, preds, age_group_arr, classes=np.arange(NUM_ACTIONS))
    print("\nMetrics by age group:")
    for g, met in metrics_by_age.items():
        print(f"Age {g}: n={met['n']}, acc={met['accuracy']:.3f}, macro_f1={met['macro_f1']:.3f}")
        for c in range(NUM_ACTIONS):
            print(f"   class {c}: pred_rate={met[f'pred_rate_class_{c}']:.3f}, recall={met[f'recall_class_{c}']:.3f}, prec={met[f'precision_class_{c}']:.3f}")
    # example disparity: oldest vs youngest if both present
    if '<=45' in metrics_by_age and '>75' in metrics_by_age:
        young = metrics_by_age['<=45']
        old = metrics_by_age['>75']
        print("\nDisparities (old - young):")
        print("Accuracy diff (old - young):", old["accuracy"] - young["accuracy"])
        print("Macro F1 diff (old - young):", old["macro_f1"] - young["macro_f1"])
    print("\nConfusion matrices by age group:")
    for g in np.unique(age_group_arr):
        idx = (age_group_arr == g)
        if idx.sum() == 0:
            continue
        cm = confusion_matrix(y_test[idx], preds[idx], labels=np.arange(NUM_ACTIONS))
        print(f"Age {g} (n={cm.sum()}):\n{cm}")
except Exception as e:
    print("Could not compute age-based metrics:", str(e))

# End of script
# ...existing code...

Training DQN...
Epoch 1/40  replay=1645  eps=0.050  avg_loss=0.2239
Epoch 5/40  replay=8225  eps=0.050  avg_loss=0.0430
Epoch 10/40  replay=16450  eps=0.050  avg_loss=0.0189
Epoch 15/40  replay=20000  eps=0.050  avg_loss=0.0100
Epoch 20/40  replay=20000  eps=0.050  avg_loss=0.0057
Epoch 25/40  replay=20000  eps=0.050  avg_loss=0.0048
Epoch 30/40  replay=20000  eps=0.050  avg_loss=0.0046
Epoch 35/40  replay=20000  eps=0.050  avg_loss=0.0041
Epoch 40/40  replay=20000  eps=0.050  avg_loss=0.0035
Model saved to: ../data/dqn_cardiac.pth

Overall test metrics:
Accuracy: 0.45108695652173914
Macro F1: 0.31862595793648846

Metrics by sex (sex_Male=1 means male):
Group 0: n=32, acc=0.625, macro_f1=0.208
   class 0: pred_rate=0.656, recall=0.704, prec=0.905
   class 1: pred_rate=0.125, recall=0.250, prec=0.250
   class 2: pred_rate=0.094, recall=0.000, prec=0.000
   class 3: pred_rate=0.094, recall=0.000, prec=0.000
   class 4: pred_rate=0.031, recall=0.000, prec=0.000
Group 1: n=152, acc=0.414, 

  raw_df[col] = raw_df[col].fillna(raw_df[col].mode()[0])
