# Training a counterfactually fair CAD model

## Counterfactual Fairness

Attempts to mitigate bias in a model by simply removing sensitive attributes from its training, i.e. fairness by unawareness, often fails due to bias 'leaking' through causal relationships between the sensitive attribute and other features retained in the data. The counterfactual fairness approach introduced by Kusner et al. (2017, in Advances in Neural Information Processing Systems, https://proceedings.neurips.cc/paper_files/paper/2017/file/a486cd07e4ac3d270571622f4f316ec5-Paper.pdf) addresses this limitation by deconvoluting the biased observed variables into a fair set of unbiased latent variables. It allows the model to only learn from information that is independent from the protected attribute, and neutralise both direct and proxy bias pathways.

### Notations and definitions

we adopt the following notations consistent with the Pearlian causal framework used by Kusner et al.:

- $S$ **Protected attribute**: The sensitive variable we wish to be fair toward
- $X$ **Observed features**: The set of features available in the dataset (e.g. Blood Pressure, Cholesterol)
- $U$ **Latent (unobserved) variables**: Unobserved variables that are independent of the protected attribute $A$
- $Y$ **Target**: The outcome we are predicting (e.g. Presence of CAD)
- $Y_{S \leftarrow s}$: The value of $Y$ under a counterfactual intervention where $S$ is set to $s$
- $M = (U, V, F)$ a causal model corresponding to the observed data, where $V \equiv X \cup Y \cup S$, and $F$ is the set of structural equations of the model

**Counterfactual Fairness:** A predictor $\hat{Y}$ is counterfactually fair if, for a specific individual, the probability distribution of the prediction is the same in the actual world as it would be in a counterfactual world where their protected attribute (e.g., Sex) was different.

Formally, for any value $a'$ of the protected attribute $A$:$$P(\hat{Y}_{S \leftarrow s} = y \mid X=x, S=s) = P(\hat{Y}_{S \leftarrow s'} = y \mid X=x, S=s')$$

---


## The experiment

