# Assignment 1 for Clustering: Target Trial Emulation
- New and novel methods in Machine Learning are made either by borrowing formulas and concepts from other scientific fields and redefining it based on new sets of assumptions, or by adding an extra step to an already existing framework of methodology.

- In this exercise (Assignment 1 of the Clustering Topic), we will try to develop a novel method of Target Trial Emulation by integrating concepts of Clustering into the already existing framework. Target Trial Emulation is a new methodological framework in epidemiology which tries to account for the biases in old and traditional designs.

These are the instructions:
1. Look at this website: https://rpubs.com/alanyang0924/TTE
2. Extract the dummy data in the package and save it as "data_censored.csv"
2. Convert the R codes into Python Codes (use Jupyter Notebook), replicate the results using your python code.
3. Create another copy of your Python Codes, name it TTE-v2 (use Jupyter Notebook).
4. Using TTE-v2, think of a creative way on where you would integrate a clustering mechanism, understand each step carefully and decide at which step a clustering method can be implemented. Generate insights from your results.
5. Do this by pair, preferably your thesis partner.
6. Push to your github repository.
7. Deadline is: February 28, 2025 at 11:59 pm.

## I. Necessary Imports

In [19]:
import pandas as pd
import numpy as np
import os
import patsy
import joblib
import json
from sklearn.linear_model import LogisticRegression
from IPython.display import display
import statsmodels.api as sm
import statsmodels.formula.api as smf
from dataclasses import dataclass
from typing import List, Optional, Any




## II. Class Definition and Required Functions

In [20]:
def stats_glm_logit(save_path):
    if save_path is not None:
        os.makedirs(save_path, exist_ok=True)

    def fit_model(numerator, denominator, data):
        formula = numerator
        try:
            model = smf.logit(formula, data).fit(disp=0)  # Suppress convergence messages
        except (np.linalg.LinAlgError, sm.tools.sm_exceptions.PerfectSeparationError):
            print(f"Warning: Perfect separation or singular matrix detected for {formula}. Falling back to intercept-only model.")
            formula = f"{formula.split('~')[0].strip()} ~ 1"
            model = smf.logit(formula, data).fit(disp=0)
        model_path = os.path.join(save_path, "logit_model.pkl")
        joblib.dump(model, model_path)
        model_details = {
            "numerator": numerator,
            "denominator": denominator,
            "model_type": "te_stats_glm_logit",
            "file_path": model_path
        }
        json.dump(model_details, open(os.path.join(save_path, "model_details.json"), "w"))
        return model
    
    return fit_model

@dataclass
class TEDatastore:
    data: pd.DataFrame = None

    def save_expanded_data(self, switch_data: pd.DataFrame):
        if self.data is None:
            self.data = switch_data
        else:
            self.data = pd.concat([self.data, switch_data], ignore_index=True)
        return self

@dataclass
class TEExpansion:
    chunk_size: int = 0
    datastore: TEDatastore = None
    first_period: int = 0
    last_period: float = float('inf')
    censor_at_switch: bool = False

