In [1]:
import enum
import math

import matplotlib.pyplot as plt


import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from itertools import chain

import zipfile
import sys
import os
import requests
import pandas as pd
import pickle
import copy
from torch.utils.data import DataLoader, Dataset
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce

from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"

In [3]:
# show pytorch version and cuda version
print(torch.__version__)
print(torch.version.cuda)

2.5.0.dev20240808+cu121
12.1


## collecting and preparing data

In [4]:
%pwd

'/work/postresearch/Shared/Researchers/Farbod/cancer/code/CausalDiff'

In [5]:
# change working directory
os.chdir("/work/postresearch/Shared/Researchers/Farbod/cancer")

# Cancer data preparation

In [7]:
tau = 0
# read from pickle file
training_data = pd.read_pickle(f"training_data_tau_{tau}.pkl")
validation_data = pd.read_pickle(f"validation_data_tau_{tau}.pkl")

test_data_factuals = pd.read_pickle(f"test_data_factuals_tau_{tau}.pkl")
#:return: simulated data dict with number of rows equal to num_patients * seq_length * num_treatments
test_data_counterfactuals = pd.read_pickle(f"test_data_counterfactuals_tau_{tau}.pkl")
#:return: simulated data dict with number of rows equal to num_patients * seq_length * 2 * projection_horizon
test_data_seq = pd.read_pickle(f"test_data_seq_tau_{tau}.pkl")
means = pd.read_pickle(f"means_tau_{tau}.pkl")
stds = pd.read_pickle(f"stds_tau_{tau}.pkl")

In [8]:
means["chemo_application"] = 0
means["radio_application"] = 0
stds["chemo_application"] = 1
stds["radio_application"] = 1

In [9]:
means
stds

cancer_volume        9.796984
chemo_dosage         4.789335
radio_dosage         0.979020
patient_types        2.000600
chemo_application    0.000000
radio_application    0.000000
dtype: float64

cancer_volume        63.363042
chemo_dosage          2.948792
radio_dosage          0.999780
patient_types         0.819268
chemo_application     1.000000
radio_application     1.000000
dtype: float64

In [10]:
# scale the data
for key in training_data.keys():
    if key in means.keys():
        training_data[key] = (training_data[key] - means[key]) / stds[key]
        validation_data[key] = (validation_data[key] - means[key]) / stds[key]
        test_data_factuals[key] = (test_data_factuals[key] - means[key]) / stds[key]

for key in test_data_counterfactuals.keys():
    if key in means.keys():
        test_data_counterfactuals[key] = (
            test_data_counterfactuals[key] - means[key]
        ) / stds[key]
        test_data_seq[key] = (test_data_seq[key] - means[key]) / stds[key]

In [12]:
def tile_patient_types(data: dict):
    data["patient_type_tile"] = np.tile(
        data["patient_types"][:, None], (1, data["cancer_volume"].shape[1])
    )
    return data


training_data = tile_patient_types(training_data)
validation_data = tile_patient_types(validation_data)
test_data_factuals = tile_patient_types(test_data_factuals)
test_data_counterfactuals = tile_patient_types(test_data_counterfactuals)
test_data_seq = tile_patient_types(test_data_seq)

In [11]:
horizon = 1
offset = 1

In [13]:
patient_types = training_data["patient_type_tile"]
cancer_volume = training_data["cancer_volume"]

current_covariates = np.concatenate(
    [cancer_volume[:, :-offset, np.newaxis], patient_types[:, :-offset, np.newaxis]],
    axis=-1,
)
outputs = cancer_volume[:, horizon:, np.newaxis]

In [14]:
# add a column that indicates the application of chemo and radio at time t-1
def add_previous_treatment(data: dict):
    data["chemo_application_prev"] = np.roll(data["chemo_application"], offset, axis=1)
    data["radio_application_prev"] = np.roll(data["radio_application"], offset, axis=1)
    data["chemo_application_prev"][:, 0] = 0
    data["radio_application_prev"][:, 0] = 0
    return data


training_data = add_previous_treatment(training_data)
validation_data = add_previous_treatment(validation_data)
test_data_factuals = add_previous_treatment(test_data_factuals)
test_data_counterfactuals = add_previous_treatment(test_data_counterfactuals)
test_data_seq = add_previous_treatment(test_data_seq)

In [15]:
chemo_application = training_data["chemo_application"]
radio_application = training_data["radio_application"]

treatments = np.concatenate(
    [
        chemo_application[:, :-offset, np.newaxis],
        radio_application[:, :-offset, np.newaxis],
    ],
    axis=-1,
)

