In [1]:
#| default_exp _experiments.heart

# Heart disease

Heart Disease [1] is a classification dataset used for predicting the presence of heart disease with 13 features (age, sex, cp, trestbps, chol, fbs, restecg, thalach, exang, oldpeak, slope, ca, thal) and monotonically increasing with respect to features- trestbps and cholestrol (chol). The `monotonicity_indicator` corrsponding to these features are set to 1. 


References:


1.   John H. Gennari, Pat Langley, and Douglas H. Fisher. Models of incremental concept formation. Artif. Intell., 40(1-3):11–61, 1989.

  https://archive.ics.uci.edu/ml/datasets/heart+disease

2.   Aishwarya Sivaraman, Golnoosh Farnadi, Todd Millstein, and Guy Van den Broeck. Counterexample-guided learning of monotonic neural networks. Advances in Neural Information Processing Systems, 33:11936–11948, 2020


In [2]:
#| include: false

from mono_dense_keras.experiments import get_train_n_test_data, find_hyperparameters, create_tuner_stats

In [3]:
#| include: false

from os import environ

In [4]:
#| include: false

environ["TF_FORCE_GPU_ALLOW_GROWTH"] = "true"

These are a few examples of the dataset:

In [5]:
#| echo: false

train_df, test_df = get_train_n_test_data(dataset_name="heart")
display(train_df.head().T.style)

Unnamed: 0,0,1,2,3,4
age,0.972778,1.415074,1.415074,-1.902148,-1.459852
sex,0.649445,0.649445,0.649445,0.649445,-1.533413
cp,-2.020077,0.884034,0.884034,-0.084003,-1.05204
trestbps,0.721008,1.543527,-0.649858,-0.101512,-0.101512
chol,-0.251855,0.740555,-0.326754,0.066465,-0.794872
fbs,2.426901,-0.410346,-0.410346,-0.410346,-0.410346
restecg,1.070838,1.070838,1.070838,-0.953715,1.070838
thalach,-0.025055,-1.831151,-0.928103,1.56603,0.920995
exang,-0.72101,1.381212,1.381212,-0.72101,-0.72101
oldpeak,0.98644,0.330395,1.232457,1.970508,0.248389


In [6]:
tuner = find_hyperparameters(
    "heart",
    monotonicity_indicator = {
        "age": 0,
        "sex": 0,
        "cp": 0,
        "trestbps": 1,
        "chol": 1,
        "fbs": 0,
        "restecg": 0,
        "thalach": 0,
        "exang": 0,
        "oldpeak": 0,
        "slope": 0,
        "ca": 0,
        "thal": 0,
    },
    max_trials=100,
    final_activation="sigmoid",
    loss = "binary_crossentropy",
    metrics = "accuracy",
    objective="val_accuracy",
)

Trial 100 Complete [00h 00m 24s]
val_accuracy: 0.8633879621823629

Best val_accuracy So Far: 0.8797814249992371
Total elapsed time: 00h 36m 01s
INFO:tensorflow:Oracle triggered exit


In [7]:
#| include: false

stats = create_tuner_stats(tuner)

Unnamed: 0,units,n_layers,activation,learning_rate,weight_decay,dropout,decay_rate,val_accuracy_mean,val_accuracy_std,val_accuracy_min,val_accuracy_max,params
0,29,3,elu,0.002105,0.135955,0.467131,0.936366,0.868852,0.0,0.868852,0.868852,2552


Unnamed: 0,units,n_layers,activation,learning_rate,weight_decay,dropout,decay_rate,val_accuracy_mean,val_accuracy_std,val_accuracy_min,val_accuracy_max,params
0,29,3,elu,0.002105,0.135955,0.467131,0.936366,0.868852,0.0,0.868852,0.868852,2552
1,16,3,elu,0.001,0.205112,0.5,1.0,0.881967,0.007331,0.868852,0.885246,897


Unnamed: 0,units,n_layers,activation,learning_rate,weight_decay,dropout,decay_rate,val_accuracy_mean,val_accuracy_std,val_accuracy_min,val_accuracy_max,params
2,20,1,elu,0.165676,0.24654,0.120052,0.889096,0.865574,0.007331,0.852459,0.868852,291
0,29,3,elu,0.002105,0.135955,0.467131,0.936366,0.868852,0.0,0.868852,0.868852,2552
1,16,3,elu,0.001,0.205112,0.5,1.0,0.881967,0.007331,0.868852,0.885246,897