class TrialSequence:
    def __init__(self, estimand, **kwargs):
        self.estimand = estimand
        self.data = None
        self.censor_weights = None
        self.switch_weights = None
        self.outcome_model = None
        self.expansion = None
        self.outcome_data = None

    def set_data(self, data):
        self.data = data
        self.data["followup_time"] = self.data.groupby("id")["period"].transform(
            lambda x: x[(self.data.loc[x.index, "censored"] == 1) | (self.data.loc[x.index, "outcome"] == 1)].min()
            if ((self.data.loc[x.index, "censored"] == 1) | (self.data.loc[x.index, "outcome"] == 1)).any()
            else x.max()
        )

    def show(self):
        print(f"Trial Sequence Object\nEstimand: {self.estimand}\n")
        if self.data is not None:
            display(self.data)
        else:
            print("No data set")
        print("\nIPW for informative censoring:")
        print(self.censor_weights if self.censor_weights is not None else "Not calculated.")
        if self.switch_weights is not None:
            print("\nIPW for treatment switch censoring:")
            print(self.switch_weights)
        print("\nOutcome model:")
        print(self.outcome_model if self.outcome_model is not None else "Not specified.")
        if self.outcome_data is not None:
            print("\nOutcome data:")
            print(self.outcome_data)

    def set_switch_weight_model(self, numerator=None, denominator=None, model_fitter=None, eligible_wts_0=None, eligible_wts_1=None):
        if self.data is None:
            raise ValueError("set_data() before setting switch weight models")
        if self.estimand == "ITT":
            raise ValueError("Switching weights are not supported for intention-to-treat analyses")
        if eligible_wts_0 and eligible_wts_0 in self.data.columns:
            self.data = self.data.rename(columns={eligible_wts_0: "eligible_wts_0"})
        if eligible_wts_1 and eligible_wts_1 in self.data.columns:
            self.data = self.data.rename(columns={eligible_wts_1: "eligible_wts_1"})
        if numerator is None:
            numerator = "1"
        if denominator is None:
            denominator = "1"
        if "time_on_regime" in denominator:
            raise ValueError("time_on_regime should not be used in denominator.")
        formula_numerator = f"treatment ~ {numerator}"
        formula_denominator = f"treatment ~ {denominator}"
        self.switch_weights = {
            "numerator": formula_numerator,
            "denominator": formula_denominator,
            "model_fitter": "te_stats_glm_logit",
        }
        if model_fitter is not None:
            self.switch_weights["fitted_model_0_numerator"] = model_fitter(formula_numerator, denominator, self.data[self.data["previous_treatment"] == 0])
            self.switch_weights["fitted_model_1_numerator"] = model_fitter(formula_numerator, denominator, self.data[self.data["previous_treatment"] == 1])
            self.switch_weights["fitted_model_0_denominator"] = model_fitter(formula_denominator, denominator, self.data[self.data["previous_treatment"] == 0])
            self.switch_weights["fitted_model_1_denominator"] = model_fitter(formula_denominator, denominator, self.data[self.data["previous_treatment"] == 1])
            self.data["switch_prob_0"] = self.switch_weights["fitted_model_0_denominator"].predict(self.data[self.data["previous_treatment"] == 0])
            self.data["switch_prob_1"] = self.switch_weights["fitted_model_1_denominator"].predict(self.data[self.data["previous_treatment"] == 1])
            self.data["switch_weight"] = np.where(self.data["previous_treatment"] == 0, 
                                                  1 / self.data["switch_prob_0"], 
                                                  1 / self.data["switch_prob_1"])
            self.data["switch_weight"] = self.data["switch_weight"].fillna(1)
            print("✅ Switch weights computed and stored in self.data")

    def show_switch_weights(self):
        return self.switch_weights if self.switch_weights else "Not calculated"
    
    def show_censor_weights(self):
        return self.censor_weights if self.censor_weights else "Not calculated"
    
    def set_censor_weight_model(self, censor_event, numerator="1", denominator="1", pool_models="none", model_fitter=None):
        if model_fitter is None:
            model_fitter = stats_glm_logit()
        if censor_event not in self.data.columns:
            raise ValueError(f"'{censor_event}' must be a column in the dataset.")
        formula_numerator = f"1 - {censor_event} ~ {numerator}"
        formula_denominator = f"1 - {censor_event} ~ {denominator}"
        self.censor_weights = {
            "numerator": formula_numerator,
            "denominator": formula_denominator,
            "pool_numerator": pool_models in ["numerator", "both"],
            "pool_denominator": pool_models == "both",
            "model_fitter": "te_stats_glm_logit"
        }
        if self.estimand == "PP":
            self.censor_weights["fitted_model_0_numerator"] = model_fitter(formula_numerator, denominator, self.data[self.data["previous_treatment"] == 0])
            self.censor_weights["fitted_model_1_numerator"] = model_fitter(formula_numerator, denominator, self.data[self.data["previous_treatment"] == 1])
            self.censor_weights["fitted_model_0_denominator"] = model_fitter(formula_denominator, denominator, self.data[self.data["previous_treatment"] == 0])
            self.censor_weights["fitted_model_1_denominator"] = model_fitter(formula_denominator, denominator, self.data[self.data["previous_treatment"] == 1])
        elif self.estimand == "ITT":
            self.censor_weights["fitted_model_numerator"] = model_fitter(formula_numerator, denominator, self.data)
            if not self.censor_weights["pool_denominator"]:
                self.censor_weights["fitted_model_0_denominator"] = model_fitter(formula_denominator, denominator, self.data[self.data["previous_treatment"] == 0])
                self.censor_weights["fitted_model_1_denominator"] = model_fitter(formula_denominator, denominator, self.data[self.data["previous_treatment"] == 1])

    def calculate_weights(self, quiet=False):
        use_censor_weights = isinstance(self.censor_weights, dict) and (
            "fitted_model_0_denominator" in self.censor_weights or "fitted_model_numerator" in self.censor_weights
        )
        if self.estimand == "PP":
            if not (isinstance(self.switch_weights, dict) and "fitted_model_0_denominator" in self.switch_weights):
                raise ValueError("Switch weight models are not specified. Use set_switch_weight_model()")
            self._calculate_weights_trial_seq(quiet, switch_weights=True, censor_weights=use_censor_weights)
        elif self.estimand == "ITT":
            self._calculate_weights_trial_seq(quiet, switch_weights=False, censor_weights=use_censor_weights)
        else:
            raise ValueError(f"Unknown estimand: {self.estimand}")

    def _calculate_weights_trial_seq(self, quiet, switch_weights, censor_weights):
        if switch_weights:
            if not quiet:
                print("Calculating switch weights...")
            switch_model_0 = self.switch_weights["fitted_model_0_denominator"]
            switch_model_1 = self.switch_weights["fitted_model_1_denominator"]
            mask_0 = self.data["previous_treatment"] == 0
            mask_1 = self.data["previous_treatment"] == 1
            self.data.loc[mask_0, "switch_prob"] = switch_model_0.predict(self.data[mask_0])
            self.data.loc[mask_1, "switch_prob"] = switch_model_1.predict(self.data[mask_1])
            self.data["switch_prob"] = self.data["switch_prob"].fillna(1.0)
            self.data["switch_weight"] = 1 / self.data["switch_prob"]
        if censor_weights:
            if not quiet:
                print("Calculating censor weights...")
            if self.estimand == "PP":
                censor_model_0 = self.censor_weights["fitted_model_0_denominator"]
                censor_model_1 = self.censor_weights["fitted_model_1_denominator"]
                mask_0 = self.data["previous_treatment"] == 0
                mask_1 = self.data["previous_treatment"] == 1
                self.data.loc[mask_0, "censor_prob"] = censor_model_0.predict(self.data[mask_0])
                self.data.loc[mask_1, "censor_prob"] = censor_model_1.predict(self.data[mask_1])
            elif self.estimand == "ITT":
                censor_model = self.censor_weights["fitted_model_numerator"]
                self.data["censor_prob"] = censor_model.predict(self.data)
            self.data["censor_prob"] = self.data["censor_prob"].fillna(1.0)
            self.data["censor_weight"] = 1 / self.data["censor_prob"]
        if switch_weights and censor_weights:
            self.data["final_weight"] = self.data["switch_weight"] * self.data["censor_weight"]
        elif switch_weights:
            self.data["final_weight"] = self.data["switch_weight"]
        elif censor_weights:
            self.data["final_weight"] = self.data["censor_weight"]
        if "switch_weight" in self.data.columns:
            print("\nWeight Summary for PP:")
            print(self.data[["switch_weight", "censor_weight", "final_weight"]].describe())
        else:
            print("\nWeight Summary for ITT:")
            print(self.data[["censor_weight", "final_weight"]].describe())

    def show_weight_models(self):
        if "censored" not in self.data.columns:
            raise ValueError("Column 'censored' not found in dataset.")
        self.data["censored"] = self.data["censored"].astype(int)
        self.data["censored_inv"] = 1 - self.data["censored"]

        if self.estimand == "PP":
            print("===== PP Estimand (No Pooling) =====")
            if self.censor_weights is not None:
                print("\n## Informative Censoring Weights ##")
                print("\n# n0: Numerator Model (previous_treatment = 0)")
                n0_model = smf.logit("censored_inv ~ x2", data=self.data[self.data["previous_treatment"] == 0]).fit(method="newton", disp=0)
                print(n0_model.summary2().tables[1].round(6))
                print("\n# n1: Numerator Model (previous_treatment = 1)")
                n1_model = smf.logit("censored_inv ~ x2", data=self.data[self.data["previous_treatment"] == 1]).fit(method="newton", disp=0)
                print(n1_model.summary2().tables[1].round(6))
                print("\n# d0: Denominator Model (previous_treatment = 0)")
                d0_model = smf.logit("censored_inv ~ x2 + x1", data=self.data[self.data["previous_treatment"] == 0]).fit(method="newton", disp=0)
                print(d0_model.summary2().tables[1].round(6))
                print("\n# d1: Denominator Model (previous_treatment = 1)")
                d1_model = smf.logit("censored_inv ~ x2 + x1", data=self.data[self.data["previous_treatment"] == 1]).fit(method="newton", disp=0)
                print(d1_model.summary2().tables[1].round(6))
            if self.switch_weights is not None:
                print("\n## Treatment Switch Weights ##")
                print("\n# n0: Numerator Model (previous_treatment = 0)")
                n0_switch = smf.logit("treatment ~ age", data=self.data[self.data["previous_treatment"] == 0]).fit(method="newton", disp=0)
                print(n0_switch.summary2().tables[1].round(6))
                print("\n# n1: Numerator Model (previous_treatment = 1)")
                n1_switch = smf.logit("treatment ~ age", data=self.data[self.data["previous_treatment"] == 1]).fit(method="newton", disp=0)
                print(n1_switch.summary2().tables[1].round(6))
                print("\n# d0: Denominator Model (previous_treatment = 0)")
                d0_switch = smf.logit("treatment ~ age + x1 + x3", data=self.data[self.data["previous_treatment"] == 0]).fit(method="newton", disp=0)
                print(d0_switch.summary2().tables[1].round(6))
                print("\n# d1: Denominator Model (previous_treatment = 1)")
                d1_switch = smf.logit("treatment ~ age + x1 + x3", data=self.data[self.data["previous_treatment"] == 1]).fit(method="newton", disp=0)
                print(d1_switch.summary2().tables[1].round(6))
        elif self.estimand == "ITT":
            print("===== ITT Estimand =====")
            if self.censor_weights is not None:
                print("\n## Informative Censoring Weights ##")
                print("\n# n: Numerator Model (pooled)")
                n_pooled = smf.logit("censored_inv ~ x2", data=self.data).fit(method="newton", disp=0)
                print(n_pooled.summary2().tables[1].round(6))
                print("\n# d0: Denominator Model (previous_treatment = 0)")
                d0_model = smf.logit("censored_inv ~ x2 + x1", data=self.data[self.data["previous_treatment"] == 0]).fit(method="newton", disp=0)
                print(d0_model.summary2().tables[1].round(6))
                print("\n# d1: Denominator Model (previous_treatment = 1)")
                d1_model = smf.logit("censored_inv ~ x2 + x1", data=self.data[self.data["previous_treatment"] == 1]).fit(method="newton", disp=0)
                print(d1_model.summary2().tables[1].round(6))