In [16]:
# sequence lengths is reported once per patient, so we tile it to match the shape of the data
def add_sequence_lengths(data: dict):
    data["time_step"] = np.tile(
        np.arange(data["cancer_volume"].shape[1]), (data["cancer_volume"].shape[0], 1)
    )
    data["sequence_length_tile"] = np.tile(
        data["sequence_lengths"][:, None], (1, data["cancer_volume"].shape[1])
    )
    data["distance_from_end"] = data["sequence_length_tile"] - data["time_step"]
    # distance_from_end_scaled is a value between 0 and 1 that indicates how far we are from the end of the sequence for each patient
    data["distance_from_end_scaled"] = (
        data["distance_from_end"] / data["sequence_length_tile"]
    )
    data["active_entries"] = data["distance_from_end"] > 0
    return data


training_data = add_sequence_lengths(training_data)
validation_data = add_sequence_lengths(validation_data)
test_data_factuals = add_sequence_lengths(test_data_factuals)
test_data_counterfactuals = add_sequence_lengths(test_data_counterfactuals)
test_data_seq = add_sequence_lengths(test_data_seq)

In [17]:
active_entries = np.zeros(outputs.shape)

for i in range(training_data["sequence_lengths"].shape[0]):
    sequence_length = int(training_data["sequence_lengths"][i])
    active_entries[i, :sequence_length, :] = 1

In [18]:
def add_unscaled_data(data: dict, means=means, stds=stds):
    temp_data = copy.deepcopy(data)
    for key in data.keys():
        if key in means.keys():
            var_name = key + "_unscaled"
            temp_data[var_name] = data[key] * stds[key] + means[key]
    return temp_data


training_data = add_unscaled_data(training_data)
validation_data = add_unscaled_data(validation_data)
test_data_factuals = add_unscaled_data(test_data_factuals)
test_data_counterfactuals = add_unscaled_data(test_data_counterfactuals)
test_data_seq = add_unscaled_data(test_data_seq)

In [19]:
unscaled_outputs = outputs * stds["cancer_volume"] + means["cancer_volume"]

In [20]:
prev_outputs = current_covariates[:, :, :1]

In [21]:
treatments = np.concatenate(
    [
        training_data["chemo_application"][:, :-offset, np.newaxis],
        training_data["radio_application"][:, :-offset, np.newaxis],
    ],
    axis=-1,
)

current_treatments = treatments
prev_treatments = treatments[:, :-1, :]

In [22]:
zero_init_treatment = np.zeros(
    shape=[current_covariates.shape[0], 1, prev_treatments.shape[-1]]
)
prev_treatments = np.concatenate([zero_init_treatment, prev_treatments], axis=1)

In [23]:
output_means = means[["cancer_volume"]].values.flatten()[0]
output_stds = stds[["cancer_volume"]].values.flatten()[0]

scaling_params = {
    "input_means": means,
    "inputs_stds": stds,
    "output_means": means[["cancer_volume"]].values.flatten()[0],
    "output_stds": stds[["cancer_volume"]].values.flatten()[0],
}

In [24]:
# training_data_sequence_lengths = training_data['sequence_lengths']
# validation_data_sequence_lengths = validation_data['sequence_lengths']
# test_data_factuals_sequence_lengths = test_data_factuals['sequence_lengths']
# test_data_counterfactuals_sequence_lengths = test_data_counterfactuals['sequence_lengths']
# test_data_seq_sequence_lengths = test_data_seq['sequence_lengths']

In [25]:
# needed_keys = [
#     'chemo_application_prev', 'radio_application_prev', 'patient_type_tile', 'cancer_volume'
# ]
# #
# training_data = {key: training_data[key] for key in needed_keys}
# validation_data = {key: validation_data[key] for key in needed_keys}
# test_data_factuals = {key: test_data_factuals[key] for key in needed_keys}
# test_data_counterfactuals = {key: test_data_counterfactuals[key] for key in needed_keys}
# test_data_seq = {key: test_data_seq[key] for key in needed_keys}

In [26]:
# #convert tensors where the first dimension is the number of patients the second dimension is time and the third is the features
# #the keys are the features
# #within the features, the first dimension is the number of patients, the second is time

# #convert the keys to a dimension in the tensor
# def dictionary_to_tensor(data_dict):
#     # Extract keys and values from the dictionary
#     keys = list(data_dict.keys())
#     values = [data_dict[key] for key in keys]

#     # Check consistency in dimensions
#     num_patients = values[0].shape[0]
#     time_steps = values[0].shape[1]