Using the fairness-unaware models trained to predict Cardiovascular Disease in Straw et al. (2024, doi: [10.2196/46936](https://doi.org/10.2196/46936)) as baseline models, we will apply the fairness algorithm proposed by Kusner et al. to train a fair CAD predictor.

### The target bias
While Kusner focus on mitigating bias on tasks where the protected attribute should have no influence on the target outcome (e.g. sex and exam results), the clinical domain brings a new challenge. Indeed, protected attributes such as sex or race often encompass two variables: the clinically relevant biological attribute which can cause a disease to present differently across individuals, and the sociological attribute which has societal factors that can influence healthcare access, physician perception, diagnosis and care. A clinical outcome might be influenced by the former but should remain independent of the latter.

If we aim for fairness based on the high-level sex attribute, we risk removing legitimate clinical signals and degrading diagnostic accuracy. Therefore, our objective in experiment is to make the model counterfactually fair with regards to **sociological sex**.

### Hypothesis

By using counterfactual inference to model latent variables that are independent of the protected attribute, we can build a predictor that satisfies counterfactual fairness (i.e. ensuring that an individual’s predicted risk of CAD remains invariant to their sociological sex), while maintaining clinically acceptable predictive performance and reducing the False Negative Rate (FNR) disparity observed in baseline models.

---

## Causal models

From the feature set observed in the [Heart Disease (CAD) dataset](https://ieee-dataport.org/open-access/heart-disease-dataset-comprehensive), using clinical knowledge in the literature and strong assumptions, we create two causal models that we will compare in this experiment:

### Latent variables model

We hypothesise that clinical features in the CAD dataset are manifestations of a patient's **Age- and Sex-independent Physiological Integrity**, which we define as our fair latent variable $U$. It is independent of the protected attributes $S_{bio}$ and $S_{soc}$.

We postulate thatsubjective symptoms and clinician-dependent interpretations are influenced by Sociological Sex ($S_{soc}$). This creates an unfair pathway where the recorded value of a feature is not solely a manifestation of the patient's physiological state, but is also a product of external factors:
- Reporting bias: how a patient describes symptoms like chest pain based on gendered expectations
- Diagnostic bias: how a clinician interprets those symptoms, potentially mislabeling 'atypical' presentations in women

We assume that objective biomarkers such as cholesterol, maximum heart rate, resting blood pressure, fasting blood sugar, and measurements related to the ECG ST slope, are only influenced by the **biological sex** ($S_{bio}$), and are therefore fair pathways for the predictor.

Age is considered as the sole independent variable from $S_{bio}$ and $S_{soc}$.

**Structural equations:**

| Feature | Variable name | Variable type | Model distribution |
|:---|:---|:---|:---|
|Chest Pain Type | $CP$|Categorical|Categorical|
|Exercise Induced Angina | $Ang$|Binary|Bernoulli|
|Resting ECG Result | $ECG$|Categorical|Categorical|
|Max Heart Rate |$MHR$ |Numerical| Normal |
|Fasting Blood Sugar |$FBS$ |Binary|Bernoulli|
|Oldpeak |$ST$ |Numerical| Zero-Inflated Log-Normal |
|ST Slope |$Slope$ |Categorical|Categorical|
|Resting Blood Pressure | $BP$|Numerical| Normal |
|Serum Cholesterol | $Chol$ |Numerical| Log-Normal |
|Age- and Sex-independent Physiological Integrity | $U$|Numerical| Normal |
|Cardiovascular Disease (target outcome) | $Y$|Binary|Bernoulli|

<br>

- $CP = f_{CP}(U, S_{soc}, Age) \sim \text{Categorical}(\text{softmax}(\alpha_{CP} + \beta_{CP, U}U + \beta_{CP,S_{soc}}S_{soc} + \beta_{CP,Age}Age))$

- $Ang = f_{Ang}(U, S_{soc}, Age) \sim  \text{Bernoulli}(\text{invlogit}(\alpha_{Ang} + \beta_{Ang, U}U + \beta_{Ang,soc}S_{soc} + \beta_{Ang,Age}Age))$

- $ECG = f_{ECG}(U, S_{soc}, Age) \sim \text{Categorical}(\text{softmax}(\alpha_{ECG} + \beta_{ECG, U}U + \beta_{ECG,soc}S_{soc} + \beta_{ECG,Age}Age))$

- $MHR = f_{MHR}(U, S_{bio}, Age) \sim \mathcal{N}(\alpha_{MHR} + \beta_{MHR,U}U + \beta_{MHR,bio}S_{bio} + \beta_{MHR,Age}Age, \sigma_{MHR})$

- $FBS = f_{FBS}(U, S_{bio}, Age) \sim \text{Bernoulli}(\text{invlogit}(\alpha_{FBS} + \beta_{FBS,U}U + \beta_{FBS,bio}S_{bio} + \beta_{FBS,Age}Age))$

- $ST = f_{ST}(U, S_{bio}, Age) \sim \text{ZeroInflatedLogNormal}(\psi_{ST}, \mu_{ST}, \sigma_{ST})$

  - where $\psi_{ST} = \text{invlogit}(\alpha_{\psi} + \beta_{\psi,U}U + \beta_{\psi,bio}S_{bio} + \beta_{\psi,Age}Age)$
  - and $\mu_{ST} = \alpha_{\mu} + \beta_{\mu,U}U + \beta_{\mu,bio}S_{bio}  + \beta_{\mu,Age}Age$

- $Slope =f_{Slope}(U, S_{bio}, Age) \sim \text{Categorical}(\text{softmax}(\alpha_{Slope} + \beta_{Slope, U}U + \beta_{Slope,bio}S_{bio} + \beta_{Slope,Age}Age)) $

- $BP = f_{BP}(U, S_{bio}, Age) \sim \mathcal{N}(\alpha_{BP} + \beta_{BP,U}U + \beta_{BP,bio}S_{bio} + \beta_{BP,Age}Age, \sigma_{BP})$

- $Chol=  f_{Chol}(U, S_{bio}, Age) \sim \text{LogNormal}(\alpha_{Chol} + \beta_{Chol,U}U + \beta_{Chol,bio}S_{bio} + \beta_{Chol,Age}Age, \sigma_{Chol}$

- $Y = f_{Y}(U, S_{bio}, Age) \sim \text{Bernoulli}(\text{invlogit}(\alpha_{Y} + \beta_{Y, U}U + \beta_{Y,bio}S_{bio} + \beta_{Y,Age}Age)) $

- $U \sim \mathcal{N}(0, 1) $

<br>

### Additive error model




In [1]:
try:
  from google.colab import userdata
  from google.colab import drive
  drive.mount('/content/drive')
  PROJECT_ROOT = userdata.get('PROJECT_ROOT')
except ImportError:
  PROJECT_ROOT = '/'

Mounted at /content/drive


In [2]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import pymc as pm

In [9]:
heart_disease = pd.read_csv(f'{PROJECT_ROOT}data/heart_disease_cleveland_hungary.csv')

# Remove duplicates and null values, as per Straw et al.

rows_to_drop  = (heart_disease['ST slope'] == 0) | (heart_disease['cholesterol'] == 0) | (heart_disease['resting bp s'] == 0) | (heart_disease.duplicated(keep='first'))
heart_disease.drop(heart_disease[rows_to_drop].index, inplace=True)



heart_disease['s_bio'] = heart_disease['sex']
heart_disease.rename(columns={'sex':'s_soc', 'chest pain type':'cp', 'resting bp s':'bp', 'cholesterol':'chol',
                              'fasting blood sugar':'fbs', 'resting ecg':'ecg', 'max heart rate':'mhr', 'exercise angina':'ang',
                              'oldpeak':'st', 'ST slope':'slope', 'target':'cvd'}, inplace=True)

# Clip negative values of st (oldpeak)
heart_disease['st'] = heart_disease['st'].clip(lower=0)

# Z-score for age, bp and mhr
cont_variables = ['bp', 'mhr', 'age']
for var in cont_variables:
  heart_disease[var] = (heart_disease[var] - heart_disease[var].mean()) / heart_disease[var].std()

# Scaling for chol
log_chol = np.log(heart_disease['chol'])
heart_disease['chol_scaled'] = (log_chol - log_chol.mean()) / log_chol.std()

# Indexing cp and slope at 0
heart_disease['cp'] = heart_disease['cp'] - 1
heart_disease['slope'] = heart_disease['slope'] - 1

heart_disease.reset_index(drop=True, inplace=True)
heart_disease.describe()



Unnamed: 0,age,s_soc,cp,bp,chol,fbs,ecg,mhr,ang,st,slope,cvd,s_bio,chol_scaled
count,745.0,745.0,745.0,745.0,745.0,745.0,745.0,745.0,745.0,745.0,745.0,745.0,745.0,745.0
mean,-2.861246e-16,0.755705,2.163758,-1.192186e-18,244.702013,0.166443,0.640268,5.340992e-16,0.385235,0.902953,0.589262,0.47651,0.755705,2.279459e-15
std,1.0,0.429957,0.956037,1.0,59.165249,0.372728,0.8389,1.0,0.486978,1.072953,0.598288,0.499783,0.429957,1.0
min,-2.618103,0.0,0.0,-2.372254,85.0,0.0,0.0,-2.903068,0.0,0.0,0.0,0.0,0.0,-4.37978
25%,-0.7225862,1.0,1.0,-0.7521781,208.0,0.0,0.0,-0.7432164,0.0,0.0,0.0,0.0,1.0,-0.5737713
50%,0.1198656,1.0,2.0,-0.1735796,237.0,0.0,0.0,-0.009681997,0.0,0.5,1.0,0.0,1.0,-0.01865294
75%,0.6463981,1.0,3.0,0.405019,275.0,0.0,1.0,0.8053562,1.0,1.5,1.0,1.0,1.0,0.613824
max,2.541915,1.0,3.0,3.87661,603.0,1.0,2.0,2.516936,1.0,6.2,2.0,1.0,1.0,3.953099


In [10]:
with pm.Model() as causal_model:
  # 1. EVIDENCE & PRIORS:
  # Fixed predictors
  age = heart_disease['age'].values
  s_bio = heart_disease['s_bio'].values
  s_soc = heart_disease['s_soc'].values
  N = len(heart_disease)

  # The prior for U, the age- and sex- independent physiological integrity
  # Non-centred parametrisation
  U_raw = pm.Normal('U_raw', mu=0, sigma=1, shape=N)
  sigma_U = pm.HalfNormal('sigma_U', sigma=1)
  U = pm.Deterministic('U', U_raw*sigma_U)

  # 2. LIKELIHOODS: Variables likelihoods
  alpha_cp = pm.Normal('alpha_cp', mu=0, sigma=1, shape=4)
  beta_U_cp = pm.math.concatenate([[0],pm.Normal('beta_U_cp', mu=0, sigma=.25, shape=3)])
  beta_soc_cp = pm.math.concatenate([[0],pm.Normal('beta_soc_cp', mu=0, sigma=.25, shape=3)])
  beta_age_cp = pm.math.concatenate([[0],pm.Normal('beta_age_cp', mu=0, sigma=.25, shape=3)])
  cp = pm.Categorical('cp',
            logit_p=alpha_cp + beta_U_cp*U[:,None] + beta_soc_cp*s_soc[:,None] + beta_age_cp*age[:,None],
            observed=heart_disease['cp'])

  alpha_ang = pm.Normal('alpha_ang', mu=0, sigma=1)
  beta_U_ang = pm.Normal('beta_U_ang', mu=0, sigma=.25)
  beta_soc_ang = pm.Normal('beta_soc_ang', mu=0, sigma=.25)
  beta_age_ang = pm.Normal('beta_age_ang', mu=0, sigma=.25)
  ang = pm.Bernoulli('ang',
            logit_p=alpha_ang + beta_U_ang*U + beta_soc_ang*s_soc + beta_age_ang*age,
            observed=heart_disease['ang'])

  alpha_ecg = pm.Normal('alpha_ecg', mu=0, sigma=1, shape=3)
  beta_U_ecg = pm.math.concatenate([[0],pm.Normal('beta_U_ecg', mu=0, sigma=.25, shape=2)])
  beta_soc_ecg = pm.math.concatenate([[0],pm.Normal('beta_soc_ecg', mu=0, sigma=.25, shape=2)])
  beta_age_ecg = pm.math.concatenate([[0],pm.Normal('beta_age_ecg', mu=0, sigma=.25, shape=2)])
  ecg = pm.Categorical('ecg',
            logit_p=alpha_ecg + beta_U_ecg*U[:,None] + beta_soc_ecg*s_soc[:,None] + beta_age_ecg*age[:,None],
            observed=heart_disease['ecg'])

  alpha_mhr = pm.Normal('alpha_mhr', mu=0, sigma=1)
  beta_U_mhr = pm.Normal('beta_U_mhr', mu=0, sigma=.25)
  beta_bio_mhr = pm.Normal('beta_bio_mhr', mu=0, sigma=.25)
  beta_age_mhr = pm.Normal('beta_age_mhr', mu=0, sigma=.25)
  sigma_mhr = pm.HalfNormal('sigma_mhr', sigma=1)
  mhr = pm.Normal('mhr',
            mu=alpha_mhr + beta_U_mhr*U + beta_bio_mhr*s_bio + beta_age_mhr*age,
            sigma=sigma_mhr,
            observed=heart_disease['mhr'])

  alpha_fbs = pm.Normal('alpha_fbs', mu=0, sigma=1)
  beta_U_fbs = pm.Normal('beta_U_fbs', mu=0, sigma=.25)
  beta_bio_fbs = pm.Normal('beta_bio_fbs', mu=0, sigma=.25)
  beta_age_fbs = pm.Normal('beta_age_fbs', mu=0, sigma=.25)
  fbs = pm.Bernoulli('fbs',
            logit_p=alpha_fbs + beta_U_fbs*U + beta_bio_fbs*s_bio + beta_age_fbs*age,
            observed=heart_disease['fbs'])

  alpha_slope = pm.Normal('alpha_slope', mu=0, sigma=1, shape=3)
  beta_U_slope = pm.math.concatenate([[0],pm.Normal('beta_U_slope', mu=0, sigma=.25, shape=2)])
  beta_bio_slope = pm.math.concatenate([[0],pm.Normal('beta_bio_slope', mu=0, sigma=.25, shape=2)])
  beta_age_slope = pm.math.concatenate([[0],pm.Normal('beta_age_slope', mu=0, sigma=.25, shape=2)])
  slope = pm.Categorical('slope',
            logit_p=alpha_slope + beta_U_slope*U[:,None] + beta_bio_slope*s_bio[:,None] + beta_age_slope*age[:,None],
            observed=heart_disease['slope'])

  alpha_bp = pm.Normal('alpha_bp', mu=0, sigma=1)
  beta_U_bp = pm.Normal('beta_U_bp', mu=0, sigma=.25)
  beta_bio_bp = pm.Normal('beta_bio_bp', mu=0, sigma=.25)
  beta_age_bp = pm.Normal('beta_age_bp', mu=0, sigma=.25)
  sigma_bp = pm.HalfNormal('sigma_bp', sigma=1)
  bp = pm.Normal('bp',
            mu=alpha_bp + beta_U_bp*U + beta_bio_bp*s_bio + beta_age_bp*age,
            sigma=sigma_bp,
            observed=heart_disease['bp'])

  alpha_chol = pm.Normal('alpha_chol', mu=0, sigma=1)
  beta_U_chol = pm.Normal('beta_U_chol', mu=0, sigma=.25)
  beta_bio_chol = pm.Normal('beta_bio_chol', mu=0, sigma=.25)
  beta_age_chol = pm.Normal('beta_age_chol', mu=0, sigma=.25)
  sigma_chol = pm.HalfNormal('sigma_chol', sigma=1)
  chol = pm.Normal('chol',
            mu=alpha_chol + beta_U_chol*U + beta_bio_chol*s_bio + beta_age_chol*age,
            sigma=sigma_chol,
            observed=heart_disease['chol_scaled'])

  alpha_cvd = pm.Normal('alpha_cvd', mu=0, sigma=1)
  beta_bio_cvd = pm.Normal('beta_bio_cvd', mu=0, sigma=.25)
  beta_age_cvd = pm.Normal('beta_age_cvd', mu=0, sigma=.25)
  cvd = pm.Bernoulli('cvd',
            logit_p=alpha_cvd - U + beta_bio_cvd*s_bio + beta_age_cvd*age, #we force U to be invertly related to CVD
            observed=heart_disease['cvd'])

  # CUSTOM DIST. for Oldpeak (ST): Zero-inflated Log-normal
  # Probability of being exactly zero
  alpha_zero_st = pm.Normal('alpha_zero_st', mu=0, sigma=.5)
  beta_U_zero_st = pm.Normal('beta_U_zero_st', mu=0, sigma=.25)
  beta_bio_zero_st = pm.Normal('beta_bio_zero_st', mu=0, sigma=.25)
  beta_age_zero_st = pm.Normal('beta_age_zero_st', mu=0, sigma=.25)
  psi_st = pm.math.invlogit(alpha_zero_st + beta_U_zero_st*U + beta_bio_zero_st*s_bio + beta_age_zero_st*age)

  # Log-normal parameters
  non_zero_st = heart_disease['st'][heart_disease['st'] > 0]
  alpha_init_st = np.log(non_zero_st.mean()) if len(non_zero_st) > 0 else 0
  alpha_st = pm.Normal('alpha_st', mu=0, sigma=1, initval=alpha_init_st)
  beta_U_st = pm.Normal('beta_U_st', mu=0, sigma=.25)
  beta_U_bio_st = pm.Normal('beta_U_bio_st', mu=0, sigma=.25)
  beta_U_age_st = pm.Normal('beta_U_age_st', mu=0, sigma=.25)
  sigma_st = pm.HalfNormal('sigma_st', sigma=1)
  mu_st = alpha_st + beta_U_st*U + beta_U_bio_st*s_bio + beta_U_age_st*age
  sigma_st = sigma_st

  # Probability function for the custom distribution
  def zero_inflated_lognormal(value, psi, mu, sigma):
    # Log-probability for the zero cases
    logp_zero = pm.math.log(psi)

    # Log-probability for the non-zero cases
    safe_value = pm.math.switch(pm.math.eq(value, 0), 1.0, value)
    logp_nonzero = pm.math.log(1 - psi) + pm.logp(pm.LogNormal.dist(mu=mu, sigma=sigma), safe_value)

    return pm.math.switch(pm.math.eq(value, 0), logp_zero, logp_nonzero)

  # Define the custom distribution
  st = pm.CustomDist('st',
                     psi_st, mu_st, sigma_st,
                     logp=zero_inflated_lognormal,
                     observed=heart_disease['st'])


  # 3. INFERENCE of U with MCMC
  posterior_samples = pm.sample(target_accept=0.95)




Output()

In [13]:
import os
save_path = f'{PROJECT_ROOT}/counterfactual-fairness/results'
os.makedirs(save_path, exist_ok=True)

posterior_samples.to_netcdf(f'{save_path}/causal_model_posterior.nc')
print('Posterior samples saved')


Posterior samples saved


In [18]:
import arviz as az

posterior = az.from_netcdf(f'{PROJECT_ROOT}/counterfactual-fairness/results/causal_model_posterior.nc')

U_ess = az.summary(posterior, var_names=['U'])['ess_bulk']
print(U_ess.describe())

print(az.summary(posterior, filter_vars='regex', var_names=['^alpha_'])['ess_bulk'])
print(az.summary(posterior, filter_vars='regex', var_names=['^beta_U_'])['ess_bulk'])
print(az.summary(posterior, filter_vars='regex', var_names=['^beta_age_'])['ess_bulk'])
print(az.summary(posterior, filter_vars='regex', var_names=['^beta_soc_'])['ess_bulk'])
print(az.summary(posterior, filter_vars='regex', var_names=['^sigma_'])['ess_bulk'])



count     745.000000
mean     3488.947651
std       594.985605
min      2055.000000
25%      3061.000000
50%      3423.000000
75%      3882.000000
max      6289.000000
Name: ess_bulk, dtype: float64
alpha_ang         1265.0
alpha_bp          1763.0
alpha_chol        1879.0
alpha_cp[0]        931.0
alpha_cp[1]        984.0
alpha_cp[2]        923.0
alpha_cp[3]        895.0
alpha_cvd         1096.0
alpha_ecg[0]      1031.0
alpha_ecg[1]      1028.0
alpha_ecg[2]      1008.0
alpha_fbs         1863.0
alpha_mhr         1657.0
alpha_slope[0]    1418.0
alpha_slope[1]    1467.0
alpha_slope[2]    1408.0
alpha_st          1471.0
alpha_zero_st     1690.0
Name: ess_bulk, dtype: float64
beta_U_age_st      1711.0
beta_U_ang         1136.0
beta_U_bio_st      1631.0
beta_U_bp          1969.0
beta_U_chol        2100.0
beta_U_cp[0]       1522.0
beta_U_cp[1]       1484.0
beta_U_cp[2]       1446.0
beta_U_ecg[0]      2094.0
beta_U_ecg[1]      2529.0
beta_U_fbs         2600.0
beta_U_mhr         1127.0
beta_U_s

In [26]:
# AUGMENTED DATASET

# Define the number of samples (m) taken for each individual from the infered posterior
m = 100

# Extract the U samples from the posterior
# posterior_samples.posterior['U'].to_dataframe().reset_index()
samples = az.extract(posterior, var_names=['U'], combined=True, num_samples=100).to_dataframe().reset_index()

# Merge samples with original dataset
heart_disease_reset = heart_disease.reset_index(drop=True)
aug_heart_disease = heart_disease_reset.merge(samples, left_index=True, right_on='U_dim_0')

# Fair dataset:
# Keep only sex-independent variables (age), latent fair variable U, and outcome CVD
fair_heart_disease = aug_heart_disease[['age','U','cvd']]

fair_heart_disease.shape

(74500, 3)