In [1]:
import os
# set wd to cancer
os.chdir("/work/postresearch/Shared/Researchers/Farbod/cancer/code/CausalDiff")
print(os.getcwd())

/work/postresearch/Shared/Researchers/Farbod/cancer/code/CausalDiff


In [2]:
import os
import sys
import math
import copy
import logging
import requests
import zipfile
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from itertools import chain
from typing import Union

from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.multioutput import MultiOutputClassifier, MultiOutputRegressor

from pytorch_lightning import LightningModule, Trainer
from omegaconf import DictConfig, OmegaConf
from omegaconf.errors import MissingMandatoryValue

import ray
from ray import tune, ray_constants

import hydra
from hydra import initialize, compose
from hydra.utils import instantiate

from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce

from IPython.core.interactiveshell import InteractiveShell

from src.data import RealDatasetCollection, SyntheticDatasetCollection
from src.models import TimeVaryingCausalModel
from src.models.utils import (
    grad_reverse,
    BRTreatmentOutcomeHead,
    AlphaRise,
    clip_normalize_stabilized_weights,
)
from src.models.utils_lstm import VariationalLSTM
from copy import deepcopy

InteractiveShell.ast_node_interactivity = "all"

In [3]:
# show pytorch version and cuda version
print(torch.__version__)
print(torch.version.cuda)

if torch.cuda.is_available():
    num_gpus = torch.cuda.device_count()
    print(f"Number of available GPUs: {num_gpus}")
    for i in range(num_gpus):
        gpu_name = torch.cuda.get_device_name(i)
        print(f"GPU {i}: {gpu_name}")
# else:
    print("CUDA is not available. No GPUs detected.")

2.0.0+cu118
11.8
Number of available GPUs: 1
GPU 0: NVIDIA A100 80GB PCIe
CUDA is not available. No GPUs detected.


# MSM


In [4]:
import logging
import hydra
import torch
from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate
from pytorch_lightning import seed_everything

from src.models.utils import FilteringMlFlowLogger
from src.models.msm import MSM


logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
torch.set_default_dtype(torch.double)

# Load the configuration file
config_path = '/work/postresearch/Shared/Researchers/Farbod/cancer/code/CausalDiff/config/cancer_sim_MSM.yaml'
args = OmegaConf.load(config_path)
# @hydra.main(config_name=f'config.yaml', config_path='../config/')