#STEP 5
    def set_outcome_model(self, adjustment_terms=None):
        if self.data is None:
            raise ValueError("set_data() before defining the outcome model.")

        # Determine treatment variable
        treatment_var = "treatment" if self.estimand in ["ITT", "PP"] else "dose"

        # Dynamically retrieve stabilized weight terms
        stabilised_weight_terms = []
        if self.switch_weights:
            stabilised_weight_terms.append(self.switch_weights["numerator"].split("~")[1].strip())
        if self.censor_weights:
            stabilised_weight_terms.append(self.censor_weights["numerator"].split("~")[1].strip())
        stabilised_weight_terms = " + ".join(stabilised_weight_terms) if stabilised_weight_terms else "1"

        # **Dynamically determine adjustment terms (Mimicking R)**
        if adjustment_terms is None:
            if self.estimand == "PP":
                adjustment_terms = ["x1", "x2", "x3", "age"]
            elif self.estimand == "ITT":
                adjustment_terms = ["x2"]
            else:
                adjustment_terms = ["1"]
        elif isinstance(adjustment_terms, str):
            adjustment_terms = adjustment_terms.split(" + ")

        # Add polynomial terms **only if the columns exist**
        additional_terms = []
        if "followup_time" in self.data.columns:
            additional_terms.append("followup_time")
            self.data["followup_time_squared"] = self.data["followup_time"] ** 2
            additional_terms.append("followup_time_squared")

        if "period" in self.data.columns:
            additional_terms.append("period")
            self.data["trial_period_squared"] = self.data["period"] ** 2
            additional_terms.append("trial_period_squared")

        # Build the final regression formula
        all_terms = [treatment_var] + adjustment_terms + additional_terms + [stabilised_weight_terms]
        formula = "outcome ~ " + " + ".join(filter(None, all_terms))  # Remove empty terms

        # Check if weights exist before fitting the model
        if "final_weight" not in self.data.columns:
            raise ValueError("Weights have not been calculated. Run calculate_weights() first.")

        # Fit GLM model
        predictor_vars = [term.strip() for term in formula.split("~")[1].strip().split(" + ") if term.strip() != "1"]
        model = sm.GLM(
            self.data["outcome"],
            sm.add_constant(self.data[predictor_vars]),  # Use only available predictors
            family=sm.families.Binomial(),
            weights=self.data["final_weight"]
        ).fit()

        # Store predictions and residuals
        self.data["predicted_outcome"] = model.predict(sm.add_constant(self.data[predictor_vars]))
        self.data["residuals"] = self.data["outcome"] - self.data["predicted_outcome"]
        self.outcome_model = model

        return model

        
    def show_outcome_model(self):
        if self.outcome_model is None:
            return "Outcome model not specified."
        return self.outcome_model.summary()
    
    #step 6

    def set_expansion_options(self, output: TEDatastore, chunk_size: int = 0, first_period: int = 0, last_period: float = float('inf'), censor_at_switch: bool = False):
        
        self.expansion = TEExpansion(chunk_size = chunk_size, datastore = output, first_period = first_period, last_period = last_period, censor_at_switch = censor_at_switch)

        return self
    
    def expand_trials(self):
        data = self.data.copy()
        outcome_adj_vars = self.get_outcome_adjustment_vars()
        keeplist = ['id', 'trial_period', 'followup_time', 'outcome', 'weight', 'treatment', 'x2', 'age'] + outcome_adj_vars

        if 'wt' not in data.columns:
            data['wt'] =  1

        all_ids = data['id'].unique()
        if self.expansion.chunk_size == 0:
            ids_split = [all_ids]
        else: 
            ids_split = np.array_split(all_ids, np.ceil(len(all_ids) / self.expansion.chunk_size))

        for ids in ids_split:
            switch_data = self._expand_chunk(data, ids, outcome_adj_vars, keeplist)
            self.expansion.datastore = self.expansion.datastore.save_expanded_data(switch_data)
        
        return self
    
    def _expand_chunk(self, data: pd.DataFrame, ids: np.ndarray, outcome_adj_vars: List[str], keeplist: List[str]):
        chunk_data = data[data['id'].isin(ids)].copy()

        first_period = max([self.expansion.first_period, chunk_data[chunk_data['eligible'] == 1]['period'].min() or self.expansion.first_period])
        last_period = min([self.expansion.last_period, chunk_data[chunk_data['eligible'] == 1]['period'].max() or self.expansion.last_period])
        
        expanded_data = []
        for _, row in chunk_data.iterrows():
            if row['eligible'] == 1 and first_period <= row['period'] <= last_period:
                trial_start = row['period']
                trial_data = self._generate_trial_instance(row, chunk_data, trial_start, last_period, outcome_adj_vars, keeplist)
                expanded_data.append(trial_data)

        result = pd.concat(expanded_data, ignore_index=True) if expanded_data else pd.DataFrame()

        return result[keeplist]
    

    def _generate_trial_instance(self, baseline_row: pd.Series, data: pd.DataFrame, trial_start: int, last_period: float, outcome_adj_vars: List[str], keeplist: List[str]):

        id_val = baseline_row['id']
        patient_data = data[data['id'] == id_val].sort_values('period')
        rows = []

        if pd.isna(last_period) or last_period == float('inf'):
            last_period_value = patient_data['period'].max()
        else:
            last_period_value = last_period

        # Convert float to integer to handle errors
        if pd.notna(last_period_value):
            last_period_int = int(np.floor(float(last_period_value)))
        else:
            last_period_int = int(trial_start)

        max_period_value = patient_data['period'].max()
        if pd.notna(max_period_value):
            max_period = int(np.floor(float(max_period_value)))
        else:
            max_period = last_period_int 

        last_period_int = int(last_period_int)
        max_period = int(max_period)

        for period in range(int(trial_start), int(min(last_period_int + 1, max_period + 1))):
            period_row = patient_data[patient_data['period'] == period].iloc[0] if not patient_data[patient_data['period'] == period].empty else None
            
            if period_row is None:
                continue

            if self.expansion.censor_at_switch and period > trial_start:
                prev_row = patient_data[patient_data['period'] == (period - 1)].iloc[0]
                if prev_row['treatment'] != period_row['treatment']:
                    break  # Censor at switch

            trial_period = period - trial_start
            followup_time = period - trial_start
            final_weight = self.data[(self.data['id'] == id_val) & (self.data['period'] == period)]['final_weight'].iloc[0] if not self.data[(self.data['id'] == id_val) & (self.data['period'] == period)].empty else 1.0
            row_dict = {
                'id': id_val,
                'trial_period': trial_period,
                'followup_time': followup_time,
                'outcome': period_row['outcome'],
                'weight': final_weight,  
                'treatment': period_row['treatment'],
            }
            
            for var in outcome_adj_vars + ['age', 'x2']:
                if var in patient_data.columns:
                    row_dict[var] = period_row.get(var, np.nan)
                else:
                    row_dict[var] = np.nan 

            rows.append(pd.Series(row_dict))

        df = pd.DataFrame(rows)
        int_columns = ['id', 'trial_period', 'followup_time', 'outcome', 'treatment', 'age']
        df[int_columns] = df[int_columns].astype(int)

        return df
    
    def get_outcome_adjustment_vars(self):
        return getattr(self.outcome_model, 'adjustment_vars', [])
    

    # step 7
    def load_expanded_data(self, p_control: Optional[float] = None, period: Optional[List[int]] = None, subset_condition: Optional[str] = None, seed: Optional[int] = None):
        
        if p_control is None:
            data_table = self.expansion.datastore.data.copy()
            data_table['sample_weight'] = 1
        else:
            np.random.seed(seed) if seed is not None else np.random.seed()
            data_table = self.expansion.datastore.data.copy()

            mask_outcome_1 = data_table['outcome'] == 1
            mask_outcome_0 = data_table['outcome'] == 0
            sampled_0 = data_table[mask_outcome_0].sample(frac=p_control, replace=False)
            data_table = pd.concat([data_table[mask_outcome_1], sampled_0])

            data_table.loc[mask_outcome_0, 'sample_weight'] = 1 / p_control if p_control > 0 else 1
            data_table.loc[mask_outcome_1, 'sample_weight'] = 1

        if period is not None:
            data_table = data_table[data_table['trial_period'].isin(period) | data_table['followup_time'].isin(period)]
        
        if subset_condition is not None:
            data_table = data_table.query(subset_condition)
        
        data_table = data_table.sort_values(['id', 'trial_period', 'followup_time'])
        data_table = data_table.reset_index(drop=True)
        
        self.outcome_data = data_table
        
        return self


