In [None]:
import sys
import os
try:
  from google.colab import userdata
  PROJECT_ROOT = userdata.get('PROJECT_ROOT')
except ImportError:
  PROJECT_ROOT = '/'

if PROJECT_ROOT not in sys.path:
  sys.path.append(PROJECT_ROOT)

os.chdir(PROJECT_ROOT)

!pip install -q -r "{PROJECT_ROOT}dependencies.txt"

# Detecting and mitigating bias using causal modelling

## The method

Introduced by Hui and Lau (2024, doi:[10.1109/CCCIS63483.2024.00016](https://doi.org/10.1109/CCCIS63483.2024.00016)).

**Aim:** Detect and mitigate bias on a sensitive attribute in a black-box predictive model, by modelling and infering the causal relationship between the attribute, the model's preedicted outcome and the ground-truth outcome.

## The experiment

Using the characterisation of statistical gender bias in cardiology algorithms conducted by Straw et al. (2024, doi: [10.2196/46936](https://doi.org/10.2196/46936)) as a reference, we will apply the causal modelling bias mitigation method to the ML model replicated by Straw et al. on the [Heart Disease (CAD) dataset](https://ieee-dataport.org/open-access/heart-disease-dataset-comprehensive).

### Hypothesis

Causal modelling (as per Hui and Lau's method) will detect a statistically significant direct causal path from the protected attribute (sex) to the prediction outcome, and mitigating this direct path will reduce the observed sex-based performance disparity (FNR difference) in the cardiac disease prediction model.

### Key variables

- **Protected attribute ($a$):** sex (male/female)
- **Actual outcome ($y$):** Heart disease diagnosis / No heart disease diagnosis
- **Predicted outcome ($\hat{y}$):** The continuous probability estimate from the black-box predictive model. The inary class (heart disease / no heart disease) is derived via a threshold.
- **Mitigated outcome ($\tilde{y}$):** the continuous, bias-corrected probability estimate derived from the causal modelling mitigation on the predicted outcome.
- **Fairness metrics:**
  - Disparity of Predictive Accuracy between subgroups
  - Disparity of False Positive Rate (FPR) between subgroups
  - Disparity of False Negative Rate (FNR) between subgroups
  - Equal Opportunity (EO), satisfied when the FNR disparity between subgroups nears zero

$$\text{FPR} = \frac{\text{False Positives}}{\text{False Positives} + \text{True Negatives}}$$

$$\text{FNR} = \frac{\text{False Negatives}}{\text{False Negatives} + \text{True Positives}}$$

$$ EO \Leftrightarrow FNR_{male} = FNR_{female}$$



# Data Preprocessing

In [None]:
import pandas as pd
from tableone import TableOne

heart_disease = pd.read_csv('data/heart_disease_cleveland_hungary.csv')

# Remove duplicates and null values, as per Straw et al.

rows_to_drop  = (heart_disease['cholesterol'] == 0) | (heart_disease['resting bp s'] == 0) | (heart_disease.duplicated(keep='first'))
heart_disease.drop(heart_disease[rows_to_drop].index, inplace=True)

if (len(heart_disease) != 746): print("WARNING: Total count of records in cleaned dataset doesn't match reference study" )

# Descriptive statistics
table1 = TableOne(heart_disease,
                  groupby='sex',
                  continuous=['age','cholesterol','max heart rate','resting bp s','oldpeak'],
                  categorical=['target','chest pain type', 'fasting blood sugar','resting ecg','exercise angina','ST slope']
                  )

print(table1)


                             Grouped by sex                                          
                                    Missing       Overall             0             1
n                                                     746           182           564
age, mean (SD)                            0    52.9 (9.5)    52.2 (9.3)    53.1 (9.6)
chest pain type, n (%)     1                     41 (5.5)       9 (4.9)      32 (5.7)
                           2                   166 (22.3)     59 (32.4)    107 (19.0)
                           3                   169 (22.7)     52 (28.6)    117 (20.7)
                           4                   370 (49.6)     62 (34.1)    308 (54.6)
resting bp s, mean (SD)                   0  133.0 (17.3)  132.0 (18.6)  133.4 (16.8)
cholesterol, mean (SD)                    0  244.6 (59.2)  255.8 (62.9)  241.0 (57.5)
fasting blood sugar, n (%) 0                   621 (83.2)    163 (89.6)    458 (81.2)
                           1                   125 (16

In [None]:
# Features and target outcomes
X = heart_disease.drop(['target'], axis=1)
y = heart_disease['target']

# Model training

In [None]:
from tqdm import tqdm
from sklearn.model_selection import StratifiedShuffleSplit
from scipy.stats import chi2_contingency
import numpy as np

from src.models import train_random_forest
from src.causal_model import get_causal_model_params
from src.audit import get_perf_metrics


# Bootstrapping approach with 100 runs and a 70/30 training/test split
N_RUNS = 100
sss = StratifiedShuffleSplit(n_splits=N_RUNS, test_size=0.3, random_state=4)

perf_metrics = []

for i, (train_index, test_index) in tqdm(enumerate(sss.split(X, y)), total=N_runs, desc="Running simulations"):

  X_train, X_test = X.iloc[train_index], X.iloc[test_index]
  y_train, y_test = y.iloc[train_index], y.iloc[test_index]

  # Corrected line: Swapped X_test and y_train to match the function signature
  rf, y_pred, y_pred_proba = train_random_forest(X_train, X_test, y_train, y_test)

  #measure global performance metrics
  accuracy, roc_auc, FNR, FPR,*_ = get_perf_metrics(y_test, y_pred, y_pred_proba)

  # Causal Model Analysis on the training set
  # Considering the following causal model:
  # y_pred = beta0 + beta1*y_true + beta2*a
  # where a is the protected attribute, i.e. sex
  y_train_pred_proba = rf.predict_proba(X_train)[:, 1]
  beta2, beta2_pvalue = get_causal_model_params(X_train, y_train, y_train_pred_proba, 'sex')

  # Audit dataset
  audit_df = X_test.copy()
  audit_df['y_true'] = y_test
  audit_df['y_pred'] = y_pred
  audit_df['y_pred_proba'] = y_pred_proba


  # CAUSAL MITIGATION
  if (beta2_pvalue < 0.05):
    # Calculate y_correct_proba from y_pred_proba
    # and y_train_pred_correct_proba from y_train_pred_proba
    audit_df['y_correct_proba'] = audit_df['y_pred_proba'] - beta2*audit_df['sex']
    y_train_pred_correct_proba = y_train_pred_proba - beta2*X_train['sex']

    # Apply classification threshold as per Hui and Lau,
    # using the quantile in y_correct matching the prevalence of the negative class in the training set:
    target_prevalence = y_train.value_counts()[0] / len(y_train)
    # print(f'Prevalence: {target_prevalence}')
    threshold = y_train_pred_correct_proba.quantile(target_prevalence)
    audit_df['y_correct'] = (audit_df['y_correct_proba'] > threshold).astype(int)

  else:
    audit_df['y_correct_proba'] = audit_df['y_pred_proba']
    audit_df['y_correct'] = audit_df['y_pred']


  # Global perf of the corrected model
  accuracy_corrected, roc_auc_corrected, FNR_corrected, FPR_corrected, *_ = get_perf_metrics(
      audit_df['y_true'],
      audit_df['y_correct'],
      audit_df['y_correct_proba']
  )

  # STRATIFIED PERFORMANCE AUDIT
  male_df = audit_df[audit_df['sex'] == 1]
  female_df = audit_df[audit_df['sex'] == 0]

  ## Baseline model
  accuracy_m, roc_auc_m, FNR_m, FPR_m, tn_m, fp_m, fn_m, tp_m = get_perf_metrics(
      male_df['y_true'],
      male_df['y_pred'],
      male_df['y_pred_proba'])

  accuracy_f, roc_auc_f, FNR_f, FPR_f, tn_f, fp_f, fn_f, tp_f = get_perf_metrics(
      female_df['y_true'],
      female_df['y_pred'],
      female_df['y_pred_proba'])

  ## Chi-squared test for independence between rate of false negative rate (or true positive rate) and sex
  fnr_chi2_result = chi2_contingency(np.array([[fn_m, tp_m],[fn_f, tp_f]]))
  fpr_chi2_result = chi2_contingency(np.array([[fp_m, tn_m],[fp_f, tn_f]]))

  ## Corrected predictions
  accuracy_corrected_m, roc_auc_corrected_m, FNR_corrected_m, FPR_corrected_m, *_ = get_perf_metrics(
      male_df['y_true'],
      male_df['y_correct'],
      male_df['y_correct_proba']
  )

  accuracy_corrected_f, roc_auc_corrected_f, FNR_corrected_f, FPR_corrected_f, *_ = get_perf_metrics(
      female_df['y_true'],
      female_df['y_correct'],
      female_df['y_correct_proba']
  )

  perf_metrics.append({
      'run': i,
      'accuracy': accuracy,
      'roc_auc': roc_auc,
      'FNR': FNR,
      'FPR': FPR,
      'accuracy_diff': accuracy_m - accuracy_f,
      'roc_auc_diff': roc_auc_m - roc_auc_f,
      'FNR_diff': FNR_m - FNR_f,
      'FPR_diff': FPR_m - FPR_f,
      'fnr_chi2_pvalue': fnr_chi2_result.pvalue,
      'fpr_chi2_pvalue': fpr_chi2_result.pvalue,
      'beta2': beta2,
      'beta2_pvalue': beta2_pvalue,
      'accuracy_corrected': accuracy_corrected,
      'roc_auc_corrected': roc_auc_corrected,
      'FNR_corrected': FNR_corrected,
      'FPR_corrected': FPR_corrected,
      'accuracy_corrected_diff': accuracy_corrected_m - accuracy_corrected_f,
      'roc_auc_corrected_diff': roc_auc_corrected_m - roc_auc_corrected_f,
      'FNR_corrected_diff': FNR_corrected_m - FNR_corrected_f,
      'FPR_corrected_diff': FPR_corrected_m - FPR_corrected_f
  })



Running simulations: 100%|██████████| 100/100 [58:09<00:00, 34.90s/it]


# Results

## Expected baseline disparity statistics

| Metric | Mean | P value |
|:--|--:|--:|
| Accuracy disparity (%) | 0.32 | 0.50 |
| ROC_AUC disparity (%) | 3.86 | <.01 |
| FNR disparity (%) | -11.66 | <.01 |
| FPR disparity (%) | 3.94 | <0.1 |

In [None]:
from scipy.stats import ttest_1samp, ttest_rel
import seaborn as sns
import matplotlib.pyplot as plt

try:
  # Convert the list of dictionaries to a DataFrame
  perf_metrics_df = pd.DataFrame(perf_metrics)
  perf_metrics_df.to_csv(f'{PROJECT_ROOT}/results/perf_metrics_{N_RUNS}_runs.csv')
except NameError:
  N_RUNS = 100
  perf_metrics_df = pd.read_csv(f'{PROJECT_ROOT}/results/perf_metrics_{N_RUNS}_runs.csv')

# DISPARITY STATS, as per Straw et al
# Does this experiment show the same statistical disparity across runs
# as the original study?
def disparity_stats(data, metric_diff):
  mean = data[metric_diff].mean()
  std = data[metric_diff].std()
  t, p = ttest_1samp(data[metric_diff], popmean=0)
  return [metric_diff, round(mean*100, 2), round(std*100, 2), t, p]

disparity_stats_df = pd.DataFrame(
    [disparity_stats(perf_metrics_df, 'accuracy_diff'),
     disparity_stats(perf_metrics_df, 'roc_auc_diff'),
     disparity_stats(perf_metrics_df, 'FNR_diff'),
     disparity_stats(perf_metrics_df, 'FPR_diff')
     ],
    columns=['Metric', 'Mean %', 'Std Dev %', 't-test', 'P value'])

print(disparity_stats_df.to_markdown())


# HYPOTHESIS 1: Causal modelling (as per Hui and Lau's method)
# will detect a statistically significant direct causal path
# from the protected attribute (sex) to the prediction outcome

# How reliably does the model detect significance in a single split?
# Over the total number of runs, how many recorded a P-value for beta2 < 0.05
causal_detected = perf_metrics_df['beta2_pvalue'] < 0.05
runs_with_causal_path = perf_metrics_df[causal_detected]
print(f'\n Proportion of runs where a causal path between sex and prediction\
 was detected: {round(len(runs_with_causal_path)/N_RUNS, 2)}')

# Contingency table analysis
sig_fnr_disparity = perf_metrics_df['fnr_chi2_pvalue'] < 0.05
Causal_FNR_disparity = pd.crosstab(causal_detected, sig_fnr_disparity).values
chi2_causal_FNR_disparity = chi2_contingency(Causal_FNR_disparity)
print(f'\n Association between detection of causal path and significant FNR disparity:\
 \n {chi2_causal_FNR_disparity.statistic} (P-value = {chi2_causal_FNR_disparity.pvalue})')


sig_fpr_disparity = perf_metrics_df['fpr_chi2_pvalue'] < 0.05
Causal_FPR_disparity = pd.crosstab(causal_detected, sig_fpr_disparity).values
chi2_causal_FPR_disparity = chi2_contingency(Causal_FPR_disparity)
print(f'\n Association between detection of causal path and significant FPR disparity:\
 \n {chi2_causal_FPR_disparity.statistic} (P-value = {chi2_causal_FPR_disparity.pvalue})')


# HYPOTHESIS 2: mitigating the detected causal path
# will reduce the observed sex-based performance disparity (FNR difference)
# in the cardiac disease prediction model


# We analyse the statistical significance of the correction in the runs where
# a causal path was detected, with paired-sample t-test on the FNR difference
# before and after causal mitigation
fnr_causal_correction_ttest = ttest_rel(perf_metrics_df['FNR_diff'].abs(),
                                         perf_metrics_df['FNR_corrected_diff'].abs(),
                                         alternative='greater')
print("--- FNR Disparity Mitigation Analysis (N=100) ---")
print(f"Mean Absolute FNR Disparity BEFORE Correction:\
 {perf_metrics_df['FNR_diff'].mean():.4f}")
print(f"Std Dev Absolute FNR Disparity BEFORE Correction:\
 {perf_metrics_df['FNR_diff'].std():.4f}")
print("-" * 50)
print(f"Mean Absolute FNR Disparity AFTER Correction:\
 {perf_metrics_df['FNR_corrected_diff'].mean():.4f}")
print(f"Std Dev Absolute FNR Disparity AFTER Correction:\
 {perf_metrics_df['FNR_corrected_diff'].std():.4f}")
print("-" * 50)
print(f"Paired T-Statistic (t): {fnr_causal_correction_ttest.statistic:.4f}")
print(f"One-Tailed P-value: {fnr_causal_correction_ttest.pvalue:.6f}")



|    | Metric        |   Mean % |   Std Dev % |    t-test |     P value |
|---:|:--------------|---------:|------------:|----------:|------------:|
|  0 | accuracy_diff |     0.16 |        4.43 |  0.368732 | 0.713115    |
|  1 | roc_auc_diff  |     1.43 |        3.67 |  3.88565  | 0.000184552 |
|  2 | FNR_diff      |   -11.97 |       12.32 | -9.71811  | 4.50863e-16 |
|  3 | FPR_diff      |     4.9  |        6.99 |  7.00178  | 3.06943e-10 |

 Proportion of runs where a causal path between sex and prediction was detected: 0.44

 Association between dtection of causal path and FNR disparity: 
 0.0031830914183855305 (P-value = 0.955008106975321)

 Association between dtection of causal path and FPR disparity: 
 0.10997067448680355 (P-value = 0.7401775249481735)
--- FNR Disparity Mitigation Analysis (N=100) ---
Mean Absolute FNR Disparity BEFORE Correction: -0.1197
Std Dev Absolute FNR Disparity BEFORE Correction: 0.1232
--------------------------------------------------
Mean Absolute FNR D



### Disparity in model performance

|    | Metric        |        Mean |   Std Dev |   t-test |     P value |
|---:|:--------------|------------:|----------:|---------:|------------:|
|  0 | accuracy_diff |  0.00512749 | 0.0454495 |  1.12817 | 0.261973    |
|  1 | roc_auc_diff  |  0.0203175  | 0.043054  |  4.71907 | 7.79602e-06 |
|  2 | FNR_diff      | -0.10364    | 0.123364  | -8.4012  | 3.29169e-13 |
|  3 | FPR_diff      |  0.0384224  | 0.0596807 |  6.438   | 4.3677e-09  |