In [163]:
import os

import pandas as pd
from tqdm import tqdm
from data_loaders.temporAI_dataloader import load_gsu_dataset
import numpy as np
from tempor.utils.serialization import load_from_file

In [6]:
model_path = '/Users/jk1/temp/treatment_effects/training_test/crn_20231127_203329/crn_model_20231127_203329_split_0.cpkl'
split_folder = '/Users/jk1/temp/treatment_effects/preprocessing/gsu_Extraction_20220815_prepro_25112023_213851/splits'
split = 0

# Load data and model

In [4]:
train_features_path = os.path.join(split_folder, f'train_features_split_{split}.csv')
train_cont_outcomes_path = os.path.join(split_folder, f'train_continuous_outcomes_split_{split}.csv')
val_features_path = os.path.join(split_folder, f'val_features_split_{split}.csv')
val_cont_outcomes_path = os.path.join(split_folder, f'val_continuous_outcomes_split_{split}.csv')

In [5]:
val_gsu_dataset = load_gsu_dataset(val_features_path, val_cont_outcomes_path)

In [8]:
model = load_from_file(model_path)

# Get predictions for factuals on validation dataset

In [53]:
# horizons: all predictions (starting from TC 2)
horizons = [tc.time_indexes()[0][2:] for tc in val_gsu_dataset.time_series]

In [122]:
treatment_scenarios = [[ttt.dataframe().values[2:].squeeze().astype(int)
                        ] for ttt in val_gsu_dataset.predictive.treatments]