#Subclass of Trial Sequence, handles the PP (hehe) estimand
class TrialSequencePP(TrialSequence):
    def __init__(self, **kwargs):
        super().__init__("PP", **kwargs)
 
#Subclass of Trial Sequence, handles the ITT estimand
class TrialSequenceITT(TrialSequence):
    def __init__(self, **kwargs):
        super().__init__("ITT", **kwargs)

#trial_sequence function equivalent used in the article
def trial_sequence(estimand, **kwargs):
    estimand_classes = {
        "PP": TrialSequencePP,
        "ITT": TrialSequenceITT
    }

    if estimand not in estimand_classes:
        raise ValueError(f"{estimand} is not a valid estimand, choose either PP or ITT")
    
    return estimand_classes[estimand](**kwargs)

## III. Process

### 1. Setup
A sequence of target trials analysis starts by specifying which estimand will be used:

In [21]:
trial_pp = trial_sequence("PP")
trial_itt = trial_sequence("ITT")

### 2. Data Preparation
Next the user must specify the observational input data that will be used for the target trial emulation. Here we need to specify which columns contain which values and how they should be used.

In [22]:
data_censored = pd.read_csv("data_censored.csv")
print("Extracted Dummy Data")
display(data_censored)
data_censored["previous_treatment"] = data_censored["treatment"].shift(1).fillna(0)
#Setting the dataset to the data field
trial_pp.set_data(data_censored.copy())  # Create a separate copy
trial_itt.set_data(data_censored.copy())  


