# Build custom policy based on CRN model recommendations

In [87]:
import os
import pandas as pd
from tqdm import tqdm
from utils import flatten
from data_loaders.temporAI_dataloader import load_gsu_dataset
import numpy as np
from tempor.utils.serialization import load_from_file
from tempor.data.dataset import TemporalTreatmentEffectsDataset

In [2]:
model_path = '/Users/jk1/temp/treatment_effects/training_test/crn_20231127_203329/crn_model_20231127_203329_split_0.cpkl'
split_folder = '/Users/jk1/temp/treatment_effects/preprocessing/gsu_Extraction_20220815_prepro_25112023_213851/splits'

In [3]:
split = 0

# Load data and model

In [4]:
# Load data and model
train_features_path = os.path.join(split_folder, f'train_features_split_{split}.csv')
train_cont_outcomes_path = os.path.join(split_folder, f'train_continuous_outcomes_split_{split}.csv')
val_features_path = os.path.join(split_folder, f'val_features_split_{split}.csv')
val_cont_outcomes_path = os.path.join(split_folder, f'val_continuous_outcomes_split_{split}.csv')
val_gsu_dataset = load_gsu_dataset(val_features_path, val_cont_outcomes_path)
model = load_from_file(model_path)

# Get predictions

Criteria for working inference
- treatment scenarion must be binary (only 2 different ttt per scenario)
- start at timestep 2 (CRN needs at least 2 timesteps run-in

In [5]:
# horizons: all predictions (starting from TC 2)
horizons = [tc.time_indexes()[0][2:] for tc in val_gsu_dataset.time_series]

In [6]:
treatment_scenarios = [[ttt.dataframe().values[2:].squeeze().astype(int)
                        ] for ttt in val_gsu_dataset.predictive.treatments]

In [7]:
single_step_horizon = [tc.time_indexes()[0][2:3] for tc in val_gsu_dataset.time_series]

In [8]:
single_step_treatment_scenarios = [[np.array([ttt_strat]) for ttt_strat in range(8)] for h in single_step_horizon]

In [9]:
predictions = model.predict_counterfactuals(val_gsu_dataset[0:2], horizons=single_step_horizon[0:2], treatment_scenarios=single_step_treatment_scenarios[0:2])

In [10]:
extracted_predictions = [flatten(flatten([subj_pred[ttt_strat_idx].to_numpy() for ttt_strat_idx in range(len(subj_pred))])) for subj_pred in predictions]

In [12]:
extracted_predictions[1]

[-0.07683232426643372,
 -0.07031607627868652,
 -0.06379981338977814,
 -0.05728356912732124,
 -0.050767313688993454,
 -0.04425106570124626,
 -0.03773481026291847,
 -0.03121856227517128]

### Absolute decision function

In [13]:
# for every prediction return argmin of prediction (choose treatment which minimizes delta NIHSS)
def absolute_decision_function(predicted_counterfactuals_per_ttt):
    return np.argmin(predicted_counterfactuals_per_ttt)

In [14]:
[absolute_decision_function(subj_extracted_predictions) for subj_extracted_predictions in extracted_predictions]

[0, 0]

### Probabilistic decision function

In [15]:
def probabilistic_decision_function(predicted_counterfactuals_per_ttt, epsilon = 1e-6):
    """
    Compute likelihood and log-likelihood of choosing every treatment option based on predicted counterfactuals
    
    :param predicted_counterfactuals_per_ttt: 
    :param epsilon: 
    :return: 
    """
    
    # map extracted_predictions[0] to 0-1 (where most negative value should be mapped to 1 and most positive to 0)
    min = np.min(predicted_counterfactuals_per_ttt)
    max = np.max(predicted_counterfactuals_per_ttt)
    
    likelihood = (predicted_counterfactuals_per_ttt - max) / (min - max)
    log_likelihood = np.log(likelihood + epsilon)
    return likelihood, log_likelihood

In [16]:
[probabilistic_decision_function(subj_extracted_predictions) for subj_extracted_predictions in extracted_predictions]

[(array([ 1.        , -0.        ,  0.97755378,  0.96633069,  0.95510759,
          0.94388448,  0.93266138,  0.92143827]),
  array([ 9.99999500e-07, -1.38155106e+01, -2.27009440e-02, -3.42481411e-02,
         -4.59302418e-02, -5.77504355e-02, -6.97120053e-02, -8.18184022e-02])),
 (array([ 1.        ,  0.85714294,  0.71428555,  0.57142857,  0.42857135,
          0.28571429,  0.14285706, -0.        ]),
  array([ 9.99999500e-07, -1.54149418e-01, -3.36471065e-01, -5.59614038e-01,
         -8.47295718e-01, -1.25275947e+00, -1.94590372e+00, -1.38155106e+01]))]

In [None]:
predicted_factuals = model.predict_counterfactuals(val_gsu_dataset[0:2], horizons=horizons[0:2], treatment_scenarios=treatment_scenarios[0:2])
predicted_factuals

In [101]:
dataset = val_gsu_dataset[0:2]
n_timesteps_to_predict = 1
n_timesteps = dataset.time_series[0].dataframe().shape[0]
n_timesteps = 5
n_treatment_strategies = 8
# iteratively update dataset with selected treatment
update_treatment = True

In [104]:
predicted_treatment_strategies_df = pd.DataFrame()
for ts in tqdm(range(2, n_timesteps - n_timesteps_to_predict + 1)):
    # predict single timestep at a time
    horizon = [tc.time_indexes()[0][ts:ts + n_timesteps_to_predict] for tc in dataset.time_series]
    treatment_scenarios = [[np.array([ttt_strat]) for ttt_strat in range(8)] for h in horizon]

    predictions = model.predict_counterfactuals(dataset, horizons=horizon, treatment_scenarios=treatment_scenarios)
    extracted_predictions = [flatten(flatten([subj_pred[ttt_strat_idx].to_numpy() for ttt_strat_idx in range(len(subj_pred))])) for subj_pred in predictions]
    
    optimal_treatment_option = [absolute_decision_function(subj_extracted_predictions) for subj_extracted_predictions in extracted_predictions]
    
    treatment_probas = [probabilistic_decision_function(subj_extracted_predictions) for subj_extracted_predictions in extracted_predictions]
    treatment_likelihoods = [treatment_proba[0] for treatment_proba in treatment_probas]
    treatment_log_likelihoods = [treatment_proba[1] for treatment_proba in treatment_probas]
    
    
    temp_df = pd.DataFrame({'case_admission_id': dataset.time_series.dataframe().reset_index()['sample_idx'].unique(),
                            'time_idx': ts,
                            'optimal_treatment_option': optimal_treatment_option})
    # add a column for every likelihood treatment option
    for ttt_strat_idx in range(n_treatment_strategies):
        temp_df[f'treatment_likelihood_strat_{ttt_strat_idx}'] = [treatment_likelihood[ttt_strat_idx] for treatment_likelihood in treatment_likelihoods]
        temp_df[f'treatment_log_likelihood_strat_{ttt_strat_idx}'] = [treatment_log_likelihood[ttt_strat_idx] for treatment_log_likelihood in treatment_log_likelihoods]
    predicted_treatment_strategies_df = pd.concat([predicted_treatment_strategies_df, temp_df], axis=0)
    
    if update_treatment:
        # update dataset with predicted treatment
        temp = dataset.predictive.treatments.dataframe()
        temp.loc[(slice(None), ts), 'anti_hypertensive_strategy'] = optimal_treatment_option
        dataset = TemporalTreatmentEffectsDataset(
            time_series=dataset.time_series.dataframe(),
            treatments=temp,
            targets=dataset.predictive.targets.dataframe()
        )
    

100%|██████████| 3/3 [00:11<00:00,  3.83s/it]


In [105]:
predicted_treatment_strategies_df

Unnamed: 0,case_admission_id,time_idx,optimal_treatment_option,treatment_likelihood_strat_0,treatment_log_likelihood_strat_0,treatment_likelihood_strat_1,treatment_log_likelihood_strat_1,treatment_likelihood_strat_2,treatment_log_likelihood_strat_2,treatment_likelihood_strat_3,treatment_log_likelihood_strat_3,treatment_likelihood_strat_4,treatment_log_likelihood_strat_4,treatment_likelihood_strat_5,treatment_log_likelihood_strat_5,treatment_likelihood_strat_6,treatment_log_likelihood_strat_6,treatment_likelihood_strat_7,treatment_log_likelihood_strat_7
0,10189_1690,2,0,1.0,9.999995e-07,0.857143,-0.15415,0.714286,-0.336471,0.571429,-0.559614,0.428571,-0.847295,0.285714,-1.25276,0.142857,-1.945903,-0.0,-13.815511
1,1025279_1586,2,0,1.0,9.999995e-07,0.857143,-0.154149,0.714286,-0.336471,0.571429,-0.559614,0.428571,-0.847296,0.285714,-1.252759,0.142857,-1.945904,-0.0,-13.815511
0,10189_1690,3,0,1.0,9.999995e-07,0.857143,-0.154149,0.714286,-0.336471,0.571429,-0.559614,0.428571,-0.847296,0.285714,-1.25276,0.142857,-1.945903,-0.0,-13.815511
1,1025279_1586,3,0,1.0,9.999995e-07,0.857143,-0.154149,0.714286,-0.336471,0.571429,-0.559614,0.428572,-0.847295,0.285714,-1.25276,0.142857,-1.945903,-0.0,-13.815511
0,10189_1690,4,0,1.0,9.999995e-07,0.857143,-0.154149,0.714286,-0.336471,0.571429,-0.559614,0.428571,-0.847295,0.285714,-1.252759,0.142857,-1.945902,-0.0,-13.815511
1,1025279_1586,4,0,1.0,9.999995e-07,0.857143,-0.154149,0.714286,-0.336471,0.571429,-0.559614,0.428571,-0.847295,0.285714,-1.252759,0.142857,-1.945902,-0.0,-13.815511