Criteria for working inference
- treatment scenarion must be binary (only 2 different ttt per scenario)
- start at timestep 2 (CRN needs at least 2 timesteps run-in

In [125]:
predicted_factuals = model.predict_counterfactuals(val_gsu_dataset[0:2], horizons=horizons[0:2], treatment_scenarios=treatment_scenarios[0:2])
predicted_factuals

[[TimeSeries() with data:
            nihss_delta_at_next_ts
  time_idx                        
  2                      -0.034997
  3                      -0.042279
  4                      -0.042230
  5                      -0.042230
  6                      -0.042230
  ...                          ...
  66                     -0.042230
  67                     -0.042230
  68                     -0.042230
  69                     -0.042230
  70                     -0.042230
  
  [69 rows x 1 columns]],
 [TimeSeries() with data:
            nihss_delta_at_next_ts
  time_idx                        
  2                      -0.031219
  3                      -0.038598
  4                      -0.038547
  5                      -0.038548
  6                      -0.038548
  ...                          ...
  66                     -0.038548
  67                     -0.038548
  68                     -0.038548
  69                     -0.038548
  70                     -0.038548
  
  [69 

In [133]:
def root_mean_square_error(x1, x2):
    if len(x1) != len(x2):
        raise ValueError(f'x1 and x2 should be of same length and not {len(x1)} != {len(x2)}')
    rmse = np.sqrt(((x1 - x2) ** 2).mean())
    return rmse

In [143]:
root_mean_square_error(val_gsu_dataset[0].predictive.targets.dataframe().values[2:], 
                       predicted_factuals[0][0].to_numpy())

0.12265416790795378

Always predict only single TS

In [316]:
def flatten(l):
    return [item for sublist in l for item in sublist]

In [303]:
n_timesteps = val_gsu_dataset.time_series[0].dataframe().shape[0]

In [329]:
n_timesteps_to_predict = 1
predicted_factuals_df = pd.DataFrame()
true_factuals_df = pd.DataFrame()
rmse_per_ts_df = pd.DataFrame()
for ts in tqdm(range(2, n_timesteps - n_timesteps_to_predict + 1)):
# for ts in tqdm(range(2, 4)):
    # predict single timestep at a time
    horizon = [tc.time_indexes()[0][ts:ts+n_timesteps_to_predict] for tc in val_gsu_dataset.time_series]
    treatment_scenarios = [[ttt.dataframe().values[ts:ts+n_timesteps_to_predict].astype(int)] 
                           for ttt in val_gsu_dataset.predictive.treatments]
    
    # TODO: change this to include all subjects
    # predicted_factuals_at_ts = model.predict_counterfactuals(val_gsu_dataset, horizons=horizon,         treatment_scenarios=treatment_scenarios)
    predicted_factuals_at_ts = model.predict_counterfactuals(val_gsu_dataset[0:2], horizons=horizon[0:2],         treatment_scenarios=treatment_scenarios[0:2])
    predicted_factuals_df[ts] = flatten(flatten([pfts[0].to_numpy() for pfts in predicted_factuals_at_ts]))
    
    # TODO
    # temp_df = val_gsu_dataset.predictive.targets.dataframe().reset_index()
    temp_df = val_gsu_dataset.predictive.targets[0:2].dataframe().reset_index()
    column_name = val_gsu_dataset.predictive.targets.dataframe().columns[0]
    true_factuals_df[ts] = temp_df[temp_df.time_idx.isin(range(ts, ts+n_timesteps_to_predict))][column_name].values
    
    rmse_per_ts_df[ts] = [root_mean_square_error(true_factuals_df[ts].values,
                                        predicted_factuals_df[ts].values)]
    

100%|██████████| 69/69 [06:52<00:00,  5.98s/it]


In [330]:
true_factuals_df

Unnamed: 0,2,3,4,5,6,7,8,9,10,11,...,61,62,63,64,65,66,67,68,69,70
0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,-1.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [331]:
predicted_factuals_df

Unnamed: 0,2,3,4,5,6,7,8,9,10,11,...,61,62,63,64,65,66,67,68,69,70
0,-0.034997,-0.055535,-0.067548,-0.078275,-0.080675,-0.083704,-0.087344,-0.090192,-0.092651,-0.095064,...,-0.074182,-0.076579,-0.076311,-0.075674,-0.075303,-0.075022,-0.073745,-0.073277,-0.073052,-0.072976
1,-0.031219,-0.047482,-0.057662,-0.063753,-0.074397,-0.068745,-0.071192,-0.071263,-0.073377,-0.062573,...,-0.077067,-0.077068,-0.077068,-0.077068,-0.077068,-0.077068,-0.077068,-0.077068,-0.077068,-0.077068


In [332]:
predicted_factuals_at_ts[0][0].to_numpy()

array([[-0.07297616]])

In [333]:
rmse_per_ts_df

Unnamed: 0,2,3,4,5,6,7,8,9,10,11,...,61,62,63,64,65,66,67,68,69,70
0,0.033162,0.051666,0.0628,0.75422,0.077599,0.07659,0.079678,0.08128,0.658488,0.080475,...,0.075638,0.076823,0.07669,0.076374,0.076191,0.076052,0.075425,0.075197,0.075087,0.07505


In [334]:
overall_rmse = root_mean_square_error(
    predicted_factuals_df.melt()['value'].values,
    true_factuals_df.melt()['value'].values
)
overall_rmse

0.16185998240383107

In [299]:
predicted_factuals_df

Unnamed: 0,2,3,4,5,6,7
0,-0.034997,-0.055535,-0.067548,-0.078275,-0.080675,-0.083704
1,-0.031219,-0.047482,-0.057662,-0.063753,-0.074397,-0.068745
2,-0.028617,-0.039621,-0.045561,-0.045533,-0.048170,-0.050694
3,-0.081144,-0.063939,-0.091070,-0.095970,-0.103121,-0.100882
4,-0.035518,-0.057192,-0.069503,-0.076238,-0.078553,-0.080481
...,...,...,...,...,...,...
351,-0.026591,-0.001958,-0.048967,-0.055947,-0.058771,-0.056175
352,-0.029668,-0.040226,-0.042179,-0.042969,-0.045682,-0.047156
353,-0.031401,-0.045244,-0.056741,-0.058990,-0.061918,-0.061960
354,-0.040367,-0.058833,-0.067498,-0.075655,-0.081521,-0.082423


In [300]:
true_factuals_df

Unnamed: 0,2,3,4,5,6,7
0,0.0,0.0,0.0,0.0,0.0,0.0
1,0.0,0.0,0.0,1.0,0.0,0.0
2,0.0,0.0,0.0,0.0,0.0,0.0
3,8.5,0.0,0.0,0.0,0.0,0.0
4,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...
351,-6.0,0.0,0.0,0.0,0.0,0.0
352,0.0,0.0,0.0,0.0,0.0,0.0
353,0.0,0.0,0.0,0.0,0.0,0.0
354,0.0,0.0,0.0,0.0,0.0,0.0