#Displaying the info stored in each class
trial_pp.show()

trial_itt.show()

Extracted Dummy Data


Unnamed: 0,id,period,treatment,x1,x2,x3,x4,age,age_s,outcome,censored,eligible
0,1,0,1,1,1.146148,0,0.734203,36,0.083333,0,0,1
1,1,1,1,1,0.002200,0,0.734203,37,0.166667,0,0,0
2,1,2,1,0,-0.481762,0,0.734203,38,0.250000,0,0,0
3,1,3,1,0,0.007872,0,0.734203,39,0.333333,0,0,0
4,1,4,1,1,0.216054,0,0.734203,40,0.416667,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...
720,99,3,0,0,-0.747906,1,0.575268,68,2.750000,0,0,0
721,99,4,0,0,-0.790056,1,0.575268,69,2.833333,0,0,0
722,99,5,1,1,0.387429,1,0.575268,70,2.916667,0,0,0
723,99,6,1,1,-0.033762,1,0.575268,71,3.000000,0,0,0


Trial Sequence Object
Estimand: PP



Unnamed: 0,id,period,treatment,x1,x2,x3,x4,age,age_s,outcome,censored,eligible,previous_treatment,followup_time
0,1,0,1,1,1.146148,0,0.734203,36,0.083333,0,0,1,0.0,5
1,1,1,1,1,0.002200,0,0.734203,37,0.166667,0,0,0,1.0,5
2,1,2,1,0,-0.481762,0,0.734203,38,0.250000,0,0,0,1.0,5
3,1,3,1,0,0.007872,0,0.734203,39,0.333333,0,0,0,1.0,5
4,1,4,1,1,0.216054,0,0.734203,40,0.416667,0,0,0,1.0,5
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
720,99,3,0,0,-0.747906,1,0.575268,68,2.750000,0,0,0,0.0,7
721,99,4,0,0,-0.790056,1,0.575268,69,2.833333,0,0,0,0.0,7
722,99,5,1,1,0.387429,1,0.575268,70,2.916667,0,0,0,0.0,7
723,99,6,1,1,-0.033762,1,0.575268,71,3.000000,0,0,0,1.0,7