#     for table in values:
#         assert table.shape[0] == num_patients, "Number of patients mismatch."
#         assert table.shape[1] == time_steps, "Number of time steps mismatch."

#     # Stack tables along the new feature dimension
#     tensor = np.stack(values, axis=-1)
#     tensor = torch.tensor(tensor, dtype=torch.float32)

#     return tensor, keys

# training_data_tensor, keys = dictionary_to_tensor(training_data)
# validation_data_tensor, keys = dictionary_to_tensor(validation_data)
# test_data_factuals_tensor, keys = dictionary_to_tensor(test_data_factuals)
# test_data_counterfactuals_tensor, keys = dictionary_to_tensor(test_data_counterfactuals)
# test_data_seq_tensor, keys = dictionary_to_tensor(test_data_seq)

# Data for CT

In [27]:
import numpy as np


class DataPreparer:
    def __init__(self, data, means, stds, horizon=1, offset=1):
        self.data = data
        self.means = means
        self.stds = stds
        self.horizon = horizon
        self.offset = offset
        self.training_data_ct = {}
        self.scaling_params = {}

    def prepare_data(self):
        training_data = self.data
        patient_types = training_data["patient_type_tile"]
        cancer_volume = training_data["cancer_volume"]
        input_means = self.means
        input_stds = self.stds
        output_means = self.means[["cancer_volume"]].values.flatten()[0]
        output_stds = self.stds[["cancer_volume"]].values.flatten()[0]

        chemo_application = training_data["chemo_application"]
        radio_application = training_data["radio_application"]
        sequence_lengths = training_data["sequence_lengths"]

        self.training_data_ct["current_covariates"] = np.concatenate(
            [
                training_data["cancer_volume"][:, : -self.horizon, np.newaxis],
                training_data["patient_type_tile"][:, : -self.horizon, np.newaxis],
            ],
            axis=-1,
        )
        self.training_data_ct["outputs"] = training_data["cancer_volume"][
            :, self.horizon :, np.newaxis
        ]
        self.training_data_ct["active_entries"] = training_data["active_entries"][
            :, : -self.offset, np.newaxis
        ]
        self.training_data_ct["unscaled_outputs"] = training_data[
            "cancer_volume_unscaled"
        ][:, self.horizon :, np.newaxis]
        self.training_data_ct["prev_outputs"] = np.concatenate(
            [
                training_data["cancer_volume"][:, : -self.horizon, np.newaxis],
                training_data["patient_type_tile"][:, : -self.horizon, np.newaxis],
            ],
            axis=-1,
        )[:, :, :1]
        self.training_data_ct["static_features"] = np.concatenate(
            [
                training_data["cancer_volume"][:, : -self.horizon, np.newaxis],
                training_data["patient_type_tile"][:, : -self.horizon, np.newaxis],
            ],
            axis=-1,
        )[:, 0, 1:]
        self.training_data_ct["current_treatments"] = np.concatenate(
            [
                training_data["chemo_application"][:, : -self.offset, np.newaxis],
                training_data["radio_application"][:, : -self.offset, np.newaxis],
            ],
            axis=-1,
        )
        self.training_data_ct["prev_treatments"] = np.concatenate(
            [
                training_data["chemo_application_prev"][:, : -self.offset, np.newaxis],
                training_data["radio_application_prev"][:, : -self.offset, np.newaxis],
            ],
            axis=-1,
        )
        self.training_data_ct["sequence_lengths"] = training_data["sequence_lengths"]
        self.training_data_ct["patient_types"] = patient_types
        self.training_data_ct["cancer_volume"] = cancer_volume
        self.training_data_ct["chemo_application"] = chemo_application
        self.training_data_ct["radio_application"] = radio_application
        self.training_data_ct["sequence_lengths"] = sequence_lengths

        self.scaling_params = {
            "input_means": input_means,
            "inputs_stds": input_stds,
            "output_means": output_means,
            "output_stds": output_stds,
        }

        result = {"data": self.training_data_ct, "scaling_params": self.scaling_params}

        self.prepared_data = result

        return result

    def explode_trajectories(self, data, scaling_params, projection_horizon):

        outputs = data["outputs"]
        prev_outputs = data["prev_outputs"]
        sequence_lengths = data["sequence_lengths"]
        active_entries = data["active_entries"]
        current_treatments = data["current_treatments"]
        previous_treatments = data["prev_treatments"]
        static_features = data["static_features"]
        # if 'stabilized_weights' in data:
        #     stabilized_weights = data['stabilized_weights']

        num_patients, max_seq_length, num_features = outputs.shape
        num_seq2seq_rows = num_patients * max_seq_length

        seq2seq_previous_treatments = np.zeros(
            (num_seq2seq_rows, max_seq_length, previous_treatments.shape[-1])
        )
        seq2seq_current_treatments = np.zeros(
            (num_seq2seq_rows, max_seq_length, current_treatments.shape[-1])
        )
        seq2seq_static_features = np.zeros(
            (num_seq2seq_rows, static_features.shape[-1])
        )
        seq2seq_outputs = np.zeros(
            (num_seq2seq_rows, max_seq_length, outputs.shape[-1])
        )
        seq2seq_prev_outputs = np.zeros(
            (num_seq2seq_rows, max_seq_length, prev_outputs.shape[-1])
        )
        # seq2seq_vitals = np.zeros((num_seq2seq_rows, max_seq_length, vitals.shape[-1]))
        # seq2seq_next_vitals = np.zeros((num_seq2seq_rows, max_seq_length - 1, next_vitals.shape[-1]))
        seq2seq_active_entries = np.zeros(
            (num_seq2seq_rows, max_seq_length, active_entries.shape[-1])
        )
        seq2seq_sequence_lengths = np.zeros(num_seq2seq_rows)
        # if 'stabilized_weights' in data:
        #     seq2seq_stabilized_weights = np.zeros(
        #         (num_seq2seq_rows, max_seq_length))

        total_seq2seq_rows = 0  # we use this to shorten any trajectories later

        for i in range(num_patients):
            sequence_length = int(sequence_lengths[i])

            for t in range(
                projection_horizon, sequence_length
            ):  # shift outputs back by 1
                seq2seq_active_entries[total_seq2seq_rows, : (t + 1), :] = (
                    active_entries[i, : (t + 1), :]
                )
                # if 'stabilized_weights' in data:
                #     seq2seq_stabilized_weights[total_seq2seq_rows, :(
                #         t + 1)] = stabilized_weights[i, :(t + 1)]
                seq2seq_previous_treatments[total_seq2seq_rows, : (t + 1), :] = (
                    previous_treatments[i, : (t + 1), :]
                )
                seq2seq_current_treatments[total_seq2seq_rows, : (t + 1), :] = (
                    current_treatments[i, : (t + 1), :]
                )
                seq2seq_outputs[total_seq2seq_rows, : (t + 1), :] = outputs[
                    i, : (t + 1), :
                ]
                seq2seq_prev_outputs[total_seq2seq_rows, : (t + 1), :] = prev_outputs[
                    i, : (t + 1), :
                ]
                # seq2seq_vitals[total_seq2seq_rows, :(t + 1), :] = vitals[i, :(t + 1), :]
                # seq2seq_next_vitals[total_seq2seq_rows, :min(t + 1, sequence_length - 1), :] = \
                #     next_vitals[i, :min(t + 1, sequence_length - 1), :]
                seq2seq_sequence_lengths[total_seq2seq_rows] = t + 1
                seq2seq_static_features[total_seq2seq_rows] = static_features[i]

                total_seq2seq_rows += 1

        # Filter everything shorter
        seq2seq_previous_treatments = seq2seq_previous_treatments[
            :total_seq2seq_rows, :, :
        ]
        seq2seq_current_treatments = seq2seq_current_treatments[
            :total_seq2seq_rows, :, :
        ]
        seq2seq_static_features = seq2seq_static_features[:total_seq2seq_rows, :]
        seq2seq_outputs = seq2seq_outputs[:total_seq2seq_rows, :, :]
        seq2seq_prev_outputs = seq2seq_prev_outputs[:total_seq2seq_rows, :, :]
        # seq2seq_vitals = seq2seq_vitals[:total_seq2seq_rows, :, :]
        # seq2seq_next_vitals = seq2seq_next_vitals[:total_seq2seq_rows, :, :]
        seq2seq_active_entries = seq2seq_active_entries[:total_seq2seq_rows, :, :]
        seq2seq_sequence_lengths = seq2seq_sequence_lengths[:total_seq2seq_rows]

        # if 'stabilized_weights' in data:
        #     seq2seq_stabilized_weights = seq2seq_stabilized_weights[:total_seq2seq_rows]

        new_data = {
            "prev_treatments": seq2seq_previous_treatments,
            "current_treatments": seq2seq_current_treatments,
            "static_features": seq2seq_static_features,
            "prev_outputs": seq2seq_prev_outputs,
            "outputs": seq2seq_outputs,
            # 'vitals': seq2seq_vitals,
            # 'next_vitals': seq2seq_next_vitals,
            "unscaled_outputs": seq2seq_outputs * scaling_params["output_stds"]
            + scaling_params["output_means"],
            "sequence_lengths": seq2seq_sequence_lengths,
            "active_entries": seq2seq_active_entries,
        }
        # if 'stabilized_weights' in data:
        #     new_data['stabilized_weights'] = seq2seq_stabilized_weights

        return new_data