Unnamed: 0,units,n_layers,activation,learning_rate,weight_decay,dropout,decay_rate,val_accuracy_mean,val_accuracy_std,val_accuracy_min,val_accuracy_max,params
2,20,1,elu,0.165676,0.24654,0.120052,0.889096,0.865574,0.007331,0.852459,0.868852,291
0,29,3,elu,0.002105,0.135955,0.467131,0.936366,0.868852,0.0,0.868852,0.868852,2552
3,22,1,elu,0.3,0.2335,0.179865,0.919013,0.868852,0.016393,0.852459,0.885246,317
1,16,3,elu,0.001,0.205112,0.5,1.0,0.881967,0.007331,0.868852,0.885246,897


Unnamed: 0,units,n_layers,activation,learning_rate,weight_decay,dropout,decay_rate,val_accuracy_mean,val_accuracy_std,val_accuracy_min,val_accuracy_max,params
2,20,1,elu,0.165676,0.24654,0.120052,0.889096,0.865574,0.007331,0.852459,0.868852,291
0,29,3,elu,0.002105,0.135955,0.467131,0.936366,0.868852,0.0,0.868852,0.868852,2552
3,22,1,elu,0.3,0.2335,0.179865,0.919013,0.868852,0.016393,0.852459,0.885246,317
4,26,2,elu,0.001,0.14437,0.5,0.9943,0.878689,0.008979,0.868852,0.885246,1377
1,16,3,elu,0.001,0.205112,0.5,1.0,0.881967,0.007331,0.868852,0.885246,897


Unnamed: 0,units,n_layers,activation,learning_rate,weight_decay,dropout,decay_rate,val_accuracy_mean,val_accuracy_std,val_accuracy_min,val_accuracy_max,params
2,20,1,elu,0.165676,0.24654,0.120052,0.889096,0.865574,0.007331,0.852459,0.868852,291
0,29,3,elu,0.002105,0.135955,0.467131,0.936366,0.868852,0.0,0.868852,0.868852,2552
3,22,1,elu,0.3,0.2335,0.179865,0.919013,0.868852,0.016393,0.852459,0.885246,317
5,32,2,elu,0.001,0.114788,0.446662,0.925633,0.87541,0.008979,0.868852,0.885246,2017
4,26,2,elu,0.001,0.14437,0.5,0.9943,0.878689,0.008979,0.868852,0.885246,1377
1,16,3,elu,0.001,0.205112,0.5,1.0,0.881967,0.007331,0.868852,0.885246,897


Unnamed: 0,units,n_layers,activation,learning_rate,weight_decay,dropout,decay_rate,val_accuracy_mean,val_accuracy_std,val_accuracy_min,val_accuracy_max,params
2,20,1,elu,0.165676,0.24654,0.120052,0.889096,0.865574,0.007331,0.852459,0.868852,291
0,29,3,elu,0.002105,0.135955,0.467131,0.936366,0.868852,0.0,0.868852,0.868852,2552
3,22,1,elu,0.3,0.2335,0.179865,0.919013,0.868852,0.016393,0.852459,0.885246,317
6,29,2,elu,0.002381,0.135978,0.41784,0.912855,0.872131,0.007331,0.868852,0.885246,1682
5,32,2,elu,0.001,0.114788,0.446662,0.925633,0.87541,0.008979,0.868852,0.885246,2017
4,26,2,elu,0.001,0.14437,0.5,0.9943,0.878689,0.008979,0.868852,0.885246,1377
1,16,3,elu,0.001,0.205112,0.5,1.0,0.881967,0.007331,0.868852,0.885246,897