IPW for informative censoring:
Not calculated.

Outcome model:
Not specified.
Trial Sequence Object
Estimand: ITT



Unnamed: 0,id,period,treatment,x1,x2,x3,x4,age,age_s,outcome,censored,eligible,previous_treatment,followup_time
0,1,0,1,1,1.146148,0,0.734203,36,0.083333,0,0,1,0.0,5
1,1,1,1,1,0.002200,0,0.734203,37,0.166667,0,0,0,1.0,5
2,1,2,1,0,-0.481762,0,0.734203,38,0.250000,0,0,0,1.0,5
3,1,3,1,0,0.007872,0,0.734203,39,0.333333,0,0,0,1.0,5
4,1,4,1,1,0.216054,0,0.734203,40,0.416667,0,0,0,1.0,5
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
720,99,3,0,0,-0.747906,1,0.575268,68,2.750000,0,0,0,0.0,7
721,99,4,0,0,-0.790056,1,0.575268,69,2.833333,0,0,0,0.0,7
722,99,5,1,1,0.387429,1,0.575268,70,2.916667,0,0,0,0.0,7
723,99,6,1,1,-0.033762,1,0.575268,71,3.000000,0,0,0,1.0,7



IPW for informative censoring:
Not calculated.

Outcome model:
Not specified.


### 3. Weight Models
To adjust for the effects of informative censoring, inverse probability of censoring weights (IPCW) can be applied. To estimate these weights, we construct time-to-(censoring) event models. Two sets of models are fit for the two censoring mechanisms which may apply: censoring due to deviation from assigned treatment and other informative censoring.
#### 3.1 Censoring due to treatment switching
We specify model formulas to be used for calculating the probability of receiving treatment in the current period. Separate models are fitted for patients who had treatment = 1 and those who had treatment = 0 in the previous period. Stabilized weights are used by fitting numerator and denominator models.

There are optional arguments to specify columns which can include/exclude observations from the treatment models. These are used in case it is not possible for a patient to deviate from a certain treatment assignment in that period.

In [23]:
path = "Models"

#debug
data_0 = trial_pp.data[trial_pp.data["previous_treatment"] == 0]
data_1 = trial_pp.data[trial_pp.data["previous_treatment"] == 1]
print(f"Switch (PP): previous_treatment = 0, nobs = {len(data_0)}")
print(f"Switch (PP): previous_treatment = 1, nobs = {len(data_1)}")
print(trial_pp.data.columns)
print(trial_pp.data["previous_treatment"].value_counts())


trial_pp.set_switch_weight_model(numerator="age", denominator="age + x1 + x3", model_fitter=stats_glm_logit(save_path=os.path.join(path, "switch_models")))
trial_pp.show_switch_weights()

#debug




Switch (PP): previous_treatment = 0, nobs = 386
Switch (PP): previous_treatment = 1, nobs = 339
Index(['id', 'period', 'treatment', 'x1', 'x2', 'x3', 'x4', 'age', 'age_s',
       'outcome', 'censored', 'eligible', 'previous_treatment',
       'followup_time'],
      dtype='object')
previous_treatment
0.0    386
1.0    339
Name: count, dtype: int64
✅ Switch weights computed and stored in self.data


{'numerator': 'treatment ~ age',
 'denominator': 'treatment ~ age + x1 + x3',
 'model_fitter': 'te_stats_glm_logit',
 'fitted_model_0_numerator': <statsmodels.discrete.discrete_model.BinaryResultsWrapper at 0x13d07eef290>,
 'fitted_model_1_numerator': <statsmodels.discrete.discrete_model.BinaryResultsWrapper at 0x13d080154f0>,
 'fitted_model_0_denominator': <statsmodels.discrete.discrete_model.BinaryResultsWrapper at 0x13d080148f0>,
 'fitted_model_1_denominator': <statsmodels.discrete.discrete_model.BinaryResultsWrapper at 0x13d0d0dcad0>}

#### 3.2 Other informative censoring
In case there is other informative censoring occurring in the data, we can create similar models to estimate the IPCW. These can be used with all types of estimand. We need to specifycensor_event which is the column containing the censoring indicator.

