In [1]:
import pandas as pd
from src.data.mimic_iii.real_dataset import MIMIC3RealDataset
from data_loaders.causal_transformer_dataloader import load_gsu_dataset

2024-01-06 14:14:18.071016: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
gsu_features_path = '/Users/jk1/temp/treatment_effects/preprocessing/gsu_Extraction_20220815_prepro_25112023_213851/preprocessed_features_25112023_213851.csv'
gsu_continuous_outcomes_path = '/Users/jk1/temp/treatment_effects/preprocessing/gsu_Extraction_20220815_prepro_25112023_213851/preprocessed_outcomes_continuous_25112023_213851.csv'

# Load dataset and preprocess

In [3]:
train_ds = load_gsu_dataset(gsu_features_path, gsu_continuous_outcomes_path)

In [4]:
from src.data import RealDatasetCollection


class MIMIC3RealDatasetCollection(RealDatasetCollection):
    """
    Dataset collection (train_f, val_f, test_f)
    """
    def __init__(self,
                 ds: MIMIC3RealDataset,
                 projection_horizon: int = 5,
                 autoregressive=True,
                 **kwargs):
        """
        Args:
            path: Path with MIMIC-3 dataset (HDFStore)
            min_seq_length: Min sequence lenght in cohort
            max_seq_length: Max sequence lenght in cohort
            seed: Seed for random cohort patient selection
            max_number: Maximum number of patients in cohort
            split: Ratio of train / val / test split
            projection_horizon: Range of tau-step-ahead prediction (tau = projection_horizon + 1)
            autoregressive:
        """
        super(MIMIC3RealDatasetCollection, self).__init__()
        self.train_f = ds
        
        # only declared for compatibility
        self.val_f = ds
        self.test_f = ds

        self.projection_horizon = projection_horizon
        self.has_vitals = True
        self.autoregressive = autoregressive
        self.processed_data_encoder = True

In [9]:
ds_collection = MIMIC3RealDatasetCollection(train_ds)

In [10]:
ds_collection.train_f.data

