<center>
    <h1>Target Trial Emulation</h1>
    By Christian Abay-abay & Thristan Jay Nakila
</center>

## **Instructions**

1. Extract the dummy data from [RPubs - TTE](https://rpubs.com/alanyang0924/TTE) and save it as `data_censored.csv`.
2. Convert the R code to Python in a Jupyter Notebook, ensuring the results match the original.
3. Create a second version (`TTE-v2.ipynb`) with additional analysis.
4. Integrate clustering in `TTE-v2`, determine where it fits, and generate insights.
5. Work in pairs, preferably with your thesis partner.
6. Push your Jupyter Notebooks (`TTE.ipynb` and `TTE-v2.ipynb`) to GitHub.
7. 📅 **Deadline:** February 28, 2025, at **11:59 PM**.


***
<center>
    <h2>R Code converted to Python</h2>
    R Code from [RPubs - TTE](https://rpubs.com/alanyang0924/TTE) converted to Python code for this notebook.
</center>

In [91]:
import os
import pandas as pd
import numpy as np
import statsmodels.api as sm
import matplotlib.pyplot as plt
from scipy.stats import norm


In [92]:
import os
import pickle
import pandas as pd
import statsmodels.api as sm
from sklearn.cluster import KMeans
from sklearn.preprocessing import Binarizer

# Define the TrialEmulation class
class TrialEmulation:
    def __init__(self, data=None, id_col=None, period_col=None, treatment_col=None, outcome_col=None, eligible_col=None):
        self.data = data
        self.id_col = id_col
        self.period_col = period_col
        self.treatment_col = treatment_col
        self.outcome_col = outcome_col
        self.eligible_col = eligible_col
        self.censor_weights = None
        self.switch_weights = {}  # Initialize switch weights as an empty dictionary
###################################################
    def set_data(self, data, id_col, period_col, treatment_col, outcome_col, eligible_col):
        """Set the dataset and column mappings."""
        if not isinstance(data, pd.DataFrame):
            raise ValueError("Data must be a pandas DataFrame.")
        
        self.data = data.copy()  # Explicit copy to prevent unintended modifications
        self.id_col = id_col
        self.period_col = period_col
        self.treatment_col = treatment_col
        self.outcome_col = outcome_col
        self.eligible_col = eligible_col
        return self  # Allow method chaining
###################################################
    def show(self, num_rows=5):
        """Display dataset head."""
        if self.data is not None:
            print(f"Dataset Head ({num_rows} rows):")
            print(self.data.head(num_rows))
        else:
            print("No data available in this instance.")
###################################################
    def set_switch_weight_model(self, numerator, denominator, model_fitter):
        """Set the switch weight model formulas and the model fitter."""
        if not callable(model_fitter):
            raise ValueError("model_fitter must be a callable function.")
        
        self.switch_weights = {
            "numerator": numerator,
            "denominator": denominator,
            "model_fitter": model_fitter
        }
        return self  # Allow method chaining
###################################################
    def calculate_weights(self, save_path=None):
        """Fit logistic regression models for treatment weighting."""
        if self.data is None:
            raise ValueError("No data has been set.")
        if not self.switch_weights:
            raise ValueError("Switch weight model has not been set.")

        try:
            num_features = self.switch_weights["numerator"].split(" + ")
            denom_features = self.switch_weights["denominator"].split(" + ")
            X_num = sm.add_constant(self.data[num_features])  # Add intercept
            X_denom = sm.add_constant(self.data[denom_features])  # Add intercept
            y = self.data[self.treatment_col]
        except KeyError as e:
            raise ValueError(f"Column not found in data: {e}")

        # Fit logistic regression models
        model_num = sm.Logit(y, X_num).fit(disp=0)
        model_denom = sm.Logit(y, X_denom).fit(disp=0)

        # Save models if path is provided
        if save_path:
            os.makedirs(save_path, exist_ok=True)
            with open(os.path.join(save_path, "numerator_model.pkl"), "wb") as f:
                pickle.dump(model_num, f)
            with open(os.path.join(save_path, "denominator_model.pkl"), "wb") as f:
                pickle.dump(model_denom, f)

        self.switch_weights["fitted"] = True
        print("Weight models fitted and saved at:", save_path)
###################################################
    def set_censor_weight_model(self, censor_event, numerator, denominator, pool_models, model_fitter):
        """Set the censoring weight model with logistic regression."""
        if censor_event not in self.data.columns:
            raise ValueError(f"Censor event column '{censor_event}' not found in data.")

        self.censor_event = censor_event
        self.numerator = numerator
        self.denominator = denominator
        self.pool_models = pool_models
        self.model_fitter = model_fitter

        # Fit numerator model: 1 - censored ~ x2
        X_numerator = sm.add_constant(self.data[[col.strip() for col in numerator.split("+")]])
        y_numerator = 1 - self.data[censor_event].astype(int)  # Ensure binary
        numerator_model = sm.Logit(y_numerator, X_numerator).fit(disp=0)

        # Fit denominator model: 1 - censored ~ x2 + x1
        X_denominator = sm.add_constant(self.data[[col.strip() for col in denominator.split("+")]])
        y_denominator = 1 - self.data[censor_event].astype(int)  # Ensure binary
        denominator_model = sm.Logit(y_denominator, X_denominator).fit(disp=0)

        self.censor_weights = {
            "Numerator formula": f"1 - {censor_event} ~ {numerator}",
            "Denominator formula": f"1 - {censor_event} ~ {denominator}",
            "Model fitter type": "te_stats_glm_logit",
            "Weight models fitted": False,
            "Numerator Model": numerator_model,
            "Denominator Model": denominator_model
        }
        print(f"Set censoring weight model:\n{self.censor_weights}")
        return self
###################################################
    def show_weight_models(self):
        """Display the fitted weight models."""
        if not self.switch_weights.get("fitted"):
            print("Switch weight models have not been fitted yet.")
        else:
            print("Switch Weight Models:")
            print(self.switch_weights)
###################################################
    def apply_clustering(self, features, n_clusters=3):
        """Apply K-Means clustering to the dataset."""
        if not all(feature in self.data.columns for feature in features):
            raise ValueError("Some clustering features are missing from the dataset.")
        
        kmeans = KMeans(n_clusters=n_clusters, random_state=42)
        self.data['cluster'] = kmeans.fit_predict(self.data[features])
        print("Clustering applied successfully.")
###################################################
    def apply_dichotomization(self, column, threshold=None):
        """Dichotomize a column based on a given threshold or median."""
        if column not in self.data.columns:
            raise ValueError(f"Column '{column}' not found in data.")
        
        if threshold is None:
            threshold = self.data[column].median()
        binarizer = Binarizer(threshold=threshold)
        self.data[column + "_binary"] = binarizer.fit_transform(self.data[[column]].values.reshape(-1, 1))
        print(f"Dichotomization applied to {column} with threshold {threshold}.")
###################################################
    def __repr__(self):
        """Provide a string representation for debugging."""
        return (f"TrialEmulation(id_col='{self.id_col}', period_col='{self.period_col}', "
                f"treatment_col='{self.treatment_col}', outcome_col='{self.outcome_col}', "
                f"eligible_col='{self.eligible_col}', data_shape={self.data.shape if self.data is not None else None})")


# 1. SETUP

In [93]:

# Load data
file_path = "data/data_censored.csv"
data_censored = pd.read_csv(file_path)


# Create directories for saving models
trial_pp_dir = os.path.join(os.getcwd(), "trial_pp")
trial_itt_dir = os.path.join(os.getcwd(), "trial_itt")
os.makedirs(trial_pp_dir, exist_ok=True)
os.makedirs(trial_itt_dir, exist_ok=True)


# 2. Data Preparation

In [94]:
print("Dataset Head:")
print(data_censored.head())

Dataset Head:
   id  period  treatment  x1        x2  x3        x4  age     age_s  outcome  \
0   1       0          1   1  1.146148   0  0.734203   36  0.083333        0   
1   1       1          1   1  0.002200   0  0.734203   37  0.166667        0   
2   1       2          1   0 -0.481762   0  0.734203   38  0.250000        0   
3   1       3          1   0  0.007872   0  0.734203   39  0.333333        0   
4   1       4          1   1  0.216054   0  0.734203   40  0.416667        0   

   censored  eligible  
0         0         1  
1         0         0  
2         0         0  
3         0         0  
4         0         0  


In [95]:
# Initialize and set data for per-protocol analysis
trial_pp = TrialEmulation().set_data(
    data_censored, 
    id_col="id", 
    period_col="period", 
    treatment_col="treatment", 
    outcome_col="outcome", 
    eligible_col="eligible"
)

# Initialize and set data for ITT (Intention-To-Treat) without method chaining
trial_itt = TrialEmulation()
trial_itt.set_data(
    data_censored, 
    id_col="id", 
    period_col="period", 
    treatment_col="treatment", 
    outcome_col="outcome", 
    eligible_col="eligible"
)

print("---------------------------------------------------------------")
# Show the first few rows of trial_pp
trial_pp.show()


---------------------------------------------------------------
Dataset Head (5 rows):
   id  period  treatment  x1        x2  x3        x4  age     age_s  outcome  \
0   1       0          1   1  1.146148   0  0.734203   36  0.083333        0   
1   1       1          1   1  0.002200   0  0.734203   37  0.166667        0   
2   1       2          1   0 -0.481762   0  0.734203   38  0.250000        0   
3   1       3          1   0  0.007872   0  0.734203   39  0.333333        0   
4   1       4          1   1  0.216054   0  0.734203   40  0.416667        0   

   censored  eligible  
0         0         1  
1         0         0  
2         0         0  
3         0         0  
4         0         0  


# 3. Weight models and censoring

## 3.1 Censoring due to treatment switching

In [96]:
# Set switch weight model
trial_pp.set_switch_weight_model(
    numerator="age",
    denominator="age + x1 + x3",
    model_fitter="te_stats_glm_logit"
)
trial_pp.calculate_weights(save_path=os.path.join(trial_pp_dir, "switch_models"))

ValueError: model_fitter must be a callable function.

## 3.2 Other informative censoring

In [97]:
# Apply censoring weight model to trial_pp
trial_pp.set_censor_weight_model(
    censor_event="censored",
    numerator="x2",
    denominator="x2 + x1",
    pool_models="none",
    model_fitter=f"{trial_pp_dir}/switch_models"
)

# Apply censoring weight model to trial_itt
trial_itt.set_censor_weight_model(
    censor_event="censored",
    numerator="x2",
    denominator="x2 + x1",
    pool_models="numerator",
    model_fitter=f"{trial_itt_dir}/switch_models"
)

Set censoring weight model:
{'Numerator formula': '1 - censored ~ x2', 'Denominator formula': '1 - censored ~ x2 + x1', 'Model fitter type': 'te_stats_glm_logit', 'Weight models fitted': False, 'Numerator Model': <statsmodels.discrete.discrete_model.BinaryResultsWrapper object at 0x00000256CB958230>, 'Denominator Model': <statsmodels.discrete.discrete_model.BinaryResultsWrapper object at 0x00000256CBB79430>}
Set censoring weight model:
{'Numerator formula': '1 - censored ~ x2', 'Denominator formula': '1 - censored ~ x2 + x1', 'Model fitter type': 'te_stats_glm_logit', 'Weight models fitted': False, 'Numerator Model': <statsmodels.discrete.discrete_model.BinaryResultsWrapper object at 0x00000256CBBDAE70>, 'Denominator Model': <statsmodels.discrete.discrete_model.BinaryResultsWrapper object at 0x00000256CBBD80B0>}


TrialEmulation(id_col='id', period_col='period', treatment_col='treatment', outcome_col='outcome', eligible_col='eligible', data_shape=(725, 12))

# 4. Calculate Weights

In [98]:
trial_pp = trial_pp.calculate_weights()
trial_itt = trial_itt.calculate_weights()

trial_itt.show_weight_models()
trial_pp.show_weight_models()

ValueError: Switch weight model has not been set.