In [24]:
data_0 = trial_pp.data[trial_pp.data["previous_treatment"] == 0]
data_1 = trial_pp.data[trial_pp.data["previous_treatment"] == 1]
print(f"Censor (PP): previous_treatment = 0, nobs = {len(data_0)}")
print(f"Censor (PP): previous_treatment = 1, nobs = {len(data_1)}")
print(trial_pp.data.columns)  # Make sure 'previous_treatment' is in the dataset
print(trial_pp.data["previous_treatment"].value_counts())
trial_pp.set_censor_weight_model(censor_event="censored", numerator="x2", denominator="x2 + x1", pool_models="none", model_fitter=stats_glm_logit(save_path=os.path.join(path, "censor_models")))
trial_pp.show_censor_weights()

Censor (PP): previous_treatment = 0, nobs = 386
Censor (PP): previous_treatment = 1, nobs = 339
Index(['id', 'period', 'treatment', 'x1', 'x2', 'x3', 'x4', 'age', 'age_s',
       'outcome', 'censored', 'eligible', 'previous_treatment',
       'followup_time', 'switch_prob_0', 'switch_prob_1', 'switch_weight'],
      dtype='object')
previous_treatment
0.0    386
1.0    339
Name: count, dtype: int64


  return 1/(1+np.exp(-X))
  return np.sum(np.log(self.cdf(q * linpred)))


{'numerator': '1 - censored ~ x2',
 'denominator': '1 - censored ~ x2 + x1',
 'pool_numerator': False,
 'pool_denominator': False,
 'model_fitter': 'te_stats_glm_logit',
 'fitted_model_0_numerator': <statsmodels.discrete.discrete_model.BinaryResultsWrapper at 0x13d08015c10>,
 'fitted_model_1_numerator': <statsmodels.discrete.discrete_model.BinaryResultsWrapper at 0x13d0d101610>,
 'fitted_model_0_denominator': <statsmodels.discrete.discrete_model.BinaryResultsWrapper at 0x13d0d102030>,
 'fitted_model_1_denominator': <statsmodels.discrete.discrete_model.BinaryResultsWrapper at 0x13d0d0f0110>}

In [25]:
print(f"ITT: total nobs = {len(trial_itt.data)}")
data_0 = trial_itt.data[trial_itt.data["previous_treatment"] == 0]
data_1 = trial_itt.data[trial_itt.data["previous_treatment"] == 1]
print(f"ITT: previous_treatment = 0, nobs = {len(data_0)}")
print(f"ITT: previous_treatment = 1, nobs = {len(data_1)}")
trial_itt.set_censor_weight_model(censor_event="censored", numerator="x2", denominator="x2 + x1", pool_models="numerator", model_fitter=stats_glm_logit(save_path=os.path.join(path, "censor_models")))
trial_itt.show_censor_weights()

ITT: total nobs = 725
ITT: previous_treatment = 0, nobs = 386
ITT: previous_treatment = 1, nobs = 339


  return 1/(1+np.exp(-X))
  return np.sum(np.log(self.cdf(q * linpred)))
  return 1/(1+np.exp(-X))
  return np.sum(np.log(self.cdf(q * linpred)))


{'numerator': '1 - censored ~ x2',
 'denominator': '1 - censored ~ x2 + x1',
 'pool_numerator': True,
 'pool_denominator': False,
 'model_fitter': 'te_stats_glm_logit',
 'fitted_model_numerator': <statsmodels.discrete.discrete_model.BinaryResultsWrapper at 0x13d0d101d90>,
 'fitted_model_0_denominator': <statsmodels.discrete.discrete_model.BinaryResultsWrapper at 0x13d0d0f2750>,
 'fitted_model_1_denominator': <statsmodels.discrete.discrete_model.BinaryResultsWrapper at 0x13d0d0f3fb0>}

#### 4. Calculate Weights
Next we need to fit the individual models and combine them into weights. This is done with calculate_weights().

In [26]:

trial_pp.calculate_weights()
trial_itt.calculate_weights()
# print(trial_pp.data.columns)
# print(trial_itt.data.columns)


Calculating switch weights...
Calculating censor weights...

Weight Summary for PP:
       switch_weight  censor_weight  final_weight
count     725.000000   7.250000e+02    725.000000
mean        2.733546   1.000000e+00      2.733546
std         1.732471   2.221979e-16      1.732471
min         1.246125   1.000000e+00      1.246125
25%         1.620576   1.000000e+00      1.620576
50%         1.955091   1.000000e+00      1.955091
75%         3.258089   1.000000e+00      3.258089
max        12.525849   1.000000e+00     12.525849
Calculating censor weights...

Weight Summary for ITT:
       censor_weight  final_weight
count   7.250000e+02  7.250000e+02
mean    1.000000e+00  1.000000e+00
std     2.221979e-16  2.221979e-16
min     1.000000e+00  1.000000e+00
25%     1.000000e+00  1.000000e+00
50%     1.000000e+00  1.000000e+00
75%     1.000000e+00  1.000000e+00
max     1.000000e+00  1.000000e+00


In [27]:
trial_pp.show_weight_models()

===== PP Estimand (No Pooling) =====

## Informative Censoring Weights ##

# n0: Numerator Model (previous_treatment = 0)
              Coef.  Std.Err.          z     P>|z|    [0.025   0.975]