def main(args: DictConfig):
    """
    Training / evaluation script for MSMs
    Args:
        args: arguments of run as DictConfig

    Returns: dict with results (one and nultiple-step-ahead RMSEs)
    """

    results = {}

    # Non-strict access to fields
    OmegaConf.set_struct(args, False)
    OmegaConf.register_new_resolver("sum", lambda x, y: x + y, replace=True)
    logger.info('\n' + OmegaConf.to_yaml(args, resolve=True))

    # Initialisation of data to calculate dim_outcomes, dim_treatments, dim_vitals and dim_static_features
    seed_everything(args.exp.seed)
    dataset_collection = instantiate(args.dataset, _recursive_=True)
    # Only binary multilabel regime possible
    assert args.dataset.treatment_mode == 'multilabel'
    dataset_collection.process_data_multi()
    args.model.dim_outcomes = dataset_collection.train_f.data['outputs'].shape[-1]
    args.model.dim_treatments = dataset_collection.train_f.data['current_treatments'].shape[-1]
    args.model.dim_vitals = dataset_collection.train_f.data[
        'vitals'].shape[-1] if dataset_collection.has_vitals else 0
    args.model.dim_static_features = dataset_collection.train_f.data['static_features'].shape[-1]

    # MlFlow Logger
    if args.exp.logging:
        experiment_name = f'{args.model.name}/{args.dataset.name}'
        mlf_logger = FilteringMlFlowLogger(filter_submodels=MSM.possible_model_types, experiment_name=experiment_name,
                                           tracking_uri=args.exp.mlflow_uri)
    else:
        mlf_logger = None

    # ============================== Nominator (treatment propensity network) ==============================
    propensity_treatment = instantiate(
        args.model.propensity_treatment, args, dataset_collection, _recursive_=False)
    mlf_logger.log_hyperparams(propensity_treatment.hparams)
    propensity_treatment.fit()

    # ============================== Denominator (history propensity network) ==============================
    propensity_history = instantiate(
        args.model.propensity_history, args, dataset_collection, _recursive_=False)
    mlf_logger.log_hyperparams(propensity_history.hparams)
    propensity_history.fit()

    # ============================== Initialisation & Training of Encoder ==============================
    msm_regressor = instantiate(args.model.msm_regressor, args, propensity_treatment, propensity_history, dataset_collection,
                                _recursive_=False)
    mlf_logger.log_hyperparams(msm_regressor.hparams)
    msm_regressor.fit()
    encoder_results = {}

    if hasattr(dataset_collection, 'test_cf_one_step'):  # Test one_step_counterfactual rmse
        test_rmse_orig, test_rmse_all, test_rmse_last = \
            msm_regressor.get_normalised_masked_rmse(
                dataset_collection.test_cf_one_step, one_step_counterfactual=True)
        logger.info(f'Test normalised RMSE (all): {test_rmse_all}; '
                    f'Test normalised RMSE (orig): {test_rmse_orig}; '
                    f'Test normalised RMSE (only counterfactual): {test_rmse_last}')
        encoder_results = {
            'encoder_test_rmse_all': test_rmse_all,
            'encoder_test_rmse_orig': test_rmse_orig,
            'encoder_test_rmse_last': test_rmse_last
        }
    elif hasattr(dataset_collection, 'test_f'):  # Test factual rmse
        test_rmse_orig, test_rmse_all = msm_regressor.get_normalised_masked_rmse(
            dataset_collection.test_f)
        logger.info(f'Test normalised RMSE (all): {test_rmse_all}; '
                    f'Test normalised RMSE (orig): {test_rmse_orig}.')
        encoder_results = {
            # 'encoder_val_rmse_all': val_rmse_all,
            # 'encoder_val_rmse_orig': val_rmse_orig,
            'encoder_test_rmse_all': test_rmse_all,
            'encoder_test_rmse_orig': test_rmse_orig
        }

    mlf_logger.log_metrics(encoder_results) if args.exp.logging else None
    results.update(encoder_results)

    test_rmses = {}
    if hasattr(dataset_collection, 'test_cf_treatment_seq'):  # Test n_step_counterfactual rmse
        test_rmses = msm_regressor.get_normalised_n_step_rmses(
            dataset_collection.test_cf_treatment_seq)
    elif hasattr(dataset_collection, 'test_f_multi'):  # Test n_step_factual rmse
        test_rmses = msm_regressor.get_normalised_n_step_rmses(
            dataset_collection.test_f_multi)
    test_rmses = {f'{k+2}-step': v for (k, v) in enumerate(test_rmses)}

    logger.info(f'Test normalised RMSE (n-step prediction): {test_rmses}')
    decoder_results = {('decoder_test_rmse_' + k)
                        : v for (k, v) in test_rmses.items()}

    mlf_logger.log_metrics(decoder_results) if args.exp.logging else None
    results.update(decoder_results)

    mlf_logger.experiment.set_terminated(
        mlf_logger.run_id) if args.exp.logging else None

    return results


if __name__ == "__main__":
    main(args)

INFO:__main__:
dataset:
  _target_: src.data.SyntheticCancerDatasetCollection
  name: tumor_generator
  coeff: 4
  chemo_coeff: 4
  radio_coeff: 4
  seed: 100
  num_patients:
    train: 10000
    val: 1000
    test: 100
  window_size: 15
  lag: 0
  max_seq_length: 60
  projection_horizon: 5
  cf_seq_mode: sliding_treatment
  val_batch_size: 512
  treatment_mode: multilabel
