In [None]:
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
  from google.colab import userdata
  from google.colab import drive
  drive.mount('/content/drive')
  PROJECT_ROOT = userdata.get('PROJECT_ROOT')
else:
  PROJECT_ROOT = '../'

!pip install -q tableone
!pip install -q tqdm
!pip install -q scipy
!pip install -q sklearn
!pip install -q seaborn
!pip install -q matplotlib
!pip install -q semopy

# Detecting and mitigating bias using causal modelling

## The method

Introduced by Hui and Lau (2024, doi:[10.1109/CCCIS63483.2024.00016](https://doi.org/10.1109/CCCIS63483.2024.00016)).

**Aim:** Detect and mitigate bias on a sensitive attribute in a black-box predictive model, by modelling and infering the causal relationship between the attribute, the model's preedicted outcome and the ground-truth outcome.

## The experiment

Using the characterisation of statistical gender bias in cardiology algorithms conducted by Straw et al. (2024, doi: [10.2196/46936](https://doi.org/10.2196/46936)) as a reference, we will apply the causal modelling bias mitigation method to the ML model replicated by Straw et al. on the [Heart Disease (CAD) dataset](https://ieee-dataport.org/open-access/heart-disease-dataset-comprehensive).

### Hypothesis

Causal modelling (as per Hui and Lau's method) will detect a statistically significant direct causal path from the protected attribute (sex) to the prediction outcome, and mitigating this direct path will reduce the observed sex-based performance disparity (FNR difference) in the cardiac disease prediction model.

### Key variables

- **Protected attribute ($a$):** sex (male/female)
- **Actual outcome ($y$):** Heart disease diagnosis / No heart disease diagnosis
- **Predicted outcome ($\hat{y}$):** The continuous probability estimate from the black-box predictive model. The inary class (heart disease / no heart disease) is derived via a threshold.
- **Mitigated outcome ($\tilde{y}$):** the continuous, bias-corrected probability estimate derived from the causal modelling mitigation on the predicted outcome.
- **Fairness metrics:**
  - Disparity of Predictive Accuracy between subgroups
  - Disparity of False Positive Rate (FPR) between subgroups
  - Disparity of False Negative Rate (FNR) between subgroups
  - Equal Opportunity (EO), satisfied when the FNR disparity between subgroups nears zero

$$\text{FPR} = \frac{\text{False Positives}}{\text{False Positives} + \text{True Negatives}}$$

$$\text{FNR} = \frac{\text{False Negatives}}{\text{False Negatives} + \text{True Positives}}$$

$$ EO \Leftrightarrow FNR_{male} = FNR_{female}$$



# Function library

In [None]:
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import accuracy_score, confusion_matrix, precision_score,\
 recall_score, roc_auc_score
import semopy
from scipy.stats import sem, t

def train_random_forest(X_train, y_train, X_test, y_test):
  '''
    Trains a sklearn RandomForestClassifier on the given training data,\
     optimised hyperparameters with 3-fold GridSearchCV

     Inputs:
       X_train: training features
       y_train: training labels
       X_test: test features
       y_test: test labels

     Outputs:
       rf: trained RandomForestClassifier
       y_pred: predicted labels
       y_pred_proba: predicted probabilities
  '''
  param_grid = {
    "n_estimators": [100, 200],
    "max_depth": [5, 10, 20, None],
    "max_features": ["sqrt", "log2"],
    "min_samples_split": [2, 5]
  }

  #create the RF classifier
  rf = RandomForestClassifier()

  #create the grid search
  rf_grid = GridSearchCV(estimator=rf, param_grid=param_grid, scoring='roc_auc', cv=3)

  #fit the grid search
  rf_grid.fit(X_train, y_train)
  y_pred = rf_grid.predict(X_test)
  y_pred_proba = rf_grid.predict_proba(X_test)[:,1]

  return [rf_grid, y_pred, y_pred_proba]

def get_causal_model_params(X_train, y_train, y_pred_proba, protected_attribute):
  '''
    Trains a semopy causal model on the given training data and predictions,\
     for the following linear causal model:\
      y_pred ~ beta0 + beta1*y_true + beta2*protected_attribute,\
     to identify the causal relationship between the protected attribute\
      and the predicted outcome.

     Inputs
       X_train: training features
       y_train: training labels
       y_pred_proba: predicted probabilities
       protected_attribute: name of the protected attribute

     Outputs
       beta2: coefficient of the causal relationship between\
        the protected attribute and the predicted outcome
       beta2_pvalue: p-value of the causal relationship
  '''
  causal_features = pd.DataFrame()
  causal_features['protected_attribute'] = X_train[protected_attribute]
  causal_features['y_true'] = y_train
  causal_features['y_pred'] = y_pred_proba

  model_desc='''
    y_true ~ protected_attribute
    y_pred ~ y_true + protected_attribute
  '''

  causal_model = semopy.Model(model_desc)
  causal_model.fit(causal_features)
  causal_params = causal_model.inspect()

  # Retrieve the coefficients of the causal model
  beta2 = causal_params.loc[(causal_params.rval == "protected_attribute") &
                            (causal_params.lval == "y_pred"),'Estimate'].values[0]
  beta2_pvalue = causal_params.loc[(causal_params.rval == "protected_attribute") &
                            (causal_params.lval == "y_pred"),'p-value'].values[0]

  return [beta2, beta2_pvalue]

def get_perf_metrics(y_true, y_pred, y_pred_proba):
  '''
    Calculates the performance metrics for a given set of predictions.

    Inputs
      y_true: true labels
      y_pred: predicted labels
      y_pred_proba: predicted probabilities

    Outputs
      accuracy: accuracy score
      roc_auc: ROC AUC (Receiver Operating Characteristic Area Under the Curve)
      FNR: False Negative Rate
      FPR: False Positive Rate
      tn: True Negatives
      fp: False Positives
      fn: False Negatives
      tp: True Positives
  '''
  accuracy = accuracy_score(y_true, y_pred)
  roc_auc = roc_auc_score(y_true, y_pred_proba)
  tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
  FNR = fn / (fn + tp)
  FPR = fp / (fp + tn)

  return [accuracy, roc_auc, FNR, FPR, tn, fp, fn, tp]

def get_95_ci(data):
  '''
    Calculates the 95% confidence interval for a given set of data.

    Inputs
      data: data as a Pandas Series

    Outputs
      interval: Array of the lower and upper bounds of the confidence interval
  '''
  n = len(data)
  mean = data.mean()
  std_err = sem(data)
  interval = t.interval(0.95, n - 1, loc=mean, scale=std_err)
  return interval


# Data Preprocessing

In [None]:
from tableone import TableOne

heart_disease = pd.read_csv(f'{PROJECT_ROOT}data/heart_disease_cleveland_hungary.csv')

# Remove duplicates and null values, as per Straw et al.

rows_to_drop  = (heart_disease['cholesterol'] == 0) | (heart_disease['resting bp s'] == 0) | (heart_disease.duplicated(keep='first'))
heart_disease.drop(heart_disease[rows_to_drop].index, inplace=True)

if (len(heart_disease) != 746): print("WARNING: Total count of records in cleaned dataset doesn't match reference study" )

# Descriptive statistics
table1 = TableOne(heart_disease,
                  groupby='sex',
                  continuous=['age','cholesterol','max heart rate','resting bp s','oldpeak'],
                  categorical=['target','chest pain type', 'fasting blood sugar','resting ecg','exercise angina','ST slope']
                  )

print(table1)


                             Grouped by sex                                          
                                    Missing       Overall             0             1
n                                                     746           182           564
age, mean (SD)                            0    52.9 (9.5)    52.2 (9.3)    53.1 (9.6)
chest pain type, n (%)     1                     41 (5.5)       9 (4.9)      32 (5.7)
                           2                   166 (22.3)     59 (32.4)    107 (19.0)
                           3                   169 (22.7)     52 (28.6)    117 (20.7)
                           4                   370 (49.6)     62 (34.1)    308 (54.6)
resting bp s, mean (SD)                   0  133.0 (17.3)  132.0 (18.6)  133.4 (16.8)
cholesterol, mean (SD)                    0  244.6 (59.2)  255.8 (62.9)  241.0 (57.5)
fasting blood sugar, n (%) 0                   621 (83.2)    163 (89.6)    458 (81.2)
                           1                   125 (16

In [None]:
# Features and target outcomes
X = heart_disease.drop(['target'], axis=1)
y = heart_disease['target']

# Model training

In [None]:
from tqdm import tqdm
from sklearn.model_selection import StratifiedShuffleSplit
from scipy.stats import barnard_exact
import numpy as np

# Bootstrapping approach with N_RUNS runs and a 70/30 training/test split
N_RUNS = 100

sss = StratifiedShuffleSplit(n_splits=N_RUNS, test_size=0.3, random_state=4)

perf_metrics = []

for i, (train_index, test_index) in tqdm(enumerate(sss.split(X, y)), total=N_RUNS, desc="Running simulations"):

  X_train, X_test = X.iloc[train_index], X.iloc[test_index]
  y_train, y_test = y.iloc[train_index], y.iloc[test_index]

  rf, y_pred, y_pred_proba = train_random_forest(X_train, y_train, X_test, y_test)

  #measure global performance metrics
  accuracy, roc_auc, FNR, FPR,*_ = get_perf_metrics(y_test, y_pred, y_pred_proba)

  # Causal Model Analysis on the training set
  # Considering the following causal model:
  # y_pred = beta0 + beta1*y_true + beta2*a
  # where a is the protected attribute, i.e. sex
  y_train_pred_proba = rf.predict_proba(X_train)[:, 1]
  beta2, beta2_pvalue = get_causal_model_params(X_train, y_train, y_train_pred_proba, 'sex')

  # Audit dataset
  audit_df = X_test.copy()
  audit_df['y_true'] = y_test
  audit_df['y_pred'] = y_pred
  audit_df['y_pred_proba'] = y_pred_proba


  # CAUSAL MITIGATION
  if (beta2_pvalue < 0.05):
    # Calculate y_correct_proba from y_pred_proba
    # and y_train_pred_correct_proba from y_train_pred_proba
    audit_df['y_correct_proba'] = audit_df['y_pred_proba'] - beta2*audit_df['sex']
    y_train_pred_correct_proba = y_train_pred_proba - beta2*X_train['sex']

    # Apply classification threshold as per Hui and Lau,
    # using the quantile in y_correct matching the prevalence of the negative class in the training set:
    target_prevalence = y_train.value_counts()[0] / len(y_train)
    # print(f'Prevalence: {target_prevalence}')
    threshold = y_train_pred_correct_proba.quantile(target_prevalence)
    audit_df['y_correct'] = (audit_df['y_correct_proba'] > threshold).astype(int)

  else:
    audit_df['y_correct_proba'] = audit_df['y_pred_proba']
    audit_df['y_correct'] = audit_df['y_pred']


  # Global perf of the corrected model
  accuracy_corrected, roc_auc_corrected, FNR_corrected, FPR_corrected, *_ = get_perf_metrics(
      audit_df['y_true'],
      audit_df['y_correct'],
      audit_df['y_correct_proba']
  )

  # STRATIFIED PERFORMANCE AUDIT
  male_df = audit_df[audit_df['sex'] == 1]
  female_df = audit_df[audit_df['sex'] == 0]

  ## Baseline model
  accuracy_m, roc_auc_m, FNR_m, FPR_m, tn_m, fp_m, fn_m, tp_m = get_perf_metrics(
      male_df['y_true'],
      male_df['y_pred'],
      male_df['y_pred_proba'])

  accuracy_f, roc_auc_f, FNR_f, FPR_f, tn_f, fp_f, fn_f, tp_f = get_perf_metrics(
      female_df['y_true'],
      female_df['y_pred'],
      female_df['y_pred_proba'])

  ## Chi-squared test for independence between rate of false negative rate (or true positive rate) and sex
  fnr_barnard = barnard_exact([[fn_m, tp_m],[fn_f, tp_f]])
  fpr_barnard = barnard_exact([[fp_m, tn_m],[fp_f, tn_f]])

  ## Corrected predictions
  accuracy_corrected_m, roc_auc_corrected_m, FNR_corrected_m, FPR_corrected_m, *_ = get_perf_metrics(
      male_df['y_true'],
      male_df['y_correct'],
      male_df['y_correct_proba']
  )

  accuracy_corrected_f, roc_auc_corrected_f, FNR_corrected_f, FPR_corrected_f, *_ = get_perf_metrics(
      female_df['y_true'],
      female_df['y_correct'],
      female_df['y_correct_proba']
  )

  perf_metrics.append({
      'run': i,
      'accuracy': accuracy,
      'roc_auc': roc_auc,
      'FNR': FNR,
      'FPR': FPR,
      'accuracy_diff': accuracy_m - accuracy_f,
      'roc_auc_diff': roc_auc_m - roc_auc_f,
      'FNR_diff': FNR_m - FNR_f,
      'FPR_diff': FPR_m - FPR_f,
      'fnr_barnard_pvalue': fnr_barnard.pvalue,
      'fpr_barnard_pvalue': fpr_barnard.pvalue,
      'beta2': beta2,
      'beta2_pvalue': beta2_pvalue,
      'accuracy_corrected': accuracy_corrected,
      'roc_auc_corrected': roc_auc_corrected,
      'FNR_corrected': FNR_corrected,
      'FPR_corrected': FPR_corrected,
      'accuracy_corrected_diff': accuracy_corrected_m - accuracy_corrected_f,
      'roc_auc_corrected_diff': roc_auc_corrected_m - roc_auc_corrected_f,
      'FNR_corrected_diff': FNR_corrected_m - FNR_corrected_f,
      'FPR_corrected_diff': FPR_corrected_m - FPR_corrected_f
  })

perf_metrics_df = pd.DataFrame(perf_metrics)
perf_metrics_df.to_csv(f'{PROJECT_ROOT}/results/perf_metrics_{N_RUNS}_runs.csv')

Running simulations: 100%|██████████| 100/100 [54:10<00:00, 32.50s/it]


# Results

## Expected baseline disparity statistics

| Metric | Mean | P value |
|:--|--:|--:|
| Accuracy disparity (%) | 0.32 | 0.50 |
| ROC_AUC disparity (%) | 3.86 | <.01 |
| FNR disparity (%) | -11.66 | <.01 |
| FPR disparity (%) | 3.94 | <0.1 |

In [None]:
from scipy.stats import ttest_1samp, ttest_rel
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import chi2_contingency
import pandas as pd

# N_RUNS=100

perf_metrics_df = pd.read_csv(f'{PROJECT_ROOT}/results/perf_metrics_{N_RUNS}_runs.csv')

# OVERALL PERFORMANCE
accuracy_ci = get_95_ci(perf_metrics_df['accuracy'])
roc_auc_ci = get_95_ci(perf_metrics_df['roc_auc'])
fnr_ci = get_95_ci(perf_metrics_df['FNR'])
fpr_ci = get_95_ci(perf_metrics_df['FPR'])

print('---Overall performance---')
print(f'Accuracy: {round(perf_metrics_df["accuracy"].mean(), 3)}\
  (95% CI: {round(accuracy_ci[0], 3)}, {round(accuracy_ci[1], 3)})')
print(f'ROC AUC: {round(perf_metrics_df["roc_auc"].mean(), 3)}\
  (95% CI: {round(roc_auc_ci[0], 3)}, {round(roc_auc_ci[1],3)})')
print(f'FNR: {round(perf_metrics_df["FNR"].mean()*100, 2)}\
  (95% CI: {round(fnr_ci[0]*100, 2)}, {round(fnr_ci[1]*100, 2)})')
print(f'FPR: {round(perf_metrics_df["FPR"].mean()*100, 2)}\
  (95% CI: {round(fpr_ci[0]*100, 2)}, {round(fpr_ci[1]*100, 2)})')
print('\n')

# DISPARITY STATS, as per Straw et al
# Does this experiment show the same statistical disparity across runs
# as the original study?
def disparity_stats(data, metric_diff):
  mean = data[metric_diff].mean()
  std = data[metric_diff].std()
  t, p = ttest_1samp(data[metric_diff], popmean=0)
  return [metric_diff, round(mean*100, 2), round(std*100, 2), t, p]

disparity_stats_df = pd.DataFrame(
    [disparity_stats(perf_metrics_df, 'accuracy_diff'),
     disparity_stats(perf_metrics_df, 'roc_auc_diff'),
     disparity_stats(perf_metrics_df, 'FNR_diff'),
     disparity_stats(perf_metrics_df, 'FPR_diff')
     ],
    columns=['Metric', 'Mean %', 'Std Dev %', 't-test', 'P value'])

print(disparity_stats_df.to_markdown())


# HYPOTHESIS 1: Causal modelling (as per Hui and Lau's method)
# will detect a statistically significant direct causal path
# from the protected attribute (sex) to the prediction outcome

# How reliably does the model detect significance in a single split?
# Over the total number of runs, how many recorded a P-value for beta2 < 0.05
causal_detected = perf_metrics_df['beta2_pvalue'] < 0.05
runs_with_causal_path = perf_metrics_df[causal_detected]
print(f'\n Proportion of runs where a causal path between sex and prediction\
 was detected: {round(len(runs_with_causal_path)/N_RUNS, 3)*100} %')

# Contingency table analysis: how often the causal model flags a problem exactly
# when the performance metrics show an inequitable outcome
sig_fnr_disparity = perf_metrics_df['fnr_barnard_pvalue'] < 0.05
Causal_FNR_disparity = pd.crosstab(causal_detected, sig_fnr_disparity,
                                   rownames=['Causal Path Detected (beta2)'],
                                   colnames=['Significant FNR Disparity (Barnard)'])

sns.heatmap(Causal_FNR_disparity, annot=True, cmap=sns.color_palette('light:#D92B89', as_cmap=True))
plt.title('Association between detection of causal path and significant FNR disparity')
plt.savefig(f'{PROJECT_ROOT}/results/causal-FNR-association-heatmap', format='png')
plt.clf()

chi2_causal_FNR_disparity = chi2_contingency(Causal_FNR_disparity.values)
print(f'\n Association between detection of causal path and significant FNR disparity:\
 \n {chi2_causal_FNR_disparity.statistic} (P-value = {chi2_causal_FNR_disparity.pvalue})')




sig_fpr_disparity = perf_metrics_df['fpr_barnard_pvalue'] < 0.05
Causal_FPR_disparity = pd.crosstab(causal_detected, sig_fpr_disparity,
                                   rownames=['Causal Path Detected (beta2)'],
                                   colnames=['Significant FPR Disparity (Barnard)'])

sns.heatmap(Causal_FPR_disparity, annot=True, cmap=sns.color_palette('light:#D92B89', as_cmap=True))
plt.title('Association between detection of causal path and significant FPR disparity')
plt.savefig(f'{PROJECT_ROOT}/results/causal-FPR-association-heatmap', format='png')
plt.clf()

chi2_causal_FPR_disparity = chi2_contingency(Causal_FPR_disparity.values)
print(f'\n Association between detection of causal path and significant FPR disparity:\
 \n {chi2_causal_FPR_disparity.statistic} (P-value = {chi2_causal_FPR_disparity.pvalue})')

# Correlation between disparity metrics
corr_matrix = perf_metrics_df[['accuracy_diff', 'roc_auc_diff', 'FNR_diff', 'FPR_diff', 'beta2']].corr()
sns.heatmap(corr_matrix, annot=True, cmap=sns.diverging_palette(250, 250, center='light', as_cmap=True))
plt.savefig(f'{PROJECT_ROOT}/results/disparity-metrics-correlation', format='png')
plt.close()

# HYPOTHESIS 2: mitigating the detected causal path
# will reduce the observed sex-based performance disparity (FNR difference)
# in the cardiac disease prediction model


# We analyse the statistical significance of the correction in the runs where
# a causal path was detected, with paired-sample t-test on the FNR difference
# before and after causal mitigation
fnr_causal_correction_ttest = ttest_rel(perf_metrics_df['FNR_diff'],
                                         perf_metrics_df['FNR_corrected_diff'],
                                         alternative='less')
print("\n--- FNR Disparity Mitigation Analysis (N=100) ---")
print(f"Mean Absolute FNR Disparity BEFORE Correction:\
 {perf_metrics_df['FNR_diff'].mean():.4f}")
print(f"Std Dev Absolute FNR Disparity BEFORE Correction:\
 {perf_metrics_df['FNR_diff'].std():.4f}")
print("-" * 50)
print(f"Mean Absolute FNR Disparity AFTER Correction:\
 {perf_metrics_df['FNR_corrected_diff'].mean():.4f}")
print(f"Std Dev Absolute FNR Disparity AFTER Correction:\
 {perf_metrics_df['FNR_corrected_diff'].std():.4f}")
print("-" * 50)
print(f"Paired T-Statistic (t): {fnr_causal_correction_ttest.statistic:.4f}")
print(f"P-value: {fnr_causal_correction_ttest.pvalue:.6f}")

perf_metrics_df['FNR_correction_abs'] = perf_metrics_df['FNR_corrected_diff'] \
  - perf_metrics_df['FNR_diff']
perf_metrics_df['FNR_correction_rel'] = perf_metrics_df['FNR_correction_abs'] \
  / perf_metrics_df['FNR_diff'].abs()

print(f'Mean correction (absolute): {perf_metrics_df["FNR_correction_abs"].mean()*100:.2f}%')
print(f'Mean correction (relative): {perf_metrics_df["FNR_correction_rel"].mean()*100:.2f}%')

#Impact of the correction on the overall performance
accuracy_corrected_ci = get_95_ci(perf_metrics_df['accuracy_corrected'])
roc_auc_corrected_ci = get_95_ci(perf_metrics_df['roc_auc_corrected'])
fnr_corrected_ci = get_95_ci(perf_metrics_df['FNR_corrected'])
fpr_corrected_ci = get_95_ci(perf_metrics_df['FPR_corrected'])

print('\n---Overall performance after correction---')
print(f'Accuracy: {round(perf_metrics_df["accuracy_corrected"].mean(), 3)}\
  (95% CI: {round(accuracy_corrected_ci[0], 3)}, {round(accuracy_corrected_ci[1], 3)})')
print(f'ROC AUC: {round(perf_metrics_df["roc_auc_corrected"].mean(), 3)}\
  (95% CI: {round(roc_auc_corrected_ci[0], 3)}, {round(roc_auc_corrected_ci[1],3)})')
print(f'FNR: {round(perf_metrics_df["FNR_corrected"].mean()*100, 2)}\
  (95% CI: {round(fnr_corrected_ci[0]*100, 2)}, {round(fnr_corrected_ci[1]*100, 2)})')
print(f'FPR: {round(perf_metrics_df["FPR_corrected"].mean()*100, 2)}\
  (95% CI: {round(fpr_corrected_ci[0]*100, 2)}, {round(fpr_corrected_ci[1]*100, 2)})')
print('\n')


---Overall performance---
Accuracy: 0.859  (95% CI: 0.855, 0.863)
ROC AUC: 0.926  (95% CI: 0.923, 0.929)
FNR: 13.97  (95% CI: 13.29, 14.65)
FPR: 14.24  (95% CI: 13.56, 14.92)


|    | Metric        |   Mean % |   Std Dev % |     t-test |     P value |
|---:|:--------------|---------:|------------:|-----------:|------------:|
|  0 | accuracy_diff |     0.01 |        4.6  |  0.0218411 | 0.982619    |
|  1 | roc_auc_diff  |     1.54 |        3.73 |  4.12083   | 7.84779e-05 |
|  2 | FNR_diff      |   -12.18 |       12.33 | -9.88451   | 1.95412e-16 |
|  3 | FPR_diff      |     5.15 |        7.55 |  6.82241   | 7.20108e-10 |

 Proportion of runs where a causal path between sex and prediction was detected: 89.0 %

 Association between detection of causal path and significant FNR disparity: 
 0.05745658835546476 (P-value = 0.8105620290006389)

 Association between detection of causal path and significant FPR disparity: 
 0.0 (P-value = 1.0)

--- FNR Disparity Mitigation Analysis (N=100) ---
Me