Intercept  2.329680  0.184529  12.625032  0.000000  1.968011  2.69135
x2        -0.469171  0.184234  -2.546610  0.010877 -0.830262 -0.10808

# n1: Numerator Model (previous_treatment = 1)
              Coef.  Std.Err.          z     P>|z|    [0.025    0.975]
Intercept  2.617963  0.221940  11.795796  0.000000  2.182968  3.052958
x2        -0.390321  0.209994  -1.858719  0.063067 -0.801902  0.021261

# d0: Denominator Model (previous_treatment = 0)
              Coef.  Std.Err.         z     P>|z|    [0.025    0.975]
Intercept  1.861908  0.215667  8.633235  0.000000  1.439207  2.284608
x2        -0.479630  0.185757 -2.582034  0.009822 -0.843706 -0.115554
x1         1.225127  0.402734  3.042029  0.002350  0.435784  2.014471

# d1: Denominator Model (previous_treatment = 1)
              Coef.  Std.Er

In [28]:
trial_itt.show_weight_models()

===== ITT Estimand =====

## Informative Censoring Weights ##

# n: Numerator Model (pooled)
              Coef.  Std.Err.          z     P>|z|    [0.025    0.975]
Intercept  2.448091  0.140575  17.414876  0.000000  2.172569  2.723612
x2        -0.448648  0.136878  -3.277724  0.001046 -0.716924 -0.180372

# d0: Denominator Model (previous_treatment = 0)
              Coef.  Std.Err.         z     P>|z|    [0.025    0.975]
Intercept  1.861908  0.215667  8.633235  0.000000  1.439207  2.284608
x2        -0.479630  0.185757 -2.582034  0.009822 -0.843706 -0.115554
x1         1.225127  0.402734  3.042029  0.002350  0.435784  2.014471

# d1: Denominator Model (previous_treatment = 1)
              Coef.  Std.Err.         z     P>|z|    [0.025    0.975]
Intercept  2.624317  0.268212  9.784493  0.000000  2.098632  3.150003
x2        -0.389484  0.210920 -1.846598  0.064805 -0.802878  0.023911
x1        -0.020322  0.478823 -0.042442  0.966146 -0.958798  0.918153


### 5. Specify Outcome Model
Now we can specify the outcome model. Here we can include adjustment terms for any variables in the dataset. The numerator terms from the stabilised weight models are automatically included in the outcome model formula.

In [29]:
# print(trial_pp.data.columns)
# print(trial_itt.data.columns)
# print(trial_pp.data["outcome"].value_counts())
trial_pp.set_outcome_model()  
print(trial_pp.show_outcome_model())

trial_itt.set_outcome_model(adjustment_terms="x2")  
print(trial_itt.show_outcome_model())

                 Generalized Linear Model Regression Results                  
Dep. Variable:                outcome   No. Observations:                  725
Model:                            GLM   Df Residuals:                      715
Model Family:                Binomial   Df Model:                            9
Link Function:                  Logit   Scale:                          1.0000
Method:                          IRLS   Log-Likelihood:                -25.350
Date:                Thu, 06 Mar 2025   Deviance:                       50.699
Time:                        17:20:06   Pearson chi2:                     117.
No. Iterations:                    29   Pseudo R-squ. (CS):            0.08358
Covariance Type:            nonrobust                                         
                            coef    std err          z      P>|z|      [0.025      0.975]
-----------------------------------------------------------------------------------------
const                    -5.54



### 6. Expand Trials
Now we are ready to create the data set with all of the sequence of target trials.

In [30]:
output = TEDatastore()
trial_pp.set_expansion_options(output, chunk_size=500, first_period = 0, last_period= float('inf'), censor_at_switch = True)

<__main__.TrialSequencePP at 0x13d07ea2900>

#### 6.1 Create Sequence of Trials Data

In [31]:
trial_pp.expand_trials()
print("\nExpanded Data:")
print(trial_pp.expansion.datastore.data)
trial_pp.expansion.datastore.data.to_csv("output2.csv", index=False)


Expanded Data:
     id  trial_period  followup_time  outcome    weight  treatment        x2  \
0     1             0              0        0  1.660829          1  1.146148   
1     1             1              1        0  1.404781          1  0.002200   
2     1             2              2        0  1.687483          1 -0.481762   
3     1             3              3        0  1.699942          1  0.007872   
4     1             4              4        0  1.427189          1  0.216054   
..   ..           ...            ...      ...       ...        ...       ...   
495  98             0              0        0  3.880790          1  1.392339   
496  98             1              1        0  1.669296          1 -0.934798   
497  98             2              2        0  2.136737          1 -0.735241   
498  99             0              0        0  4.747953          1 -0.346378   
499  99             1              1        0  1.904885          1 -1.106481   

     age  
0     36  
1

### 7. Load or Sample Expanded Data
Now that the expanded data has been created, we can prepare the data to fit the outcome model. For data that can fit comfortably in memory, this is a trivial step using load_expanded_data.

For large datasets, it may be necessary to sample from the expanded by setting the p_control argument. This sets the probability that an observation with outcome == 0 will be included in the loaded data. A seed can be set for reproducibility. Additionally, a vector of periods to include can be specified, e.g., period = 1:60, and/or a subsetting condition, subset_condition = "age > 65".

In [32]:
trial_pp.load_expanded_data(p_control = 0.5, seed=1234)


<__main__.TrialSequencePP at 0x13d07ea2900>