Causal Transformer for estimating counterfactual outcomes over time.
The project is built with following Python libraries:
- Pytorch-Lightning - deep learning models
- Hydra - simplified command line arguments management
- MlFlow - experiments tracking
First one needs to make the virtual environment and install all the requirements:
pip3 install virtualenv
python3 -m virtualenv -p python3 --always-copy venv
source venv/bin/activate
pip3 install -r requirements.txt
To start an experiments server, run:
mlflow server --port=5000
To access MlFLow web UI with all the experiments, connect via ssh:
ssh -N -f -L localhost:5000:localhost:5000 <username>@<server-link>
Then, one can go to local browser http://localhost:5000.
Main training script is universal for different models and datasets. For details on mandatory arguments - see the main configuration file config/config.yaml
and other files in configs/
folder.
Generic script with logging and fixed random seed is following (with training-type
enc_dec
, gnet
, rmsn
and multi
):
PYTHONPATH=. CUDA_VISIBLE_DEVICES=<devices>
python3 runnables/train_<training-type>.py +dataset=<dataset> +backbone=<backbone> exp.seed=10 exp.logging=True
One needs to choose a backbone and then fill the specific hyperparameters (they are left blank in the configs):
- Causal Transformer (this paper):
runnables/train_multi.py +backbone=ct
- Encoder-Decoder Causal Transformer (this paper):
runnables/train_enc_dec.py +backbone=edct
- Marginal Structural Models (MSMs):
runnables/train_msm.py +backbone=msm
- Recurrent Marginal Structural Networks (RMSNs):
runnables/train_rmsn.py +backbone=rmsn
- Counterfactual Recurrent Network (CRN):
runnables/train_enc_dec.py +backbone=crn
- G-Net:
runnables/train_gnet.py +backbone=gnet
Models already have best hyperparameters saved (for each model and dataset), one can access them via: +backbone/<backbone>_hparams/cancer_sim_<balancing_objective>=<coeff_value>
or +backbone/<backbone>_hparams/mimic3_real=diastolic_blood_pressure
.
For CT, EDCT, and CT, several adversarial balancing objectives are available:
- counterfactual domain confusion loss (this paper):
exp.balancing=domain_confusion
- gradient reversal (originally in CRN, but can be used for all the methods):
exp.balancing=grad_reverse
To train a decoder (for CRN and RMSNs), use the flag model.train_decoder=True
.
To perform a manual hyperparameter tuning use the flags model.<sub_model>.tune_hparams=True
, and then see model.<sub_model>.hparams_grid
. Use model.<sub_model>.tune_range
to specify the number of trials for random search.
One needs to specify a dataset / dataset generator (and some additional parameters, e.g. set gamma for cancer_sim
with dataset.coeff=1.0
):
- Synthetic Tumor Growth Simulator:
+dataset=cancer_sim
- MIMIC III Semi-synthetic Simulator (multiple treatments and outcomes):
+dataset=mimic3_synthetic
- MIMIC III Real-world dataset:
+dataset=mimic3_real
Before running MIMIC III experiments, place MIMIC-III-extract dataset (all_hourly_data.h5) to data/processed/
Example of running Causal Transformer on Synthetic Tumor Growth Generator with gamma = [1.0, 2.0, 3.0] and different random seeds (total of 30 subruns), using hyperparameters:
PYTHONPATH=. CUDA_VISIBLE_DEVICES=<devices>
python3 runnables/train_multi.py -m +dataset=cancer_sim +backbone=ct +backbone/ct_hparams/cancer_sim_domain_conf=\'0\',\'1\',\'2\' exp.seed=10,101,1010,10101,101010
New results for semi-synthetic and real-world experiments after fixing a bug with self- and cross-attentions (#7). Therein, the bug affected only Tables 1 and 2, and Figure 5 (https://arxiv.org/pdf/2204.07258.pdf). Nevertheless, the performance of the CT with the bug fixed did not change drastically.
Table 1 (updated). Results for semi-synthetic data for
MSMs | 0.37 ± 0.01 | 0.57 ± 0.03 | 0.74 ± 0.06 | 0.88 ± 0.03 | 1.14 ± 0.10 | 1.95 ± 1.48 | 3.44 ± 4.57 | > 10.0 | > 10.0 | > 10.0 |
RMSNs | 0.24 ± 0.01 | 0.47 ± 0.01 | 0.60 ± 0.01 | 0.70 ± 0.02 | 0.78 ± 0.04 | 0.84 ± 0.05 | 0.89 ± 0.06 | 0.94 ± 0.08 | 0.97 ± 0.09 | 1.00 ± 0.11 |
CRN | 0.30 ± 0.01 | 0.48 ± 0.02 | 0.59 ± 0.02 | 0.65 ± 0.02 | 0.68 ± 0.02 | 0.71 ± 0.01 | 0.72 ± 0.01 | 0.74 ± 0.01 | 0.76 ± 0.01 | 0.78 ± 0.02 |
G-Net | 0.34 ± 0.01 | 0.67 ± 0.03 | 0.83 ± 0.04 | 0.94 ± 0.04 | 1.03 ± 0.05 | 1.10 ± 0.05 | 1.16 ± 0.05 | 1.21 ± 0.06 | 1.25 ± 0.06 | 1.29 ± 0.06 |
EDCT (GR; |
0.29 ± 0.01 | 0.46 ± 0.01 | 0.56 ± 0.01 | 0.62 ± 0.01 | 0.67 ± 0.01 | 0.70 ± 0.01 | 0.72 ± 0.01 | 0.74 ± 0.01 | 0.76 ± 0.01 | 0.78 ± 0.01 |
CT ( |
0.20 ± 0.01 | 0.38 ± 0.01 | 0.46 ± 0.01 | 0.50 ± 0.01 | 0.52 ± 0.01 | 0.54 ± 0.01 | 0.56 ± 0.01 | 0.57 ± 0.01 | 0.59 ± 0.01 | 0.60 ± 0.01 |
CT (ours, fixed) | 0.21 ± 0.01 | 0.38 ± 0.01 | 0.46 ± 0.01 | 0.50 ± 0.01 | 0.53 ± 0.01 | 0.54 ± 0.01 | 0.55 ± 0.01 | 0.57 ± 0.01 | 0.58 ± 0.01 | 0.59 ± 0.01 |
Table 2 (updated). Results for experiments with real-world medical data (MIMIC-III). Shown: RMSE as mean ± standard deviation over five runs.
MSMs | 6.37 ± 0.26 | 9.06 ± 0.41 | 11.89 ± 1.28 | 13.12 ± 1.25 | 14.44 ± 1.12 |
RMSNs | 5.20 ± 0.15 | 9.79 ± 0.31 | 10.52 ± 0.39 | 11.09 ± 0.49 | 11.64 ± 0.62 |
CRN | 4.84 ± 0.08 | 9.15 ± 0.16 | 9.81 ± 0.17 | 10.15 ± 0.19 | 10.40 ± 0.21 |
G-Net | 5.13 ± 0.05 | 11.88 ± 0.20 | 12.91 ± 0.26 | 13.57 ± 0.30 | 14.08 ± 0.31 |
CT (ours, fixed) | 4.60 ± 0.08 | 9.01 ± 0.21 | 9.58 ± 0.19 | 9.89 ± 0.21 | 10.12 ± 0.22 |
Figure 6 (updated). Subnetworks importance scores based on semi-synthetic benchmark (higher values correspond to higher importance of subnetwork connectivity via cross-attentions). Shown: RMSE differences between model with isolated subnetwork and full CT, means ± standard errors.
New results after fixing a bug with the synthetic tumor-growth simulator: outcome corresponding to the last entry for every time series was zeroed.
Table 9 (updated). Normalized RMSE for one-step-ahead prediction. Shown: mean and standard deviation over five runs (lower is better). Parameter
MSMs | 1.091 ± 0.115 | 1.202 ± 0.108 | 1.383 ± 0.090 | 1.647 ± 0.121 | 1.981 ± 0.232 |
RMSNs | 0.834 ± 0.072 | 0.860 ± 0.025 | 1.000 ± 0.134 | 1.131 ± 0.057 | 1.434 ± 0.148 |
CRN | 0.755 ± 0.059 | 0.788 ± 0.057 | 0.881 ± 0.066 | 1.062 ± 0.088 | 1.358 ± 0.167 |
G-Net | 0.795 ± 0.066 | 0.841 ± 0.038 | 0.946 ± 0.083 | 1.057 ± 0.146 | 1.319 ± 0.248 |
CT ( |
0.772 ± 0.051 | 0.783 ± 0.071 | 0.862 ± 0.052 | 1.062 ± 0.119 | 1.331 ± 0.217 |
CT (ours) | 0.770 ± 0.049 | 0.783 ± 0.071 | 0.864 ± 0.059 | 1.098 ± 0.097 | 1.413 ± 0.259 |
Table 10 (updated). Normalized RMSE for
('2', 'MSMs') | 0.975 ± 0.063 | 1.183 ± 0.146 | 1.428 ± 0.274 | 1.673 ± 0.431 | 1.884 ± 0.637 |
('2', 'RMSNs') | 0.825 ± 0.057 | 0.851 ± 0.043 | 0.861 ± 0.078 | 0.993 ± 0.126 | 1.269 ± 0.294 |
('2', 'CRN') | 0.761 ± 0.058 | 0.760 ± 0.037 | 0.805 ± 0.050 | 2.045 ± 1.491 | 1.209 ± 0.192 |
('2', 'G-Net') | 1.006 ± 0.082 | 0.994 ± 0.086 | 1.185 ± 0.077 | 1.083 ± 0.145 | 1.243 ± 0.202 |
('2', 'CT ( |
0.766 ± 0.029 | 0.781 ± 0.066 | 0.814 ± 0.078 | 0.944 ± 0.144 | 1.191 ± 0.316 |
('2', 'CT (ours)') | 0.762 ± 0.028 | 0.781 ± 0.058 | 0.818 ± 0.091 | 1.001 ± 0.150 | 1.163 ± 0.233 |
('3', 'MSMs') | 0.937 ± 0.060 | 1.133 ± 0.158 | 1.344 ± 0.262 | 1.525 ± 0.400 | 1.564 ± 0.545 |
('3', 'RMSNs') | 0.824 ± 0.043 | 0.871 ± 0.036 | 0.857 ± 0.109 | 1.020 ± 0.140 | 1.267 ± 0.298 |
('3', 'CRN') | 0.769 ± 0.057 | 0.777 ± 0.037 | 0.826 ± 0.077 | 1.789 ± 1.108 | 1.356 ± 0.330 |
('3', 'G-Net') | 1.103 ± 0.092 | 1.097 ± 0.095 | 1.355 ± 0.107 | 1.225 ± 0.184 | 1.382 ± 0.242 |
('3', 'CT ( |
0.766 ± 0.037 | 0.806 ± 0.060 | 0.828 ± 0.106 | 0.996 ± 0.185 | 1.335 ± 0.465 |
('3', 'CT (ours)') | 0.762 ± 0.036 | 0.807 ± 0.056 | 0.838 ± 0.120 | 1.072 ± 0.196 | 1.283 ± 0.312 |
('4', 'MSMs') | 0.845 ± 0.060 | 1.022 ± 0.149 | 1.196 ± 0.233 | 1.325 ± 0.363 | 1.308 ± 0.482 |
('4', 'RMSNs') | 0.780 ± 0.046 | 0.834 ± 0.040 | 0.814 ± 0.123 | 0.988 ± 0.146 | 1.169 ± 0.269 |
('4', 'CRN') | 0.734 ± 0.061 | 0.743 ± 0.037 | 0.805 ± 0.096 | 1.567 ± 0.825 | 1.327 ± 0.293 |
('4', 'G-Net') | 1.092 ± 0.090 | 1.074 ± 0.098 | 1.385 ± 0.117 | 1.212 ± 0.202 | 1.358 ± 0.253 |
('4', 'CT ( |
0.730 ± 0.042 | 0.776 ± 0.056 | 0.802 ± 0.119 | 0.983 ± 0.208 | 1.394 ± 0.563 |
('4', 'CT (ours)') | 0.726 ± 0.041 | 0.777 ± 0.054 | 0.810 ± 0.128 | 1.075 ± 0.220 | 1.302 ± 0.356 |
('5', 'MSMs') | 0.747 ± 0.056 | 0.896 ± 0.136 | 1.038 ± 0.210 | 1.128 ± 0.320 | 1.155 ± 0.448 |
('5', 'RMSNs') | 0.717 ± 0.053 | 0.775 ± 0.041 | 0.747 ± 0.124 | 0.922 ± 0.141 | 1.057 ± 0.246 |
('5', 'CRN') | 0.678 ± 0.062 | 0.692 ± 0.037 | 0.761 ± 0.104 | 1.410 ± 0.604 | 1.242 ± 0.239 |
('5', 'G-Net') | 1.033 ± 0.086 | 1.014 ± 0.097 | 1.358 ± 0.118 | 1.160 ± 0.199 | 1.285 ± 0.242 |
('5', 'CT ( |
0.673 ± 0.044 | 0.722 ± 0.052 | 0.748 ± 0.124 | 0.931 ± 0.213 | 1.405 ± 0.648 |
('5', 'CT (ours)') | 0.669 ± 0.043 | 0.723 ± 0.053 | 0.751 ± 0.125 | 1.036 ± 0.238 | 1.264 ± 0.389 |
('6', 'MSMs') | 0.647 ± 0.055 | 0.778 ± 0.123 | 0.894 ± 0.188 | 0.952 ± 0.284 | 1.060 ± 0.432 |
('6', 'RMSNs') | 0.646 ± 0.058 | 0.702 ± 0.043 | 0.675 ± 0.121 | 0.847 ± 0.132 | 0.947 ± 0.225 |
('6', 'CRN') | 0.614 ± 0.057 | 0.631 ± 0.035 | 0.706 ± 0.104 | 1.308 ± 0.438 | 1.132 ± 0.194 |
('6', 'G-Net') | 0.963 ± 0.083 | 0.942 ± 0.090 | 1.321 ± 0.118 | 1.092 ± 0.183 | 1.195 ± 0.223 |
('6', 'CT ( |
0.609 ± 0.042 | 0.657 ± 0.046 | 0.684 ± 0.122 | 0.864 ± 0.201 | 1.383 ± 0.699 |
('6', 'CT (ours)') | 0.605 ± 0.040 | 0.657 ± 0.047 | 0.685 ± 0.119 | 0.979 ± 0.249 | 1.201 ± 0.419 |
Table 11 (updated). Normalized RMSE for
('2', 'MSMs') | 1.362 ± 0.109 | 1.612 ± 0.172 | 1.939 ± 0.365 | 2.290 ± 0.545 | 2.468 ± 1.058 |
('2', 'RMSNs') | 0.742 ± 0.043 | 0.760 ± 0.047 | 0.827 ± 0.056 | 0.957 ± 0.106 | 1.276 ± 0.240 |
('2', 'CRN') | 0.671 ± 0.066 | 0.666 ± 0.052 | 0.741 ± 0.042 | 1.668 ± 1.184 | 1.151 ± 0.166 |
('2', 'G-Net') | 1.021 ± 0.067 | 1.009 ± 0.092 | 1.271 ± 0.075 | 1.113 ± 0.149 | 1.257 ± 0.227 |
('2', 'CT ( |
0.685 ± 0.050 | 0.679 ± 0.044 | 0.714 ± 0.053 | 0.875 ± 0.105 | 1.072 ± 0.315 |
('2', 'CT (ours)') | 0.681 ± 0.052 | 0.677 ± 0.044 | 0.713 ± 0.042 | 0.908 ± 0.122 | 1.274 ± 0.366 |
('3', 'MSMs') | 1.679 ± 0.132 | 1.953 ± 0.208 | 2.302 ± 0.437 | 2.640 ± 0.639 | 2.622 ± 1.132 |
('3', 'RMSNs') | 0.783 ± 0.053 | 0.792 ± 0.047 | 0.889 ± 0.050 | 1.086 ± 0.175 | 1.382 ± 0.286 |
('3', 'CRN') | 0.700 ± 0.078 | 0.692 ± 0.046 | 0.818 ± 0.051 | 1.959 ± 1.032 | 1.360 ± 0.225 |
('3', 'G-Net') | 1.253 ± 0.079 | 1.226 ± 0.104 | 1.611 ± 0.102 | 1.383 ± 0.200 | 1.574 ± 0.328 |
('3', 'CT ( |
0.707 ± 0.053 | 0.711 ± 0.038 | 0.770 ± 0.043 | 0.969 ± 0.119 | 1.261 ± 0.462 |
('3', 'CT (ours)') | 0.703 ± 0.055 | 0.712 ± 0.040 | 0.770 ± 0.032 | 1.010 ± 0.119 | 1.536 ± 0.450 |
('4', 'MSMs') | 1.871 ± 0.145 | 2.145 ± 0.227 | 2.489 ± 0.471 | 2.791 ± 0.681 | 2.615 ± 1.142 |
('4', 'RMSNs') | 0.821 ± 0.079 | 0.837 ± 0.058 | 0.963 ± 0.106 | 1.216 ± 0.240 | 1.416 ± 0.304 |
('4', 'CRN') | 0.734 ± 0.087 | 0.722 ± 0.041 | 0.898 ± 0.068 | 2.201 ± 0.967 | 1.573 ± 0.255 |
('4', 'G-Net') | 1.390 ± 0.087 | 1.347 ± 0.112 | 1.819 ± 0.133 | 1.544 ± 0.243 | 1.769 ± 0.413 |
('4', 'CT ( |
0.729 ± 0.056 | 0.749 ± 0.033 | 0.826 ± 0.046 | 1.053 ± 0.147 | 1.426 ± 0.574 |
('4', 'CT (ours)') | 0.726 ± 0.057 | 0.748 ± 0.036 | 0.822 ± 0.036 | 1.089 ± 0.122 | 1.762 ± 0.523 |
('5', 'MSMs') | 1.963 ± 0.155 | 2.221 ± 0.231 | 2.547 ± 0.479 | 2.810 ± 0.684 | 2.542 ± 1.122 |
('5', 'RMSNs') | 0.855 ± 0.099 | 0.889 ± 0.074 | 1.030 ± 0.165 | 1.349 ± 0.326 | 1.434 ± 0.299 |
('5', 'CRN') | 0.769 ± 0.094 | 0.755 ± 0.039 | 0.976 ± 0.082 | 2.361 ± 1.000 | 1.730 ± 0.292 |
('5', 'G-Net') | 1.477 ± 0.092 | 1.430 ± 0.119 | 1.963 ± 0.157 | 1.667 ± 0.275 | 1.907 ± 0.471 |
('5', 'CT ( |
0.758 ± 0.055 | 0.788 ± 0.036 | 0.875 ± 0.056 | 1.118 ± 0.172 | 1.560 ± 0.663 |
('5', 'CT (ours)') | 0.756 ± 0.057 | 0.786 ± 0.039 | 0.870 ± 0.048 | 1.154 ± 0.111 | 1.922 ± 0.569 |
('6', 'MSMs') | 1.970 ± 0.155 | 2.205 ± 0.228 | 2.509 ± 0.469 | 2.732 ± 0.662 | 2.422 ± 1.084 |
('6', 'RMSNs') | 0.889 ± 0.112 | 0.936 ± 0.091 | 1.081 ± 0.211 | 1.473 ± 0.433 | 1.436 ± 0.290 |
('6', 'CRN') | 0.807 ± 0.097 | 0.790 ± 0.035 | 1.047 ± 0.092 | 2.480 ± 1.078 | 1.827 ± 0.326 |
('6', 'G-Net') | 1.538 ± 0.091 | 1.493 ± 0.121 | 2.062 ± 0.172 | 1.758 ± 0.286 | 1.994 ± 0.500 |
('6', 'CT ( |
0.790 ± 0.058 | 0.827 ± 0.036 | 0.915 ± 0.063 | 1.177 ± 0.193 | 1.654 ± 0.704 |
('6', 'CT (ours)') | 0.789 ± 0.059 | 0.821 ± 0.034 | 0.909 ± 0.054 | 1.205 ± 0.100 | 2.052 ± 0.608 |
Project based on the cookiecutter data science project template. #cookiecutterdatascience