# Classifier training

**Inputs:**
- data/heart_disease_cleaned.csv
- data/fair_heart_disease_hybrid.csv
- data/counterfactual_heart_disease_hybrid.csv

**Outputs:**
- results/perf_metrics.csv

## Setup and imports

In [None]:
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
  from google.colab import userdata
  from google.colab import drive
  drive.mount('/content/drive')
  PROJECT_ROOT = userdata.get('PROJECT_ROOT')
else:
  PROJECT_ROOT = '../'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [13]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

from google.colab import output
# output.enable_custom_widget_manager()
output.disable_custom_widget_manager()

sns.set_style('whitegrid')
sns.set_context('paper', font_scale=1)

In [14]:
heart_disease = pd.read_csv(f'{PROJECT_ROOT}/data/heart_disease_cleaned.csv')
fair_heart_disease = pd.read_csv(f'{PROJECT_ROOT}/data/fair_heart_disease_hybrid.csv')
cf_heart_disease = pd.read_csv(f'{PROJECT_ROOT}/data/cf_heart_disease_hybrid.csv')

### Function library

In [15]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import RandomizedSearchCV
from sklearn.metrics import accuracy_score, confusion_matrix, precision_score,\
 recall_score, roc_auc_score

def train_random_forest(X_train, y_train, X_test, y_test):
  '''
    Trains a sklearn RandomForestClassifier on the given training data,\
     optimised hyperparameters with 3-fold GridSearchCV

     Inputs:
       X_train: training features
       y_train: training labels
       X_test: test features
       y_test: test labels

     Outputs:
       rf: trained RandomForestClassifier
       y_pred: predicted labels
       y_pred_proba: predicted probabilities
  '''
  param_grid = {
    "max_depth": [5, 10, 20, None],
    "max_features": ["sqrt", "log2"],
    "min_samples_split": [2, 5],
    "min_samples_leaf": [1, 2]
  }

  #create the RF classifier
  rf = RandomForestClassifier(random_state=4, n_estimators=100)

  #create the grid search
  rf_search = RandomizedSearchCV(estimator=rf, param_distributions=param_grid,
                               n_iter=10, scoring='roc_auc',
                               cv=3, n_jobs=-1, random_state=4)

  #fit the grid search
  rf_search.fit(X_train, y_train)
  y_pred = rf_search.predict(X_test)
  y_pred_proba = rf_search.predict_proba(X_test)[:,1]

  return [rf_search, y_pred, y_pred_proba]

def get_perf_metrics(y_true, y_pred, y_pred_proba):
  '''
    Calculates the performance metrics for a given set of predictions.

    Inputs
      y_true: true labels
      y_pred: predicted labels
      y_pred_proba: predicted probabilities

    Outputs
      accuracy: accuracy score
      roc_auc: ROC AUC (Receiver Operating Characteristic Area Under the Curve)
      FNR: False Negative Rate
      FPR: False Positive Rate
      tn: True Negatives
      fp: False Positives
      fn: False Negatives
      tp: True Positives
  '''
  accuracy = accuracy_score(y_true, y_pred)
  roc_auc = roc_auc_score(y_true, y_pred_proba)
  tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
  FNR = fn / (fn + tp)
  FPR = fp / (fp + tn)

  return [accuracy, roc_auc, FNR, FPR, tn, fp, fn, tp]

