# Disentangled Causal Effect Variational Autoencoder

**Inputs:**
- data/heart_disease_cleaned.csv

**Outputs:**
- DCEVEA model
- data/fair_disease_dcevae.csv
- data/cf_disease_dcevea.csv

## Setup and imports

In [None]:
import sys
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
  from google.colab import userdata
  from google.colab import drive
  drive.mount('/content/drive')
  PROJECT_ROOT = userdata.get('PROJECT_ROOT')
else:
  PROJECT_ROOT = '../'

sys.path.append(PROJECT_ROOT)

In [None]:
import torch
import pandas as pd
from src.config import Config
from src.data_loader import make_bucketed_loader
from src.model import DCEVAE
from src.train import train_dcevae
from src.test import test_dcevae
from src.utils import load_feature_mapping, setup_logger

In [None]:
from argparse import Namespace
import datetime

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MODEL_PATH = PROJECT_ROOT + '/models/dcevae.pt'
SEED = 4
BATCH_SIZE = 32
UC_DIM = 12
UD_DIM = 12
H_DIM = 5
ACT_FN = 'relu'
N_EPOCHS = 40
LEARNING_RATE = 0.01
CORR_RECON_ALPHA = 1
DESC_RECON_ALPHA = 1
PRED_ALPHA = 10
FAIR_BETA = 1
TC_BETA = 1
DISTILL_KL_ANN_N = 10
MAPPING = '../configs/uci_feature_mapping.json'
DATA = 'heart_disease_cleaned.csv'
EXP_NAME = datetime.datetime.now().strftime('%Y-%m-%d') + '-01'

args = Namespace(
    device=DEVICE,
    model_path=MODEL_PATH,
    seed=SEED,
    batch_size=BATCH_SIZE,
    uc_dim=UC_DIM,
    ud_dim=UD_DIM,
    h_dim=H_DIM,
    act_fn=ACT_FN,
    n_epochs=N_EPOCHS,
    lr = LEARNING_RATE,
    distill_kl_ann = DISTILL_KL_ANN_N,
    corr_a = CORR_RECON_ALPHA,
    desc_a = DESC_RECON_ALPHA,
    pred_a = PRED_ALPHA,
    fair_b = FAIR_BETA,
    tc_b = TC_BETA,
    mapping = MAPPING,
    data = DATA,
    exp_name = EXP_NAME
)

In [None]:
# Initialise logger
logger = setup_logger(PROJECT_ROOT + Config.LOG_DIR, args.exp_name)

#Load dataset
dataset = pd.read_csv(PROJECT_ROOT + Config.DATA_DIR + args.data)

#Load feature mapping
feature_mapping = load_feature_mapping(args.mapping)

# Bucketed data loaders for training , validation, and test
train_loader, val_loader, test_loader = make_bucketed_loader(dataset, feature_mapping)

# Feature metadata
ind_meta = feature_mapping['ind']
desc_meta = feature_mapping['desc']
corr_meta = feature_mapping['corr']
sens_meta = feature_mapping['sens']

model = DCEVAE(ind_meta, desc_meta, corr_meta, sens_meta, 
              args=args)

training_log, epoch_metrics_log = train_dcevae(
  model,
  train_loader,
  val_loader,
  logger,
  args
)

test_dcevae(model, test_loader, logger, args)


In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

training_metrics = pd.DataFrame(training_log)

fig, ax = plt.subplots(figsize=(8, 4))
sns.lineplot(x=training_metrics.index, y=training_metrics["avg_train_loss"], ax=ax, label='Train VAE Loss', errorbar=None)
sns.lineplot(x=training_metrics.index+.5, y=training_metrics["avg_val_loss"], ax=ax, label='Val VAE Loss', errorbar=None)
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Total VAE Loss')
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(8, 4))
sns.lineplot(x=training_metrics.index, y=training_metrics["avg_disc_loss"], ax=ax, label='Discriminator loss', errorbar=None)
sns.lineplot(x=training_metrics.index, y=training_metrics["avg_tc_loss"], ax=ax, label='VAE TC loss', errorbar=None)
plt.legend()
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.show()

In [None]:
plt.figure(figsize=(8, 3))
sns.lineplot(x=training_metrics.index, y=training_metrics['avg_distill_loss'])
plt.xlabel('Epoch')
plt.ylabel('Distillation Loss')
plt.show()