Unnamed: 0,units,n_layers,activation,learning_rate,weight_decay,dropout,decay_rate,val_accuracy_mean,val_accuracy_std,val_accuracy_min,val_accuracy_max,params
2,20,1,elu,0.165676,0.24654,0.120052,0.889096,0.865574,0.007331,0.852459,0.868852,291
0,29,3,elu,0.002105,0.135955,0.467131,0.936366,0.868852,0.0,0.868852,0.868852,2552
3,22,1,elu,0.3,0.2335,0.179865,0.919013,0.868852,0.016393,0.852459,0.885246,317
6,29,2,elu,0.002381,0.135978,0.41784,0.912855,0.872131,0.007331,0.868852,0.885246,1682
5,32,2,elu,0.001,0.114788,0.446662,0.925633,0.87541,0.008979,0.868852,0.885246,2017
4,26,2,elu,0.001,0.14437,0.5,0.9943,0.878689,0.008979,0.868852,0.885246,1377
7,17,2,elu,0.001112,0.169679,0.5,1.0,0.881967,0.013716,0.868852,0.901639,680
1,16,3,elu,0.001,0.205112,0.5,1.0,0.881967,0.007331,0.868852,0.885246,897


Unnamed: 0,units,n_layers,activation,learning_rate,weight_decay,dropout,decay_rate,val_accuracy_mean,val_accuracy_std,val_accuracy_min,val_accuracy_max,params
2,20,1,elu,0.165676,0.24654,0.120052,0.889096,0.865574,0.007331,0.852459,0.868852,291
0,29,3,elu,0.002105,0.135955,0.467131,0.936366,0.868852,0.0,0.868852,0.868852,2552
3,22,1,elu,0.3,0.2335,0.179865,0.919013,0.868852,0.016393,0.852459,0.885246,317
6,29,2,elu,0.002381,0.135978,0.41784,0.912855,0.872131,0.007331,0.868852,0.885246,1682
5,32,2,elu,0.001,0.114788,0.446662,0.925633,0.87541,0.008979,0.868852,0.885246,2017
4,26,2,elu,0.001,0.14437,0.5,0.9943,0.878689,0.008979,0.868852,0.885246,1377
7,17,2,elu,0.001112,0.169679,0.5,1.0,0.881967,0.013716,0.868852,0.901639,680
1,16,3,elu,0.001,0.205112,0.5,1.0,0.881967,0.007331,0.868852,0.885246,897
8,16,3,elu,0.001,0.205682,0.5,1.0,0.881967,0.007331,0.868852,0.885246,897


Unnamed: 0,units,n_layers,activation,learning_rate,weight_decay,dropout,decay_rate,val_accuracy_mean,val_accuracy_std,val_accuracy_min,val_accuracy_max,params
2,20,1,elu,0.165676,0.24654,0.120052,0.889096,0.865574,0.007331,0.852459,0.868852,291
0,29,3,elu,0.002105,0.135955,0.467131,0.936366,0.868852,0.0,0.868852,0.868852,2552
9,9,4,elu,0.007938,0.234227,0.376129,0.964579,0.868852,0.011592,0.852459,0.885246,432
3,22,1,elu,0.3,0.2335,0.179865,0.919013,0.868852,0.016393,0.852459,0.885246,317
6,29,2,elu,0.002381,0.135978,0.41784,0.912855,0.872131,0.007331,0.868852,0.885246,1682
5,32,2,elu,0.001,0.114788,0.446662,0.925633,0.87541,0.008979,0.868852,0.885246,2017
4,26,2,elu,0.001,0.14437,0.5,0.9943,0.878689,0.008979,0.868852,0.885246,1377
7,17,2,elu,0.001112,0.169679,0.5,1.0,0.881967,0.013716,0.868852,0.901639,680
1,16,3,elu,0.001,0.205112,0.5,1.0,0.881967,0.007331,0.868852,0.885246,897
8,16,3,elu,0.001,0.205682,0.5,1.0,0.881967,0.007331,0.868852,0.885246,897


In [8]:
#| echo: false

stats.sort_values(by="val_accuracy_mean", ascending=False).head().reset_index(drop=True).T.style

Unnamed: 0,0,1,2,3,4
units,16,16,17,26,32
n_layers,3,3,2,2,2
activation,elu,elu,elu,elu,elu
learning_rate,0.001000,0.001000,0.001112,0.001000,0.001000
weight_decay,0.205112,0.205682,0.169679,0.144370,0.114788
dropout,0.500000,0.500000,0.500000,0.500000,0.446662
decay_rate,1.000000,1.000000,1.000000,0.994300,0.925633
val_accuracy_mean,0.881967,0.881967,0.881967,0.878689,0.875410
val_accuracy_std,0.007331,0.007331,0.013716,0.008979,0.008979
val_accuracy_min,0.868852,0.868852,0.868852,0.868852,0.868852