model:
  dim_treatments: 4
  dim_vitals: 0
  dim_static_features: 1
  dim_outcomes: 1
  min_length: 1
  lag_features: 1
  name: MSM
  propensity_treatment:
    _target_: src.models.msm.MSMPropensityTreatment
  propensity_history:
    _target_: src.models.msm.MSMPropensityHistory
  msm_regressor:
    _target_: src.models.msm.MSMRegressor
exp:
  unscale_rmse: true
  percentage_rmse: true
  seed: 100
  max_epochs: 100
  gpus:
  - 0
  logging: true
  mlflow_uri: http://127.0.0.1:8081
  alpha: 1.0
  update_alpha: false
  alpha_rate: exp
  balancing: null
  bce_weight: false
  weights_ema: null
  beta: 0.99

Global seed se

INFO:root:Got correlated params for 13793 patients
INFO:root:Simulating beta_c parameters
INFO:root:Randomising outputs
  if recovery_rvs[i, t] < np.exp(-cancer_volume[i, t] * TUMOUR_CELL_DENSITY):
100%|██████████| 10000/10000 [00:12<00:00, 795.51it/s]
INFO:root:Simulating initial volumes for stage I  with norm params: mu=1.72, sigma=4.7, lb=-0.6221218732608373, ub=-0.023523848418276528
INFO:root:Simulating initial volumes for stage II  with norm params: mu=1.96, sigma=1.63, lb=-1.9410876100159118, ub=0.3711345751297772
INFO:root:Simulating initial volumes for stage IIIA  with norm params: mu=1.91, sigma=9.4, lb=-0.33127370258786554, ub=0.06967546355973796
INFO:root:Simulating initial volumes for stage IIIB  with norm params: mu=2.76, sigma=6.87, lb=-0.5769974969906748, ub=-0.02839165102452155
INFO:root:Simulating initial volumes for stage IV  with norm params: mu=3.86, sigma=8.82, lb=-0.5741465764541877, ub=-0.14683113860980307
INFO:root:Got correlated params for 444 patients
INFO:roo

Call to simulate counterfactuals data


100%|██████████| 100/100 [00:02<00:00, 43.78it/s]
INFO:src.data.cancer_sim.dataset:Processing train dataset before training
INFO:src.data.cancer_sim.dataset:Shape of processed train data: {'cancer_volume': (10000, 60), 'chemo_dosage': (10000, 60), 'radio_dosage': (10000, 60), 'chemo_application': (10000, 60), 'radio_application': (10000, 60), 'chemo_probabilities': (10000, 60), 'radio_probabilities': (10000, 60), 'sequence_lengths': (10000,), 'death_flags': (10000, 60), 'recovery_flags': (10000, 60), 'patient_types': (10000,), 'prev_treatments': (10000, 59, 2), 'current_treatments': (10000, 59, 2), 'current_covariates': (10000, 59, 2), 'outputs': (10000, 59, 1), 'active_entries': (10000, 59, 1), 'unscaled_outputs': (10000, 59, 1), 'prev_outputs': (10000, 59, 1), 'static_features': (10000, 1)}
INFO:src.data.cancer_sim.dataset:Processing val dataset before training
INFO:src.data.cancer_sim.dataset:Shape of processed val data: {'cancer_volume': (1000, 60), 'chemo_dosage': (1000, 60), 'rad

🏃 View run able-shoat-598 at: http://127.0.0.1:8081/#/experiments/997620564669772123/runs/fc7d02f5a5ab46d2961e07eed761e1aa
🧪 View experiment at: http://127.0.0.1:8081/#/experiments/997620564669772123


{'encoder_test_rmse_all': 3.770219048532587,
 'encoder_test_rmse_orig': 2.9484263938352724,
 'encoder_test_rmse_last': 1.305809047323435,
 'decoder_test_rmse_2-step': 2.3833088435242646,
 'decoder_test_rmse_3-step': 2.5292776897711144,
 'decoder_test_rmse_4-step': 2.5056101764583745,
 'decoder_test_rmse_5-step': 2.420787517464479,
 'decoder_test_rmse_6-step': 2.293027285986994}