In [28]:
data_preparer = DataPreparer(training_data, means, stds)
result = data_preparer.prepare_data()

training_data_ct = result["data"]
scaling_params = result["scaling_params"]

In [29]:
training_data_ct_exploded = data_preparer.explode_trajectories(
    training_data_ct, scaling_params, horizon
)

training_data_ct_exploded.keys()

dict_keys(['prev_treatments', 'current_treatments', 'static_features', 'prev_outputs', 'outputs', 'unscaled_outputs', 'sequence_lengths', 'active_entries'])

# MSM

In [43]:
def get_propensity_scores(exploded_dataset, model_type, lag_features, dim_treatments):

    dataset = exploded_dataset
    active_entries = dataset['active_entries']
    prev_treatments = dataset['prev_treatments']
    inputs = (prev_treatments * active_entries).sum(1)
    
    classifier = model_type

    propensity_scores = np.stack(classifier.predict_proba(inputs), 1)[:, :, 1]
    propensity_scores = propensity_scores.reshape(dataset['active_entries'].shape[0],
                                                    dataset['active_entries'].shape[1] - lag_features,
                                                    dim_treatments)
    propensity_scores = np.concatenate([0.5 * np.ones((propensity_scores.shape[0], lag_features, dim_treatments)),
                                        propensity_scores], axis=1)
    return propensity_scores

