# Disentangled Causal Effect Variational Autoencoder

**Inputs:**
- data/heart_disease_cleaned.csv

**Outputs:**
- DCEVEA model
- data/fair_disease_dcevae.csv
- data/cf_disease_dcevea.csv

## Setup and imports

In [1]:
try:
  from google.colab import userdata
  from google.colab import drive
  drive.mount('/content/drive')
  PROJECT_ROOT = userdata.get('PROJECT_ROOT')
except ImportError:
  PROJECT_ROOT = '/'


Mounted at /content/drive


In [66]:
import pandas as pd
import numpy as np
import torch
import re
from torch import nn

## Classes and functions





### DCEVAE Model

### Utils

In [87]:
def make_bucket_loader(dataset, map, val_size=0.1, test_size=0.1, seed=4):
  '''
    Create a DataLoader for the given dataset, separating features into \
    independent, sensitive, descendant, and correlated features.

    Input:
      - dataset: a pandas DataFrame
      - map: a dictionary mapping feature names to buckets
      - val_size: the proportion of the dataset to use for validation
      - test_size: the proportion of the dataset to use for testing
      - seed: a seed for the random number generator

    Output:
      - train_loader: Training DataLoader
      - val_loader: Validation DataLoader
      - test_loader: Testing DataLoader
  '''
  np.random.seed(seed=seed)

  ## BUCKET DATASET
  # Independent, Descendant, Correlated features
  r_ind = re.compile(f'{"|".join(map['ind'])}')
  X_ind = dataset.filter(regex=r_ind).to_numpy()
  r_desc = re.compile(f'{"|".join(map['desc'])}')
  X_desc = dataset.filter(regex=r_desc).to_numpy()
  r_corr = re.compile(f'{"|".join(map['corr'])}')
  X_corr = dataset.filter(regex=r_corr).to_numpy()

  # Sensitive attribute and Target
  X_sens = dataset[map['sens']].to_numpy().reshape(-1, 1)
  Y = dataset[map['target']].to_numpy().reshape(-1, 1)

  ## TRAIN-VAL-TRAIN SPLIT
  N = X_ind.shape[0]
  shuffled_indices = np.random.permutation(N)
  val_count = int(N * val_size)
  test_count = int(N * test_size)
  val_index = shuffled_indices[:val_count]
  test_index = shuffled_indices[val_count:val_count+test_count]
  train_index = shuffled_indices[val_count+test_count:]

  # Training set
  X_ind_train = X_ind[train_index]
  X_desc_train = X_desc[train_index]
  X_corr_train = X_corr[train_index]
  X_sens_train = X_sens[train_index]
  Y_train = Y[train_index]

  # Permuted set for the discriminator
  permuted_indices = np.random.permutation(X_ind_train.shape[0])
  X_ind_train_2 = X_ind[permuted_indices]
  X_desc_train_2 = X_desc[permuted_indices]
  X_corr_train_2 = X_corr[permuted_indices]
  X_sens_train_2 = X_sens[permuted_indices]
  Y_train_2 = Y[permuted_indices]

  # Validation set
  X_ind_val = X_ind[val_index]
  X_desc_val = X_desc[val_index]
  X_corr_val = X_corr[val_index]
  X_sens_val = X_sens[val_index]
  Y_val = Y[val_index]

  # Permuted set for the discriminator
  permuted_indices = np.random.permutation(X_ind_val.shape[0])
  X_ind_val_2 = X_ind[permuted_indices]
  X_desc_val_2 = X_desc[permuted_indices]
  X_corr_val_2 = X_corr[permuted_indices]
  X_sens_val_2 = X_sens[permuted_indices]
  Y_val_2 = Y[permuted_indices]

  # Test set
  X_ind_test = X_ind[test_index]
  X_desc_test = X_desc[test_index]
  X_corr_test = X_corr[test_index]
  X_sens_test = X_sens[test_index]
  Y_test = Y[test_index]

  # Permuted set for the discriminator
  permuted_indices = np.random.permutation(X_ind_test.shape[0])
  X_ind_test_2 = X_ind[permuted_indices]
  X_desc_test_2 = X_desc[permuted_indices]
  X_corr_test_2 = X_corr[permuted_indices]
  X_sens_test_2 = X_sens[permuted_indices]
  Y_test_2 = Y[permuted_indices]

  return

## Data preparation

In [88]:
heart_disease = pd.read_csv(PROJECT_ROOT + '/data/heart_disease_cleaned.csv')

# Hot-on encoding for categorical features
heart_disease_encoded = pd.get_dummies(heart_disease, columns=['cp','ecg','slope'], drop_first=True, dtype=int)

feature_mapping = {
    'ind': ['age'], # Features independent of the protected attribute and unconfounded
    'sens': 'sex', # Sensitive attribute
    'desc': ['cp', 'ecg', 'ang'], # Features descendant of the protected attribute
    'corr': ['bp', 'chol', 'fbs', 'mhr', 'st', 'slope'], # Features correlated with the protected attribute
    'target': 'cvd' # Target outcome
}


make_bucket_loader(heart_disease_encoded, feature_mapping)
heart_disease_encoded.head()

745


Unnamed: 0,age,sex,bp,chol,fbs,mhr,ang,st,cvd,cp_1,cp_2,cp_3,ecg_1,ecg_2,slope_1,slope_2
0,-1.354425,1,0.466063,0.825012,0,1.294379,0,0.0,0,1,0,0,0,0,0,0
1,-0.406667,0,1.511938,-1.188684,0,0.642349,0,1.0,1,0,1,0,0,0,1,0
2,-1.670345,1,-0.114382,0.735784,0,-1.721262,0,0.0,0,1,0,0,1,0,0,0
3,-0.511973,0,0.353365,-0.452823,0,-1.313743,1,1.5,1,0,0,1,0,0,1,0
4,0.119866,1,1.006445,-0.848258,0,-0.743216,0,0.0,0,0,1,0,0,0,0,0