{'sequence_lengths': array([71, 71, 71, ..., 71, 71, 71]),
 'prev_treatments': array([[[4.],
         [4.],
         [4.],
         ...,
         [0.],
         [0.],
         [0.]],
 
        [[7.],
         [7.],
         [7.],
         ...,
         [4.],
         [4.],
         [4.]],
 
        [[7.],
         [7.],
         [7.],
         ...,
         [7.],
         [7.],
         [7.]],
 
        ...,
 
        [[7.],
         [7.],
         [7.],
         ...,
         [3.],
         [3.],
         [3.]],
 
        [[7.],
         [7.],
         [7.],
         ...,
         [7.],
         [7.],
         [7.]],
 
        [[7.],
         [7.],
         [7.],
         ...,
         [7.],
         [7.],
         [7.]]]),
 'vitals': array([[[-0.30182825, -0.64252919, -0.2857695 , ...,  0.01302788,
           1.        ,  1.15325429],
         [-0.30182825, -0.64252919, -0.2857695 , ...,  0.01302788,
           1.        ,  1.15325429],
         [-0.30182825, -0.64252919, -0.2857695 

In [11]:
# preprocess 
ds_collection.process_data_multi()

# Test dataset with model 

In [12]:
config_dict = {
    'model': {
        'name': 'CT',
        'multi': {
            '_target_': 'src.models.ct.CT',
            'max_seq_length': '${sum:${dataset.max_seq_length},${dataset.projection_horizon}}',
            'seq_hidden_units': None,
            'br_size': None,
            'fc_hidden_units': None,
            'dropout_rate': None,
            'num_layer': 1,
            'num_heads': 2,
            'max_grad_norm': None,
            'batch_size': None,
            'attn_dropout': True,
            'disable_cross_attention': False,
            'isolate_subnetwork': '_',
            'self_positional_encoding': {
                'absolute': False,
                'trainable': True,
                'max_relative_position': 15,
            },
            'optimizer': {
                'optimizer_cls': 'adam',
                'learning_rate': None,
                'weight_decay': 0.0,
                'lr_scheduler': False,
            },
            'augment_with_masked_vitals': True,
            'tune_hparams': False,
            'tune_range': 50,
            'hparams_grid': None,
            'resources_per_trial': None,
        }
    },
    'exp': {
        'weights_ema': True,
        'balancing': 'domain_confusion',
        'alpha': 0.01,
    }
}

# Example usage
print(config_dict['model']['multi']['max_seq_length'])  # Accessing a specific value in the dictionary


${sum:${dataset.max_seq_length},${dataset.projection_horizon}}


In [13]:
from omegaconf import OmegaConf
from hydra import initialize, compose

OmegaConf.register_new_resolver("sum", lambda x, y: x + y, replace=True)

with initialize(version_base=None, config_path="./config"):
    cfg=compose(config_name='config.yaml', overrides=['+backbone=CT', '+dataset=gsu', '+hparams=mimic3_diastolic_blood_pressure'])
    print(cfg)

{'model': {'dim_treatments': '???', 'dim_vitals': '???', 'dim_static_features': '???', 'dim_outcomes': '???', 'name': 'CT', 'multi': {'_target_': 'src.models.ct.CT', 'max_seq_length': '${sum:${dataset.max_seq_length},${dataset.projection_horizon}}', 'seq_hidden_units': 24, 'br_size': 22, 'fc_hidden_units': 22, 'dropout_rate': 0.2, 'num_layer': 2, 'num_heads': 3, 'max_grad_norm': None, 'batch_size': 64, 'attn_dropout': True, 'disable_cross_attention': False, 'isolate_subnetwork': '_', 'self_positional_encoding': {'absolute': False, 'trainable': True, 'max_relative_position': 30}, 'optimizer': {'optimizer_cls': 'adam', 'learning_rate': 0.0001, 'weight_decay': 0.0, 'lr_scheduler': False}, 'augment_with_masked_vitals': True, 'tune_hparams': False, 'tune_range': 50, 'hparams_grid': None, 'resources_per_trial': None}}, 'dataset': {'val_batch_size': 512, 'treatment_mode': 'multilabel', 'seed': '${exp.seed}', 'name': 'gsu', 'min_seq_length': 71, 'max_seq_length': 71, 'max_number': 5000, 'proje

In [14]:
cfg.model.dim_outcomes = ds_collection.train_f.data['outputs'].shape[-1]
cfg.model.dim_treatments = ds_collection.train_f.data['current_treatments'].shape[-1]
cfg.model.dim_vitals = ds_collection.train_f.data['vitals'].shape[-1] if ds_collection.has_vitals else 0
cfg.model.dim_static_features = ds_collection.train_f.data['static_features'].shape[-1]

In [15]:
cfg['dataset']['treatment_list'] = train_ds.treatment_names
cfg['dataset']['vital_list'] = train_ds.feature_names
cfg['dataset']['static_list'] = None
cfg['dataset']['outcome_list'] = train_ds.outcome_names

In [16]:
# from prior hyperparams
cfg['model']['multi']['seq_hidden_units'], cfg['model']['multi']['br_size'], cfg['model']['multi']['fc_hidden_units'], cfg['model']['multi']['dropout_rate']

(24, 22, 22, 0.2)

# Train model

In [17]:
from src.models import CT

model = CT(cfg, ds_collection)

In [18]:
import hydra
from pytorch_lightning.callbacks import LearningRateMonitor
from src.models.utils import AlphaRise, FilteringMlFlowLogger

# Train_callbacks
multimodel_callbacks = [AlphaRise(rate=cfg.exp.alpha_rate)]

# MlFlow Logger
if cfg.exp.logging:
    experiment_name = f'{cfg.model.name}/{cfg.dataset.name}'
    mlf_logger = FilteringMlFlowLogger(filter_submodels=[], experiment_name=experiment_name, tracking_uri=args.exp.mlflow_uri)
    multimodel_callbacks += [LearningRateMonitor(logging_interval='epoch')]
    artifacts_path = hydra.utils.to_absolute_path(mlf_logger.experiment.get_run(mlf_logger.run_id).info.artifact_uri)
else:
    mlf_logger = None
    artifacts_path = None

In [19]:
from pytorch_lightning import Trainer

cfg.exp.gpus = None
multimodel_trainer = Trainer(gpus=eval(str(cfg.exp.gpus)), logger=mlf_logger, max_epochs=cfg.exp.max_epochs,
                                 callbacks=multimodel_callbacks, terminate_on_nan=True,
                                 gradient_clip_val=cfg.model.multi.max_grad_norm)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [44]:
multimodel_trainer.fit(model)


  | Name                            | Type                       | Params
-------------------------------------------------------------------------------
0 | input_transformation            | Linear                     | 2.0 K 
1 | self_positional_encoding_k      | RelativePositionalEncoding | 488   
2 | self_positional_encoding_v      | RelativePositionalEncoding | 488   
3 | transformer_blocks              | ModuleList                 | 62.9 K
4 | output_dropout                  | Dropout                    | 0     
5 | br_treatment_outcome_head       | BRTreatmentOutcomeHead     | 1.9 K 
6 | treatments_input_transformation | Linear                     | 216   
7 | vitals_input_transformation     | Linear                     | 2.0 K 
8 | outputs_input_transformation    | Linear                     | 48    
9 | static_input_transformation     | Linear                     | 24    
-------------------------------------------------------------------------------
69.2 K    Trainable param

Training: -1it [00:00, ?it/s]

IndexError: index 820 is out of bounds for axis 0 with size 0

In [20]:
dl = model.train_dataloader()

In [5]:
index = 0
{k: v[index] for k, v in train_ds.data.items()}

{'sequence_lengths': 71,
 'prev_treatments': array([[4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [4.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
        [0.],
    

In [6]:
list(train_ds.data.items())[3]

('next_vitals',
 array([[[-0.30182825, -0.64252919, -0.2857695 , ...,  0.01302788,
           1.        ,  1.15325429],
         [-0.30182825, -0.64252919, -0.2857695 , ...,  0.01302788,
           1.        ,  1.15325429],
         [-0.30182825, -0.64252919, -0.2857695 , ...,  0.01302788,
           1.        ,  1.15325429],
         ...,
         [-0.30182825, -0.64252919, -0.2857695 , ...,  0.01302788,
           1.        ,  1.15325429],
         [-0.30182825, -0.64252919, -0.2857695 , ...,  0.01302788,
           1.        ,  1.15325429],
         [-0.30182825, -0.64252919, -0.2857695 , ...,  0.01302788,
           1.        ,  1.15325429]],
 
        [[-0.16476602, -0.16509521, -0.2857695 , ...,  1.06150768,
           1.        ,  0.74380008],
         [-0.16476602, -0.16509521, -0.2857695 , ...,  1.06150768,
           1.        ,  0.88028481],
         [-0.16476602, -0.16509521, -0.2857695 , ...,  1.06150768,
           1.        ,  0.88028481],
         ...,
         [ 0.3834

In [7]:
for k, v in train_ds.data.items():
    print(k)
    print(v.shape)

sequence_lengths
(2493,)
prev_treatments
(2493, 71, 1)
vitals
(2493, 71, 84)
next_vitals
(2493, 70, 84)
current_treatments
(2493, 71, 1)
static_features
(179496, 0)
active_entries
(2493, 71, 1)
outputs
(2493, 71, 1)
unscaled_outputs
(2493, 71, 1)
prev_outputs
(2493, 71, 1)


In [22]:
temp = next(iter(dl))

In [24]:
len(temp)

10

In [42]:
train_ds.data.items()

dict_items([('sequence_lengths', array([71, 71, 71, ..., 71, 71, 71])), ('prev_treatments', array([[[4.],
        [4.],
        [4.],
        ...,
        [0.],
        [0.],
        [0.]],

       [[7.],
        [7.],
        [7.],
        ...,
        [4.],
        [4.],
        [4.]],

       [[7.],
        [7.],
        [7.],
        ...,
        [7.],
        [7.],
        [7.]],

       ...,

       [[7.],
        [7.],
        [7.],
        ...,
        [3.],
        [3.],
        [3.]],

       [[7.],
        [7.],
        [7.],
        ...,
        [7.],
        [7.],
        [7.]],

       [[7.],
        [7.],
        [7.],
        ...,
        [7.],
        [7.],
        [7.]]])), ('vitals', array([[[-0.30182825, -0.64252919, -0.2857695 , ...,  0.01302788,
          1.        ,  1.15325429],
        [-0.30182825, -0.64252919, -0.2857695 , ...,  0.01302788,
          1.        ,  1.15325429],
        [-0.30182825, -0.64252919, -0.2857695 , ...,  0.01302788,
          1.      