In [None]:
class MSM:
    """
    Pytorch-Lightning implementation of Marginal Structural Models (MSMs) (https://pubmed.ncbi.nlm.nih.gov/10955408/)
    """
    model_type = None  # Will be defined in subclasses
    possible_model_types = {'msm_regressor',
                            'propensity_treatment', 'propensity_history'}
    tuning_criterion = None

    def __init__(self,
                 args: DictConfig,
                 dataset_collection: Union[RealDatasetCollection,
                                           SyntheticDatasetCollection] = None,
                 autoregressive: bool = None,
                 has_vitals: bool = None,
                 **kwargs):
        """
        Args:
            args: DictConfig of model hyperparameters
            dataset_collection: Dataset collection
            autoregressive: Flag of including previous outcomes to modelling
            has_vitals: Flag of vitals in dataset
            **kwargs: Other arguments
        """
        super().__init__(args, dataset_collection, autoregressive, has_vitals)
        self.lag_features = args.model.lag_features

    def prepare_data(self) -> None:
        if self.dataset_collection is not None and not self.dataset_collection.processed_data_multi:
            # Only binary multilabel regime possible
            assert self.hparams.dataset.treatment_mode == 'multilabel'
            self.dataset_collection.process_data_multi()

    def get_exploded_dataset(self, dataset: Dataset, min_length: int, only_active_entries=True, max_length=None) -> Dataset:
        exploded_dataset = deepcopy(dataset)
        if max_length is None:
            max_length = max(exploded_dataset.data['sequence_lengths'][:])
        if not only_active_entries:
            exploded_dataset.data['active_entries'][:, :, :] = 1.0
            exploded_dataset.data['sequence_lengths'][:] = max_length
        exploded_dataset.explode_trajectories(min_length)
        return exploded_dataset

    def get_propensity_scores(self, dataset: Dataset) -> np.array:
        logger.info(f'Propensity scores for {dataset.subset_name}.')
        exploded_dataset = self.get_exploded_dataset(
            dataset, min_length=self.lag_features, only_active_entries=False)

        inputs = self.get_inputs(exploded_dataset)
        classifier = getattr(self, self.model_type)

        propensity_scores = np.stack(
            classifier.predict_proba(inputs), 1)[:, :, 1]
        propensity_scores = propensity_scores.reshape(dataset.data['active_entries'].shape[0],
                                                      dataset.data['active_entries'].shape[1] -
                                                      self.lag_features,
                                                      self.dim_treatments)
        propensity_scores = np.concatenate([0.5 * np.ones((propensity_scores.shape[0], self.lag_features, self.dim_treatments)),
                                            propensity_scores], axis=1)
        return propensity_scores

In [46]:
import sklearn
from sklearn.linear_model import LogisticRegression

model_type = LogisticRegression(max_iter=1000)

propensity_scores = get_propensity_scores(training_data_ct_exploded, model_type, 1, 2)

NotFittedError: This LogisticRegression instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.