def get_counterfactual_flips(y_pred, y_cf_total, y_cf_path):
  '''
    Calculates the frequency of counterfactual flips for a given set of predictions.

    Inputs
      y_pred: predictions made on the factual dataset as a numpy array
      y_cf_total: predictions made on the corresponding TOTAL EFFECT counterfactual dataset as a numpy array
      y_cf_path: predictions made on the corresponding PATHWAY EFFECT counterfactual dataset as a numpy array

    Outputs
      pos_flipped_full: frequency of counterfactual flips from y=0 to y=1 for both TOTAL and PATHWAY effect
      pos_flipped_total: frequency of counterfactual flips from y=0 to y=1 for TOTAL effect only
      pos_flipped_path: frequency of counterfactual flips from y=0 to y=1 for PATHWAY effect only
      neg_flipped_full: frequency of counterfactual flips from y=1 to y=0 for both TOTAL and PATHWAY effect
      neg_flipped_total: frequency of counterfactual flips from y=1 to y=0 for TOTAL effect only
      neg_flipped_path: frequency of counterfactual flips from y=1 to y=0 for PATHWAY effect only
  '''
  pos_flipped_full = np.sum((y_pred == 0) & (y_cf_total == 1) & (y_cf_path == 1)) / len(y_pred)
  pos_flipped_total = np.sum((y_pred == 0) & (y_cf_total == 1)  & (y_cf_path == 0)) / len(y_pred)
  pos_flipped_path = np.sum((y_pred == 0) & (y_cf_total == 0) & (y_cf_path == 1)) / len(y_pred)
  neg_flipped_full = np.sum((y_pred == 1) & (y_cf_total == 0) & (y_cf_path == 0)) / len(y_pred)
  neg_flipped_total = np.sum((y_pred == 1) & (y_cf_total == 0)  & (y_cf_path == 1)) / len(y_pred)
  neg_flipped_path = np.sum((y_pred == 1) & (y_cf_total == 1) & (y_cf_path == 0)) / len(y_pred)
  return [pos_flipped_full, pos_flipped_total, pos_flipped_path,
          neg_flipped_full, neg_flipped_total, neg_flipped_path]

def get_causal_metrics(y_pred, y_cf_pred):
    '''
      Calculates the number and frequency of counterfactual flips\
       for a given set of predictions and their counterfactual predictions.

      Inputs
        y_pred: predictions made on the factual dataset as a numpy array
        y_cf_pred: predictions made on the corresponding counterfactual dataset as a numpy array

      Outputs
        num_pos_flips: number of counterfactual flips from y=0 to y=1
        num_neg_flips: number of counterfactual flips from y=1 to y=0
        freq_pos_flips: frequency of counterfactual flips from y=0 to y=1
        freq_neg_flips: frequency of counterfactual flips from y=1 to y=0
    '''
    num_pos_flips = np.sum((y_pred == 0) & (y_cf_pred == 1))
    num_neg_flips = np.sum((y_pred == 1) & (y_cf_pred == 0))
    freq_pos_flips = num_pos_flips / len(y_pred)
    freq_neg_flips = num_neg_flips / len(y_pred)

    return [num_pos_flips, num_neg_flips, freq_pos_flips, freq_neg_flips]


## Model training

In [None]:
from tqdm import tqdm
from sklearn.model_selection import StratifiedShuffleSplit
from scipy.stats import barnard_exact

# baseline features and target class

X = heart_disease.drop(['cvd'], axis=1)
y = heart_disease['cvd']
X_cf = cf_heart_disease.drop(['cvd','U'], axis=1)
y_cf = cf_heart_disease['cvd']

# Bootstrapping approach with N_RUNS runs and a 70/30 split for training and test
N_RUNS = 50

sss = StratifiedShuffleSplit(n_splits=N_RUNS, test_size=0.3, random_state=42)

perf_metrics = []

