# Counterfactual Analysis of Target Temperature Management (TTM) Protocols

## Introduction
This notebook explores the impact of different Target Temperature Management (TTM) protocols on patient outcomes, with the aim of identifying optimal treatments for specific patient subpopulations. By evaluating counterfactual scenarios, we assess how changes in TTM protocols—specifically, TTM at 33°C, TTM at 36°C, and no TTM—affect the model's predictions for each patient. Through this analysis, we aim to:
- Determine the best predicted outcome per patient under each TTM protocol.
- Identify distinct patient groups that benefit most from TTM 33, TTM 36, or no TTM.
- Evaluate if the changes are statistically significant between original and counterfactual groups.
- Explore shared characteristics within each group to better understand optimal TTM conditions.

## Table of Contents
1. [Setup & Imports](#setup-imports)
2. [Data Loading and Preparation](#data-loading-preparation)
3. [Define Counterfactual Scenarios for TTM Protocols](#define-counterfactual-scenarios)
4. [Generate and Compare Predictions for Each Protocol](#generate-and-compare-predictions)
5. [Group Patients by Optimal Protocol](#group-patients-by-optimal-protocol)
6. [Significance Testing of Prediction Changes](#significance-testing)
7. [Analysis of Commonalities Within Each Group](#group-commonalities-analysis)

---

Each section provides a step-by-step approach to answer our main questions and assess the impact of different TTM protocols on patient outcomes. By the end of this notebook, we will gain insights into which TTM protocol is most beneficial for distinct patient groups and identify key characteristics that define these groups.


## 1. Setup & Imports <a id="setup-imports"></a>

In this section, we import the necessary libraries for data manipulation, model prediction, statistical testing, and visualization. These libraries will enable us to efficiently load, process, and analyze data, as well as interpret model outputs.

- **NumPy and Pandas**: Essential for data handling and manipulation.
- **TensorFlow**: Used to load and work with the neural network model for generating predictions under different TTM protocols.
- **sklearn**: Provides metrics for evaluating model performance.
- **scipy.stats**: Useful for conducting statistical tests to evaluate the significance of prediction changes.
- **SHAP**: Helps interpret model predictions, which is useful in analyzing the feature contributions for different TTM protocols.
- **Matplotlib and Seaborn**: Visualization libraries for creating insightful plots.

[Back to Table of Contents](#table-of-contents)

In [2]:
# Setup & Imports
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras.models import load_model
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score, accuracy_score
from scipy.stats import ttest_rel, wilcoxon
import shap  # For SHAP values if needed
import matplotlib.pyplot as plt
import seaborn as sns

# Optional: Set up visualization settings
sns.set_theme(style="whitegrid")
plt.rcParams["figure.figsize"] = (10, 6)

# Check if GPU is available
if tf.config.list_physical_devices('GPU'):
    print("GPU is available and will be used for training.")
else:
    print("GPU is not available. Training will use the CPU.")


GPU is not available. Training will use the CPU.


## 2. Data Loading and Preparation <a id="data-loading-preparation"></a>

Here, we load the preprocessed datasets required for counterfactual analysis. The data includes:
- **machine_learning_patient_data**: Contains patient data already prepared for model predictions.
- **ecg_data**: Features specific to ECG measurements.
- **patient_data**: Contains target outcome values mapped as 1 for "Good" and 0 for "Poor".

These datasets will be merged and combined to create a single DataFrame for analysis. Additionally, time-series features are reshaped to match the model's input format.

[Back to Table of Contents](#table-of-contents)


In [None]:
import pickle
import pandas as pd

# Loading the dataset that has been fully prepared for distance-based prediction
with open('data/machine_learning_patient_data.pkl', 'rb') as f:
    machine_learning_patient_data = pickle.load(f)

# Load ECG data
ecg_data = pd.read_csv('data/ecg_data.csv')

# Loading the dataset that has the target value and original features and values
with open('data/patient_data.pkl', 'rb') as f:
    patient_data = pickle.load(f)

# Map 'Good' to 1 and 'Poor' to 0 for target outcome
patient_data['outcome'] = patient_data['outcome'].map({'Good': 1, 'Poor': 0}).astype(int)

# Step 1: Merge machine_learning_patient_data and ecg_data on patient identifiers
combined_data = pd.merge(
    machine_learning_patient_data, ecg_data, 
    left_on='Patient', right_on='Patient_ID', 
    how='inner'
)

# Step 2: Merge the resulting combined_data with patient_data to add the 'outcome' column
combined_data = pd.merge(
    combined_data, patient_data[['Patient', 'outcome']], 
    on='Patient', 
    how='inner'
)

# Display the first few rows of the combined DataFrame to verify
combined_data.head()


Unnamed: 0,Patient,age,sex_Female,sex_Male,ohca_True,ohca_Unknown,ttm_33.0,ttm_36.0,ttm_No TTM,shockable_rhythm_False,...,Segment_287_HRV_SDNN,Segment_287_LF_Power,Segment_287_HF_Power,Segment_287_LF_HF_Ratio,Segment_288_Mean_HR,Segment_288_HRV_SDNN,Segment_288_LF_Power,Segment_288_HF_Power,Segment_288_LF_HF_Ratio,outcome
0,284,-0.522787,False,True,True,False,True,False,False,False,...,0.981133,-0.093946,-0.088465,-0.096152,0.34452,1.128406,-0.063206,-0.062188,-0.238693,1
1,286,1.525272,True,False,False,False,False,False,True,True,...,1.81649,-0.073634,-0.100386,-0.059189,0.48228,0.0,-0.061845,-0.061378,-0.176915,1
2,296,-0.842797,False,True,True,False,False,True,False,False,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1
3,299,-1.034802,False,True,True,False,True,False,False,False,...,2.456931,-0.111831,-0.085917,-0.109046,-0.516477,0.50152,-0.063652,-0.062835,-0.192359,1
4,303,-0.650791,False,True,True,False,True,False,False,False,...,-0.759196,-0.182218,-0.129349,-0.12194,-1.170835,0.00286,-0.063798,-0.062743,-0.30356,1


## 3. Define Counterfactual Scenarios for TTM Protocols <a id="define-counterfactual-scenarios"></a>

In this section, we define counterfactual scenarios to evaluate how each TTM protocol (TTM 33, TTM 36, and No TTM) affects patient outcomes. For each patient, we will create variations with each TTM protocol and then generate predictions based on these alternative treatments.

Steps:
1. **Set Up Counterfactual Treatments**: Define TTM 33, TTM 36, and No TTM variations for each patient.
2. **Generate Predictions for Each Protocol**: Use the trained model to predict outcomes for each TTM scenario.
3. **Prepare Data for Comparison**: Organize the predictions to allow easy comparison between the different TTM protocols.

[Back to Table of Contents](#table-of-contents)


In [8]:
# Identify time-series columns for the selected features
selected_time_series_cols = [col for col in combined_data.columns if 
                             any(feature in col for feature in ["Mean_HR", "HRV_SDNN", "LF_HF_Ratio"])]

# Reshape the time-series data into (samples, time steps, features)
X_time_series = combined_data[selected_time_series_cols].values.reshape(
    len(combined_data),  # samples (number of patients)
    -1,                  # time steps (number of segments per patient)
    3                    # features (3 values per time point: Mean_HR, HRV_SDNN, LF_HF_Ratio)
)

X_time_series.shape

(607, 288, 3)

In [19]:
# Load the neural network model from the specified path
from tensorflow.keras.models import load_model

# Load the model from models/neural_network.keras
best_model = load_model('models/neural_network.keras')

# Helper function to set TTM protocol and generate predictions
def generate_predictions_for_ttm(ttm_value, combined_data, model):
    # Make a copy of the data to avoid altering the original DataFrame
    data_copy = combined_data.copy()
    
    # Set TTM protocol columns based on the specified ttm_value
    if ttm_value == '33':
        data_copy['ttm_33.0'] = 1
        data_copy['ttm_36.0'] = 0
        data_copy['ttm_No TTM'] = 0
    elif ttm_value == '36':
        data_copy['ttm_33.0'] = 0
        data_copy['ttm_36.0'] = 1
        data_copy['ttm_No TTM'] = 0
    elif ttm_value == 'No TTM':
        data_copy['ttm_33.0'] = 0
        data_copy['ttm_36.0'] = 0
        data_copy['ttm_No TTM'] = 1
    
    # Extract tabular and time-series data with proper shape
    X_tabular = data_copy[tabular_cols].values.astype('float32')
    X_time_series = data_copy[selected_time_series_cols].values.astype('float32').reshape(len(data_copy), 288, 3)
    
    # Generate predictions
    predictions = model.predict([X_tabular, X_time_series])
    return predictions

# Function to determine the original TTM protocol based on the columns
def get_original_ttm_protocol(row):
    if row['ttm_33.0'] == 1:
        return '33'
    elif row['ttm_36.0'] == 1:
        return '36'
    elif row['ttm_No TTM'] == 1:
        return 'No TTM'
    return 'Unknown'

# Add original TTM protocol column to the combined_data DataFrame
combined_data['Original_TTM'] = combined_data.apply(get_original_ttm_protocol, axis=1)

# Generate predictions for each TTM protocol
predictions_ttm_33 = generate_predictions_for_ttm('33', combined_data, best_model)
predictions_ttm_36 = generate_predictions_for_ttm('36', combined_data, best_model)
predictions_no_ttm = generate_predictions_for_ttm('No TTM', combined_data, best_model)

# Combine predictions into a DataFrame for easy comparison
ttm_predictions_df = pd.DataFrame({
    'Patient': combined_data['Patient'],
    'Original_TTM': combined_data['Original_TTM'],
    'Outcome': combined_data['outcome'],
    'Pred_TTM_33': predictions_ttm_33.flatten(),
    'Pred_TTM_36': predictions_ttm_36.flatten(),
    'Pred_No_TTM': predictions_no_ttm.flatten()
})

# Display 20 random samples from the predictions DataFrame
ttm_predictions_df.sample(20, random_state=42)



  saveable.load_own_variables(weights_store.get(inner_path))


[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 75ms/step
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 54ms/step
[1m19/19[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 50ms/step


Unnamed: 0,Patient,Original_TTM,Outcome,Pred_TTM_33,Pred_TTM_36,Pred_No_TTM
563,970,33,0,0.519208,0.519119,0.53349
289,655,33,0,0.593681,0.587946,0.593627
76,406,36,0,0.589975,0.552358,0.584292
78,409,33,0,0.481568,0.514257,0.51233
182,529,36,0,0.52478,0.523994,0.513721
495,890,33,0,0.462534,0.489999,0.502369
10,319,33,1,0.522154,0.498199,0.507243
131,465,No TTM,0,0.525421,0.485569,0.516957
445,835,33,1,0.477855,0.495791,0.504191
86,417,33,0,0.525144,0.526891,0.523006


In [36]:
# Define the threshold for considering predictions as "Same"
threshold = 0.02

# Function to assess each prediction against the original outcome with debug information
def assess_prediction(row, original_protocol, protocol_prediction, protocol_name):
    # Select the original prediction value based on the original TTM protocol
    if original_protocol == '33':
        original_prediction = row['Pred_TTM_33']
    elif original_protocol == '36':
        original_prediction = row['Pred_TTM_36']
    elif original_protocol == 'No TTM':
        original_prediction = row['Pred_No_TTM']
    else:
        return "Unknown Original TTM"

    if original_protocol == protocol_name:
        return "Original TTM"
    
    # Calculate the difference and check with threshold
    diff = protocol_prediction - original_prediction
    print(f"Original protocol: {original_protocol}, Original value: {original_prediction}, "
          f"{protocol_name} is {protocol_prediction}, Difference: {abs(diff)}")
    
    if abs(diff) <= threshold:
        return "Same"
    elif diff > 0:
        return "Higher"
    else:
        return "Lower"

# Add assessment columns for each TTM protocol with the corrected original value
ttm_predictions_df['TTM_33_Asses'] = ttm_predictions_df.apply(
    lambda row: assess_prediction(row, row['Original_TTM'], row['Pred_TTM_33'], '33'), axis=1
)
ttm_predictions_df['TTM_36_Asses'] = ttm_predictions_df.apply(
    lambda row: assess_prediction(row, row['Original_TTM'], row['Pred_TTM_36'], '36'), axis=1
)
ttm_predictions_df['No_TTM_Asses'] = ttm_predictions_df.apply(
    lambda row: assess_prediction(row, row['Original_TTM'], row['Pred_No_TTM'], 'No TTM'), axis=1
)

# Define color mapping function for assessments
def color_assessment(val):
    if val == "Lower":
        return "background-color: lightcoral"
    elif val == "Higher":
        return "background-color: lightgreen"
    elif val == "Same":
        return "background-color: lightyellow"
    elif val == "Original TTM":
        return "background-color: lightblue"
    return ""

# Display the DataFrame with conditional formatting
# Selecting relevant columns and applying color to assessment columns
ttm_predictions_df[['Patient', 'Outcome', 'Original_TTM', 
                    'Pred_TTM_33', 'TTM_33_Asses', 
                    'Pred_TTM_36', 'TTM_36_Asses', 
                    'Pred_No_TTM', 'No_TTM_Asses']].sample(20, random_state=42)\
    .style.applymap(color_assessment, subset=['TTM_33_Asses', 'TTM_36_Asses', 'No_TTM_Asses'])


Original protocol: No TTM, Original value: 0.6255857944488525, 33 is 0.6533419489860535, Difference: 0.027756154537200928
Original protocol: 36, Original value: 0.494961142539978, 33 is 0.4765375256538391, Difference: 0.018423616886138916
Original protocol: No TTM, Original value: 0.5136523246765137, 33 is 0.4996236562728882, Difference: 0.014028668403625488
Original protocol: 36, Original value: 0.551856279373169, 33 is 0.6071951389312744, Difference: 0.05533885955810547
Original protocol: 36, Original value: 0.5861823558807373, 33 is 0.5956017971038818, Difference: 0.009419441223144531
Original protocol: 36, Original value: 0.49919047951698303, 33 is 0.5093176364898682, Difference: 0.010127156972885132
Original protocol: No TTM, Original value: 0.5094816088676453, 33 is 0.5149195790290833, Difference: 0.005437970161437988
Original protocol: 36, Original value: 0.5193232893943787, 33 is 0.521173357963562, Difference: 0.0018500685691833496
Original protocol: No TTM, Original value: 0.5

  .style.applymap(color_assessment, subset=['TTM_33_Asses', 'TTM_36_Asses', 'No_TTM_Asses'])


Unnamed: 0,Patient,Outcome,Original_TTM,Pred_TTM_33,TTM_33_Asses,Pred_TTM_36,TTM_36_Asses,Pred_No_TTM,No_TTM_Asses
563,970,0,33,0.519208,Original TTM,0.519119,Same,0.53349,Same
289,655,0,33,0.593681,Original TTM,0.587946,Same,0.593627,Same
76,406,0,36,0.589975,Higher,0.552358,Original TTM,0.584292,Higher
78,409,0,33,0.481568,Original TTM,0.514257,Higher,0.51233,Higher
182,529,0,36,0.52478,Same,0.523994,Original TTM,0.513721,Same
495,890,0,33,0.462534,Original TTM,0.489999,Higher,0.502369,Higher
10,319,1,33,0.522154,Original TTM,0.498199,Lower,0.507243,Same
131,465,0,No TTM,0.525421,Same,0.485569,Lower,0.516957,Original TTM
445,835,1,33,0.477855,Original TTM,0.495791,Same,0.504191,Higher
86,417,0,33,0.525144,Original TTM,0.526891,Same,0.523006,Same
