<center>
    <h1>Target Trial Emulation</h1>
    By Christian Abay-abay & Thristan Jay Nakila
</center>

## **Instructions**

1. Extract the dummy data from [RPubs - TTE](https://rpubs.com/alanyang0924/TTE) and save it as `data_censored.csv`.
2. Convert the R code to Python in a Jupyter Notebook, ensuring the results match the original.
3. Create a second version (`TTE-v2.ipynb`) with additional analysis.
4. Integrate clustering in `TTE-v2`, determine where it fits, and generate insights.
5. Work in pairs, preferably with your thesis partner.
6. Push your Jupyter Notebooks (`TTE.ipynb` and `TTE-v2.ipynb`) to GitHub.
7. 📅 **Deadline:** February 28, 2025, at **11:59 PM**.


***
<center>
    <h2>R Code converted to Python</h2>
    R Code from [RPubs - TTE](https://rpubs.com/alanyang0924/TTE) converted to Python code for this notebook.
</center>

In [91]:
import os
import pandas as pd
import numpy as np
import statsmodels.api as sm
import matplotlib.pyplot as plt
from scipy.stats import norm


In [20]:
import os
import pickle
import pandas as pd
import statsmodels.api as sm
from sklearn.cluster import KMeans
from sklearn.preprocessing import Binarizer

def stats_glm_logit(y, X):
    return sm.Logit(y, X).fit(disp=0)
# Define the TrialEmulation class
class TrialEmulation:
    def __init__(self, data=None, id_col=None, period_col=None, treatment_col=None, outcome_col=None, eligible_col=None):
        self.data = data
        self.id_col = id_col
        self.period_col = period_col
        self.treatment_col = treatment_col
        self.outcome_col = outcome_col
        self.eligible_col = eligible_col
        self.censor_weights = None
        self.switch_weights = {}  # Initialize switch weights as an empty dictionary
###################################################
    def set_data(self, data, id_col, period_col, treatment_col, outcome_col, eligible_col):
        if not isinstance(data, pd.DataFrame):
            raise ValueError("Data must be a pandas DataFrame.")
        self.data = data.copy()
        self.id_col = id_col
        self.period_col = period_col
        self.treatment_col = treatment_col
        self.outcome_col = outcome_col
        self.eligible_col = eligible_col
        self.data['prev_treatment'] = self.data.groupby(self.id_col)[self.treatment_col].shift(1)
        return self
    
###################################################
    def show(self, num_rows=5):
        """Display dataset head."""
        if self.data is not None:
            print(f"Dataset Head ({num_rows} rows):")
            print(self.data.head(num_rows))
        else:
            print("No data available in this instance.")
###################################################
    def set_switch_weight_model(self, numerator, denominator, model_fitter):
        """Set the switch weight model formulas and the model fitter."""
        # Allow model_fitter to be a string identifier or callable; we'll handle it later
        if not isinstance(model_fitter, str) and not callable(model_fitter):
            raise ValueError("model_fitter must be a string identifier or a callable function.")
        
        self.switch_weights = {
            "numerator": numerator,
            "denominator": denominator,
            "model_fitter": model_fitter
        }
        return self  # Allow method chaining
###################################################
    def calculate_weights(self, save_path=None):
        """Calculate stabilized weights for treatment switching and censoring."""
        if self.data is None:
            raise ValueError("No data has been set.")
    
        # Ensure prev_treatment is available
        if 'prev_treatment' not in self.data.columns:
            self.data['prev_treatment'] = self.data.groupby(self.id_col)[self.treatment_col].shift(1)
    
        # 1. Treatment Switching Weights (for PP only)
        if self.switch_weights and self.switch_weights.get("numerator"):
            try:
                num_features = [f.strip() for f in self.switch_weights["numerator"].split(" + ")]
                denom_features = [f.strip() for f in self.switch_weights["denominator"].split(" + ")]
                y = self.data[self.treatment_col]
    
                # Subset data where prev_treatment is defined (period > 0)
                data_switch = self.data.dropna(subset=['prev_treatment'])
                y_switch = data_switch[self.treatment_col]
    
                # Fit models stratified by prev_treatment
                switch_weights = {}
                numerator_models = {}
                denominator_models = {}
                for prev_trt in [0, 1]:
                    subset = data_switch[data_switch['prev_treatment'] == prev_trt]
                    y_current = y_switch.loc[subset.index]
                    X_num = sm.add_constant(subset[num_features])
                    X_denom = sm.add_constant(subset[denom_features])
    
                    num_model = self.switch_weights["model_fitter"](y_current, X_num)
                    denom_model = self.switch_weights["model_fitter"](y_current, X_denom)
                    numerator_models[prev_trt] = num_model
                    denominator_models[prev_trt] = denom_model
    
                    # Predict probabilities
                    p_num = num_model.predict(X_num)
                    p_denom = denom_model.predict(X_denom)
                    switch_weights[prev_trt] = p_num / p_denom  # Stabilized weights
    
                # Assign weights back to data
                self.data['switch_weight'] = 1.0  # Default for period = 0
                for prev_trt in [0, 1]:
                    mask = self.data['prev_treatment'] == prev_trt
                    indices = data_switch[data_switch['prev_treatment'] == prev_trt].index
                    self.data.loc[indices, 'switch_weight'] = switch_weights[prev_trt]
    
                self.switch_weights["numerator_models"] = numerator_models
                self.switch_weights["denominator_models"] = denominator_models
                self.switch_weights["fitted"] = True
    
            except KeyError as e:
                raise ValueError(f"Column not found in data for switch weights: {e}")
    
        # 2. Censoring Weights (for both PP and ITT)
        if self.censor_weights:
            try:
                num_features = [f.strip() for f in self.numerator_formula.split("+")]
                denom_features = [f.strip() for f in self.denominator_formula.split("+")]
                y_censor = 1 - self.data[self.censor_event].astype(int)
    
                if self.pool_models == "numerator":
                    # Pooled numerator model
                    X_num = sm.add_constant(self.data[num_features])
                    numerator_model = sm.Logit(y_censor, X_num).fit(disp=0)
                    p_num_all = numerator_model.predict(X_num)
    
                    # Denominator models by prev_treatment (period > 0)
                    data_censor = self.data.dropna(subset=['prev_treatment'])
                    y_censor_subset = 1 - data_censor[self.censor_event].astype(int)
                    denominator_models = {}
                    censor_weights = {}
                    for prev_trt in [0, 1]:
                        subset = data_censor[data_censor['prev_treatment'] == prev_trt]
                        X_denom = sm.add_constant(subset[denom_features])
                        denom_model = sm.Logit(y_censor_subset.loc[subset.index], X_denom).fit(disp=0)
                        denominator_models[prev_trt] = denom_model
                        p_denom = denom_model.predict(X_denom)
                        p_num_subset = p_num_all.loc[subset.index]
                        censor_weights[prev_trt] = p_num_subset / p_denom
    
                    self.censor_weights["Numerator Model"] = numerator_model
                    self.censor_weights["Denominator Models"] = denominator_models
    
                elif self.pool_models == "none":
                    # Both models stratified by prev_treatment (period > 0)
                    data_censor = self.data.dropna(subset=['prev_treatment'])
                    y_censor_subset = 1 - data_censor[self.censor_event].astype(int)
                    numerator_models = {}
                    denominator_models = {}
                    censor_weights = {}
                    for prev_trt in [0, 1]:
                        subset = data_censor[data_censor['prev_treatment'] == prev_trt]
                        y_current = y_censor_subset.loc[subset.index]
                        X_num = sm.add_constant(subset[num_features])
                        X_denom = sm.add_constant(subset[denom_features])
    
                        num_model = sm.Logit(y_current, X_num).fit(disp=0)
                        denom_model = sm.Logit(y_current, X_denom).fit(disp=0)
                        numerator_models[prev_trt] = num_model
                        denominator_models[prev_trt] = denom_model
    
                        p_num = num_model.predict(X_num)
                        p_denom = denom_model.predict(X_denom)
                        censor_weights[prev_trt] = p_num / p_denom
    
                    self.censor_weights["Numerator Models"] = numerator_models
                    self.censor_weights["Denominator Models"] = denominator_models
    
                # Assign censoring weights to data
                self.data['censor_weight'] = 1.0  # Default for period = 0
                for prev_trt in [0, 1]:
                    mask = self.data['prev_treatment'] == prev_trt
                    indices = data_censor[data_censor['prev_treatment'] == prev_trt].index
                    self.data.loc[indices, 'censor_weight'] = censor_weights[prev_trt]
    
                self.censor_weights["Weight models fitted"] = True
    
            except KeyError as e:
                raise ValueError(f"Column not found in data for censor weights: {e}")

    # Save models if path is provided
    if save_path:
        os.makedirs(save_path, exist_ok=True)
        if self.switch_weights.get("fitted"):
            for prev_trt in [0, 1]:
                with open(os.path.join(save_path, f"switch_num_model_prev_{prev_trt}.pkl"), "wb") as f:
                    pickle.dump(self.switch_weights["numerator_models"][prev_trt], f)
                with open(os.path.join(save_path, f"switch_denom_model_prev_{prev_trt}.pkl"), "wb") as f:
                    pickle.dump(self.switch_weights["denominator_models"][prev_trt], f)
        if self.censor_weights and self.censor_weights["Weight models fitted"]:
            if self.pool_models == "numerator":
                with open(os.path.join(save_path, "censor_num_model.pkl"), "wb") as f:
                    pickle.dump(self.censor_weights["Numerator Model"], f)
                for prev_trt in [0, 1]:
                    with open(os.path.join(save_path, f"censor_denom_model_prev_{prev_trt}.pkl"), "wb") as f:
                        pickle.dump(self.censor_weights["Denominator Models"][prev_trt], f)
            elif self.pool_models == "none":
                for prev_trt in [0, 1]:
                    with open(os.path.join(save_path, f"censor_num_model_prev_{prev_trt}.pkl"), "wb") as f:
                        pickle.dump(self.censor_weights["Numerator Models"][prev_trt], f)
                    with open(os.path.join(save_path, f"censor_denom_model_prev_{prev_trt}.pkl"), "wb") as f:
                        pickle.dump(self.censor_weights["Denominator Models"][prev_trt], f)

    print("Weights calculated and models saved at:", save_path)
    return self
    ###################################################
    def set_censor_weight_model(self, censor_event, numerator, denominator, pool_models, model_fitter):
        if censor_event not in self.data.columns:
            raise ValueError(f"Censor event column '{censor_event}' not found in data.")

        self.censor_event = censor_event
        self.numerator_formula = numerator  # Use consistent attribute names
        self.denominator_formula = denominator
        self.pool_models = pool_models
        self.model_fitter = model_fitter
    
        # Extract feature lists from formulas
        num_features = [col.strip() for col in numerator.split("+")]
        denom_features = [col.strip() for col in denominator.split("+")]
        y = 1 - self.data[censor_event].astype(int)  # Probability of not being censored
    
        if pool_models == "numerator":
            # Pooled numerator model on all data
            X_num = sm.add_constant(self.data[num_features])
            numerator_model = sm.Logit(y, X_num).fit(disp=0)
    
            # Separate denominator models for each prev_treatment (period > 0)
            data_subset = self.data.dropna(subset=['prev_treatment'])  # Exclude period = 0
            y_subset = 1 - data_subset[censor_event].astype(int)
            denominator_models = {}
            for prev_trt in [0, 1]:
                subset = data_subset[data_subset['prev_treatment'] == prev_trt]
                X_denom = sm.add_constant(subset[denom_features])
                model = sm.Logit(y_subset.loc[subset.index], X_denom).fit(disp=0)
                denominator_models[prev_trt] = model
    
            self.censor_weights = {
                "Numerator formula": f"1 - {censor_event} ~ {numerator}",
                "Denominator formula": f"1 - {censor_event} ~ {denominator}",
                "Model fitter type": "te_stats_glm_logit",
                "Numerator Model": numerator_model,
                "Denominator Models": denominator_models
            }
    
        elif pool_models == "none":
            # Separate models for each prev_treatment (period > 0)
            data_subset = self.data.dropna(subset=['prev_treatment'])  # Exclude period = 0
            y_subset = 1 - data_subset[censor_event].astype(int)
            numerator_models = {}
            denominator_models = {}
            for prev_trt in [0, 1]:
                subset = data_subset[data_subset['prev_treatment'] == prev_trt]
                y_current = y_subset.loc[subset.index]
                # Numerator model
                X_num = sm.add_constant(subset[num_features])
                num_model = sm.Logit(y_current, X_num).fit(disp=0)
                numerator_models[prev_trt] = num_model
                # Denominator model
                X_denom = sm.add_constant(subset[denom_features])
                denom_model = sm.Logit(y_current, X_denom).fit(disp=0)
                denominator_models[prev_trt] = denom_model
    
            self.censor_weights = {
                "Numerator formula": f"1 - {censor_event} ~ {numerator}",
                "Denominator formula": f"1 - {censor_event} ~ {denominator}",
                "Model fitter type": "te_stats_glm_logit",
                "Numerator Models": numerator_models,
                "Denominator Models": denominator_models
            }
    
        else:
            raise ValueError("pool_models must be 'none' or 'numerator'.")
    
        print(f"Set censoring weight model:\n{self.censor_weights}")
        return self
###################################################
    def show_weight_models(self):
        """Display summaries of fitted weight models."""
        if self.switch_weights.get("fitted"):
            print("Weight Models for Treatment Switching")
            print("-------------------------------------")
            for prev_trt in [0, 1]:
                print(f"\nNumerator Model (prev_treatment = {prev_trt}):")
                print(self.switch_weights["numerator_models"][prev_trt].summary())
                print(f"\nDenominator Model (prev_treatment = {prev_trt}):")
                print(self.switch_weights["denominator_models"][prev_trt].summary())
    
        if self.censor_weights and self.censor_weights.get("Weight models fitted"):
            print("\nWeight Models for Informative Censoring")
            print("---------------------------------------")
            if self.pool_models == "numerator":
                print("\nPooled Numerator Model:")
                print(self.censor_weights["Numerator Model"].summary())
                for prev_trt in [0, 1]:
                    print(f"\nDenominator Model (prev_treatment = {prev_trt}):")
                    print(self.censor_weights["Denominator Models"][prev_trt].summary())
            elif self.pool_models == "none":
                for prev_trt in [0, 1]:
                    print(f"\nNumerator Model (prev_treatment = {prev_trt}):")
                    print(self.censor_weights["Numerator Models"][prev_trt].summary())
                    print(f"\nDenominator Model (prev_treatment = {prev_trt}):")
                    print(self.censor_weights["Denominator Models"][prev_trt].summary())
###################################################
    def apply_clustering(self, features, n_clusters=3):
        """Apply K-Means clustering to the dataset."""
        if not all(feature in self.data.columns for feature in features):
            raise ValueError("Some clustering features are missing from the dataset.")
        
        kmeans = KMeans(n_clusters=n_clusters, random_state=42)
        self.data['cluster'] = kmeans.fit_predict(self.data[features])
        print("Clustering applied successfully.")
###################################################
    def apply_dichotomization(self, column, threshold=None):
        """Dichotomize a column based on a given threshold or median."""
        if column not in self.data.columns:
            raise ValueError(f"Column '{column}' not found in data.")
        
        if threshold is None:
            threshold = self.data[column].median()
        binarizer = Binarizer(threshold=threshold)
        self.data[column + "_binary"] = binarizer.fit_transform(self.data[[column]].values.reshape(-1, 1))
        print(f"Dichotomization applied to {column} with threshold {threshold}.")
###################################################
    def __repr__(self):
        """Provide a string representation for debugging."""
        return (f"TrialEmulation(id_col='{self.id_col}', period_col='{self.period_col}', "
                f"treatment_col='{self.treatment_col}', outcome_col='{self.outcome_col}', "
                f"eligible_col='{self.eligible_col}', data_shape={self.data.shape if self.data is not None else None})")


SyntaxError: 'return' outside function (3957436914.py, line 201)

# 1. SETUP

In [21]:

# Load data
file_path = "data/data_censored.csv"
data_censored = pd.read_csv(file_path)


# Create directories for saving models
trial_pp_dir = os.path.join(os.getcwd(), "trial_pp")
trial_itt_dir = os.path.join(os.getcwd(), "trial_itt")
os.makedirs(trial_pp_dir, exist_ok=True)
os.makedirs(trial_itt_dir, exist_ok=True)


# 2. Data Preparation

In [22]:
print("Dataset Head:")
print(data_censored.head())

Dataset Head:
   id  period  treatment  x1        x2  x3        x4  age     age_s  outcome  \
0   1       0          1   1  1.146148   0  0.734203   36  0.083333        0   
1   1       1          1   1  0.002200   0  0.734203   37  0.166667        0   
2   1       2          1   0 -0.481762   0  0.734203   38  0.250000        0   
3   1       3          1   0  0.007872   0  0.734203   39  0.333333        0   
4   1       4          1   1  0.216054   0  0.734203   40  0.416667        0   

   censored  eligible  
0         0         1  
1         0         0  
2         0         0  
3         0         0  
4         0         0  


In [23]:
# Initialize and set data for per-protocol analysis
trial_pp = TrialEmulation().set_data(
    data_censored, 
    id_col="id", 
    period_col="period", 
    treatment_col="treatment", 
    outcome_col="outcome", 
    eligible_col="eligible"
)

# Initialize and set data for ITT (Intention-To-Treat) without method chaining
trial_itt = TrialEmulation()
trial_itt.set_data(
    data_censored, 
    id_col="id", 
    period_col="period", 
    treatment_col="treatment", 
    outcome_col="outcome", 
    eligible_col="eligible"
)

print("---------------------------------------------------------------")
# Show the first few rows of trial_pp
trial_pp.show()


---------------------------------------------------------------
Dataset Head (5 rows):
   id  period  treatment  x1        x2  x3        x4  age     age_s  outcome  \
0   1       0          1   1  1.146148   0  0.734203   36  0.083333        0   
1   1       1          1   1  0.002200   0  0.734203   37  0.166667        0   
2   1       2          1   0 -0.481762   0  0.734203   38  0.250000        0   
3   1       3          1   0  0.007872   0  0.734203   39  0.333333        0   
4   1       4          1   1  0.216054   0  0.734203   40  0.416667        0   

   censored  eligible  prev_treatment  
0         0         1             NaN  
1         0         0             1.0  
2         0         0             1.0  
3         0         0             1.0  
4         0         0             1.0  


# 3. Weight models and censoring

## 3.1 Censoring due to treatment switching

In [24]:
# Set switch weight model
trial_pp.set_switch_weight_model(
    numerator="age",
    denominator="age + x1 + x3",
    model_fitter=stats_glm_logit
)


TrialEmulation(id_col='id', period_col='period', treatment_col='treatment', outcome_col='outcome', eligible_col='eligible', data_shape=(725, 13))

## 3.2 Other informative censoring

In [25]:
# Apply censoring weight model to trial_pp
trial_pp.set_censor_weight_model(
    censor_event="censored",
    numerator="x2",
    denominator="x2 + x1",
    pool_models="none",
    model_fitter=f"{trial_pp_dir}/switch_models"
)

# Apply censoring weight model to trial_itt
trial_itt.set_censor_weight_model(
    censor_event="censored",
    numerator="x2",
    denominator="x2 + x1",
    pool_models="numerator",
    model_fitter=f"{trial_itt_dir}/switch_models"
)

Set censoring weight model:
{'Numerator formula': '1 - censored ~ x2', 'Denominator formula': '1 - censored ~ x2 + x1', 'Model fitter type': 'te_stats_glm_logit', 'Numerator Models': {0: <statsmodels.discrete.discrete_model.BinaryResultsWrapper object at 0x0000024E062C9B50>, 1: <statsmodels.discrete.discrete_model.BinaryResultsWrapper object at 0x0000024E062CAB10>}, 'Denominator Models': {0: <statsmodels.discrete.discrete_model.BinaryResultsWrapper object at 0x0000024E062C8E90>, 1: <statsmodels.discrete.discrete_model.BinaryResultsWrapper object at 0x0000024E062CA3C0>}}
Set censoring weight model:
{'Numerator formula': '1 - censored ~ x2', 'Denominator formula': '1 - censored ~ x2 + x1', 'Model fitter type': 'te_stats_glm_logit', 'Numerator Model': <statsmodels.discrete.discrete_model.BinaryResultsWrapper object at 0x0000024E062C92B0>, 'Denominator Models': {0: <statsmodels.discrete.discrete_model.BinaryResultsWrapper object at 0x0000024E062CB7A0>, 1: <statsmodels.discrete.discrete_mod

TrialEmulation(id_col='id', period_col='period', treatment_col='treatment', outcome_col='outcome', eligible_col='eligible', data_shape=(725, 13))

# 4. Calculate Weights

In [26]:
trial_pp.calculate_weights(save_path=os.path.join(trial_pp_dir, "switch_models"))
trial_itt.calculate_weights(save_path=os.path.join(trial_itt_dir, "switch_models"))

trial_pp.show_weight_models()
trial_itt.show_weight_models()

Weight models fitted and saved at: C:\Users\thris\Documents\Repo\TTE_PY\trial_pp\switch_models


ValueError: Switch weight model has not been set.