for i, (train_index, test_index) in tqdm(enumerate(sss.split(X, y)), total=N_RUNS, desc="Running simulations"):

  X_train, X_test = X.iloc[train_index], X.iloc[test_index]
  y_train, y_test = y.iloc[train_index], y.iloc[test_index]

  # Create the equivalent fair training and test sets
  fair_X_train = fair_heart_disease.loc[fair_heart_disease['ID'].isin(train_index)].drop(['cvd', 'ID', 'sex'], axis=1)
  fair_X_test = fair_heart_disease.loc[fair_heart_disease['ID'].isin(test_index)].drop(['cvd', 'ID', 'sex'], axis=1)
  fair_y_train = fair_heart_disease.loc[fair_heart_disease['ID'].isin(train_index), 'cvd']
  fair_y_test = fair_heart_disease.loc[fair_heart_disease['ID'].isin(test_index), 'cvd']

  # Train the baseline and fair models
  rf, y_pred, y_pred_proba = train_random_forest(X_train, y_train, X_test, y_test)
  fair_rf, fair_y_pred, fair_y_pred_proba = train_random_forest(
      fair_X_train, fair_y_train, fair_X_test, fair_y_test)

  #GLOBAL PERFORMANCE METRICS
  accuracy, roc_auc, FNR, FPR,*_ = get_perf_metrics(y_test, y_pred, y_pred_proba)
  fair_accuracy, fair_roc_auc, fair_FNR, fair_FPR,*_ = get_perf_metrics(fair_y_test, fair_y_pred, fair_y_pred_proba)

  # COUNTERFACTUAL PREDICTIONS for the baseline model

  # 1. TOTAL: Flip Ssoc through ECG, ANG, CP and flip the sex feature, while keeping other variables unchanged
  # = what is the total effect of flipping the social sex label and the sociologically-biased pathways?
  X_cf_1_test = X_cf.iloc[test_index].copy()
  y_cf_1_test = y_cf.iloc[test_index].copy()
  y_cf_1_pred = rf.predict(X_cf_1_test)

  # 2. PATHWAY-SPECIFIC: Flip Ssoc through ECG, ANG, CP but keep factual sex feature and other variables unchanged
  # = What is the prediction for the individual if their symptoms had been
  # reported and interpreted as the other sex
  X_cf_2_test = X_cf.iloc[test_index].copy()
  X_cf_2_test['sex'] = X_test['sex'].values
  y_cf_2_test = y_cf.iloc[test_index].copy()
  y_cf_2_pred = rf.predict(X_cf_2_test)

  # STRATIFIED PERFORMANCE AND FAIRNESS
  # Baseline audit dataset
  baseline_audit_df = X_test.copy()
  baseline_audit_df['y_true'] = y_test
  baseline_audit_df['y_pred'] = y_pred
  baseline_audit_df['y_pred_proba'] = y_pred_proba
  baseline_audit_df['y_cf_1_pred'] = y_cf_1_pred
  baseline_audit_df['y_cf_2_pred'] = y_cf_2_pred

  baseline_male_df = baseline_audit_df[baseline_audit_df['sex'] == 1]
  baseline_female_df = baseline_audit_df[baseline_audit_df['sex'] == 0]

  # Fair audit dataset
  fair_audit_df = fair_heart_disease.loc[fair_heart_disease['ID'].isin(test_index)].copy()
  fair_audit_df['y_true'] = fair_y_test
  fair_audit_df['y_pred'] = fair_y_pred
  fair_audit_df['y_pred_proba'] = fair_y_pred_proba
  fair_male_df = fair_audit_df[fair_audit_df['sex'] == 1]
  fair_female_df = fair_audit_df[fair_audit_df['sex'] == 0]

  ### STRATIFIED PERFORMANCE AUDIT
  # Baseline Model:
  accuracy_m, roc_auc_m, FNR_m, FPR_m, tn_m, fp_m, fn_m, tp_m = get_perf_metrics(
      baseline_male_df['y_true'],
      baseline_male_df['y_pred'],
      baseline_male_df['y_pred_proba'])

  accuracy_f, roc_auc_f, FNR_f, FPR_f, tn_f, fp_f, fn_f, tp_f = get_perf_metrics(
      baseline_female_df['y_true'],
      baseline_female_df['y_pred'],
      baseline_female_df['y_pred_proba'])

  # Fair Model before correction of direct bias:
  fair_accuracy_m, fair_roc_auc_m, fair_FNR_m, fair_FPR_m, *_ = get_perf_metrics(
      fair_male_df['y_true'],
      fair_male_df['y_pred'],
      fair_male_df['y_pred_proba'])

  fair_accuracy_f, fair_roc_auc_f, fair_FNR_f, fair_FPR_f, *_ = get_perf_metrics(
      fair_female_df['y_true'],
      fair_female_df['y_pred'],
      fair_female_df['y_pred_proba'])

  ### COUNTERFACTUAL FAIRNESS METRICS on the baseline model

  pos_flipped_full_m, pos_flipped_total_m, pos_flipped_path_m,\
   neg_flipped_full_m, neg_flipped_total_m, neg_flipped_path_m = get_counterfactual_flips(
      baseline_male_df['y_pred'],
      baseline_male_df['y_cf_1_pred'],
      baseline_male_df['y_cf_2_pred'])

  pos_flipped_full_f, pos_flipped_total_f, pos_flipped_path_f,\
   neg_flipped_full_f, neg_flipped_total_f, neg_flipped_path_f = get_counterfactual_flips(
      baseline_female_df['y_pred'],
      baseline_female_df['y_cf_1_pred'],
      baseline_female_df['y_cf_2_pred'])

  perf_metrics.append({
      'run': i,
      'accuracy': accuracy,
      'roc_auc': roc_auc,
      'FNR': FNR,
      'FPR': FPR,
      'fair_accuracy': fair_accuracy,
      'fair_roc_auc': fair_roc_auc,
      'fair_FNR': fair_FNR,
      'fair_FPR': fair_FPR,
      'accuracy_m': accuracy_m,
      'accuracy_f': accuracy_f,
      'accuracy_diff': accuracy_m - accuracy_f,
      'roc_auc_m': roc_auc_m,
      'roc_auc_f': roc_auc_f,
      'roc_auc_diff': roc_auc_m - roc_auc_f,
      'FNR_m': FNR_m,
      'FNR_f': FNR_f,
      'FNR_diff': FNR_m - FNR_f,
      'FPR_m': FPR_m,
      'FPR_f': FPR_f,
      'FPR_diff': FPR_m - FPR_f,
      'fair_accuracy_m': fair_accuracy_m,
      'fair_accuracy_f': fair_accuracy_f,
      'fair_accuracy_diff': fair_accuracy_m - fair_accuracy_f,
      'fair_roc_auc_m': fair_roc_auc_m,
      'fair_roc_auc_f': fair_roc_auc_f,
      'fair_roc_auc_diff': fair_roc_auc_m - fair_roc_auc_f,
      'fair_FNR_m': fair_FNR_m,
      'fair_FNR_f': fair_FNR_f,
      'fair_FNR_diff': fair_FNR_m - fair_FNR_f,
      'fair_FPR_m': fair_FPR_m,
      'fair_FPR_f': fair_FPR_f,
      'fair_FPR_diff': fair_FPR_m - fair_FPR_f,
      'pos_flipped_full_m': pos_flipped_full_m,
      'pos_flipped_full_f': pos_flipped_full_f,
      'pos_flipped_total_m': pos_flipped_total_m,
      'pos_flipped_total_f': pos_flipped_total_f,
      'pos_flipped_path_m': pos_flipped_path_m,
      'pos_flipped_path_f': pos_flipped_path_f,
      'neg_flipped_full_m': neg_flipped_full_m,
      'neg_flipped_full_f': neg_flipped_full_f,
      'neg_flipped_total_m': neg_flipped_total_m,
      'neg_flipped_total_f': neg_flipped_total_f,
      'neg_flipped_path_m': neg_flipped_path_m,
      'neg_flipped_path_f': neg_flipped_path_f
  })


Running simulations:   4%|‚ñç         | 2/50 [00:56<22:20, 27.92s/it]

In [None]:
import os
import datetime
save_path = f'{PROJECT_ROOT}/results'
os.makedirs(save_path, exist_ok=True)

perf_metrics_df = pd.DataFrame(perf_metrics)
date_str = datetime.datetime.now().strftime('%Y-%m-%d_%H%M')
perf_metrics_df.to_csv(f'{save_path}/perf_metrics_hybrid_{N_RUNS}_runs_{date_str}.csv')
print('Performance metrics saved')