In [14]:
import os
import pickle
import pandas as pd
import statsmodels.api as sm

# Define the logistic regression fitter to match R's stats_glm_logit
def stats_glm_logit(y, X):
    return sm.Logit(y, X).fit(disp=0)

class TrialEmulation:
    def __init__(self, estimand=None):
        self.estimand = estimand  # "PP" or "ITT"
        self.data = None
        self.id_col = None
        self.period_col = None
        self.treatment_col = None
        self.outcome_col = None
        self.eligible_col = None
        self.censor_weights = None
        self.switch_weights = {}  # Empty dict for switch weights

    def set_data(self, data, id_col, period_col, treatment_col, outcome_col, eligible_col):
        """Set the dataset and column mappings."""
        if not isinstance(data, pd.DataFrame):
            raise ValueError("Data must be a pandas DataFrame.")
        self.data = data.copy()
        self.id_col = id_col
        self.period_col = period_col
        self.treatment_col = treatment_col
        self.outcome_col = outcome_col
        self.eligible_col = eligible_col
        # Compute previous treatment as in R
        self.data['prev_treatment'] = self.data.groupby(self.id_col)[self.treatment_col].shift(1)
        return self

    def show(self, num_rows=5):
        """Display dataset head."""
        if self.data is not None:
            print(f"Trial Sequence Object\nEstimand: {self.estimand}\n\nData:\n - N: {len(self.data)} observations from {self.data[self.id_col].nunique()} patients")
            print(self.data.head(num_rows))
        else:
            print("No data available in this instance.")

    def set_switch_weight_model(self, numerator, denominator, model_fitter, save_path=None):
        """Set the switch weight model formulas and fitter (PP only)."""
        if self.estimand == "ITT":
            raise ValueError("Switch weight models are not applicable for ITT estimand.")
        if not callable(model_fitter):
            raise ValueError("model_fitter must be a callable function.")
        
        self.switch_weights = {
            "numerator": numerator,
            "denominator": denominator,
            "model_fitter": model_fitter,
            "save_path": save_path,
            "fitted": False
        }
        return self

    def set_censor_weight_model(self, censor_event, numerator, denominator, pool_models, model_fitter, save_path=None):
        """Set the censoring weight model."""
        if censor_event not in self.data.columns:
            raise ValueError(f"Censor event column '{censor_event}' not found in data.")
        
        self.censor_event = censor_event
        self.numerator_formula = numerator
        self.denominator_formula = denominator
        self.pool_models = pool_models
        self.model_fitter = model_fitter
        self.censor_weights = {
            "numerator_formula": f"1 - {censor_event} ~ {numerator}",
            "denominator_formula": f"1 - {censor_event} ~ {denominator}",
            "model_fitter_type": "te_stats_glm_logit",
            "save_path": save_path,
            "fitted": False
        }
        return self

    def calculate_weights(self):
        """Calculate stabilized weights for switching (PP) and censoring (PP and ITT)."""
        if self.data is None:
            raise ValueError("No data has been set.")

        # 1. Treatment Switching Weights (PP only)
        if self.switch_weights and self.switch_weights.get("numerator"):
            num_features = [f.strip() for f in self.switch_weights["numerator"].split(" + ")]
            denom_features = [f.strip() for f in self.switch_weights["denominator"].split(" + ")]
            data_switch = self.data.dropna(subset=['prev_treatment'])
            y_switch = data_switch[self.treatment_col]
            numerator_models = {}
            denominator_models = {}
            switch_weights = {}

            for prev_trt in [0, 1]:
                subset = data_switch[data_switch['prev_treatment'] == prev_trt]
                print(f"PP Switching - Observations for prev_treatment = {prev_trt}: {len(subset)}")  # Debug
                X_num = sm.add_constant(subset[num_features])
                X_denom = sm.add_constant(subset[denom_features])
                y_current = y_switch.loc[subset.index]

                num_model = self.switch_weights["model_fitter"](y_current, X_num)
                denom_model = self.switch_weights["model_fitter"](y_current, X_denom)
                numerator_models[prev_trt] = num_model
                denominator_models[prev_trt] = denom_model

                p_num = num_model.predict(X_num)
                p_denom = denom_model.predict(X_denom)
                switch_weights[prev_trt] = p_num / p_denom

            # Assign weights
            self.data['switch_weight'] = 1.0
            for prev_trt in [0, 1]:
                indices = data_switch[data_switch['prev_treatment'] == prev_trt].index
                self.data.loc[indices, 'switch_weight'] = switch_weights[prev_trt]

            self.switch_weights["numerator_models"] = numerator_models
            self.switch_weights["denominator_models"] = denominator_models
            self.switch_weights["fitted"] = True

            # Save models
            if self.switch_weights["save_path"]:
                os.makedirs(self.switch_weights["save_path"], exist_ok=True)
                for prev_trt in [0, 1]:
                    pickle.dump(numerator_models[prev_trt], open(os.path.join(self.switch_weights["save_path"], f"model_switch_n{prev_trt}.pkl"), "wb"))
                    pickle.dump(denominator_models[prev_trt], open(os.path.join(self.switch_weights["save_path"], f"model_switch_d{prev_trt}.pkl"), "wb"))

        # 2. Censoring Weights (PP and ITT)
        if self.censor_weights:
            num_features = [f.strip() for f in self.numerator_formula.split(" + ")]
            denom_features = [f.strip() for f in self.denominator_formula.split(" + ")]
            y_censor = 1 - self.data[self.censor_event].astype(int)

            if self.pool_models == "numerator":
                # Pooled numerator model
                X_num = sm.add_constant(self.data[num_features])
                numerator_model = self.model_fitter(y_censor, X_num)
                p_num_all = numerator_model.predict(X_num)

                # Denominator models by prev_treatment
                data_censor = self.data.dropna(subset=['prev_treatment'])
                y_censor_subset = 1 - data_censor[self.censor_event].astype(int)
                denominator_models = {}
                censor_weights = {}
                for prev_trt in [0, 1]:
                    subset = data_censor[data_censor['prev_treatment'] == prev_trt]
                    print(f"ITT Censoring - Observations for prev_treatment = {prev_trt}: {len(subset)}")  # Debug
                    X_denom = sm.add_constant(subset[denom_features])
                    denom_model = self.model_fitter(y_censor_subset.loc[subset.index], X_denom)
                    denominator_models[prev_trt] = denom_model
                    p_denom = denom_model.predict(X_denom)
                    p_num_subset = p_num_all.loc[subset.index]
                    censor_weights[prev_trt] = p_num_subset / p_denom

                self.censor_weights["Numerator Model"] = numerator_model
                self.censor_weights["Denominator Models"] = denominator_models

            elif self.pool_models == "none":
                data_censor = self.data.dropna(subset=['prev_treatment'])
                y_censor_subset = 1 - data_censor[self.censor_event].astype(int)
                numerator_models = {}
                denominator_models = {}
                censor_weights = {}
                for prev_trt in [0, 1]:
                    subset = data_censor[data_censor['prev_treatment'] == prev_trt]
                    print(f"PP Censoring - Observations for prev_treatment = {prev_trt}: {len(subset)}")  # Debug
                    X_num = sm.add_constant(subset[num_features])
                    X_denom = sm.add_constant(subset[denom_features])
                    y_current = y_censor_subset.loc[subset.index]

                    num_model = self.model_fitter(y_current, X_num)
                    denom_model = self.model_fitter(y_current, X_denom)
                    numerator_models[prev_trt] = num_model
                    denominator_models[prev_trt] = denom_model

                    p_num = num_model.predict(X_num)
                    p_denom = denom_model.predict(X_denom)
                    censor_weights[prev_trt] = p_num / p_denom

                self.censor_weights["Numerator Models"] = numerator_models
                self.censor_weights["Denominator Models"] = denominator_models

            # Assign censoring weights
            self.data['censor_weight'] = 1.0
            for prev_trt in [0, 1]:
                indices = data_censor[data_censor['prev_treatment'] == prev_trt].index
                self.data.loc[indices, 'censor_weight'] = censor_weights[prev_trt]

            self.censor_weights["fitted"] = True

            # Save models
            if self.censor_weights["save_path"]:
                os.makedirs(self.censor_weights["save_path"], exist_ok=True)
                if self.pool_models == "numerator":
                    pickle.dump(self.censor_weights["Numerator Model"], open(os.path.join(self.censor_weights["save_path"], "model_censor_n.pkl"), "wb"))
                    for prev_trt in [0, 1]:
                        pickle.dump(self.censor_weights["Denominator Models"][prev_trt], open(os.path.join(self.censor_weights["save_path"], f"model_censor_d{prev_trt}.pkl"), "wb"))
                else:
                    for prev_trt in [0, 1]:
                        pickle.dump(self.censor_weights["Numerator Models"][prev_trt], open(os.path.join(self.censor_weights["save_path"], f"model_censor_n{prev_trt}.pkl"), "wb"))
                        pickle.dump(self.censor_weights["Denominator Models"][prev_trt], open(os.path.join(self.censor_weights["save_path"], f"model_censor_d{prev_trt}.pkl"), "wb"))

        return self

    def show_weight_models(self):
        """Display summaries of fitted weight models."""
        if self.switch_weights.get("fitted"):
            print("Weight Models for Treatment Switching")
            print("-------------------------------------")
            for prev_trt in [0, 1]:
                print(f"\n[[n{prev_trt}]]")
                print(f"Model: P(treatment = 1 | previous treatment = {prev_trt}) for numerator")
                print(self.switch_weights["numerator_models"][prev_trt].summary())
                print(f"\n[[d{prev_trt}]]")
                print(f"Model: P(treatment = 1 | previous treatment = {prev_trt}) for denominator")
                print(self.switch_weights["denominator_models"][prev_trt].summary())

        if self.censor_weights and self.censor_weights.get("fitted"):
            print("\nWeight Models for Informative Censoring")
            print("---------------------------------------")
            if self.pool_models == "numerator":
                print("\n[[n]]")
                print("Model: P(censor_event = 0 | X) for numerator")
                print(self.censor_weights["Numerator Model"].summary())
                for prev_trt in [0, 1]:
                    print(f"\n[[d{prev_trt}]]")
                    print(f"Model: P(censor_event = 0 | X, previous treatment = {prev_trt}) for denominator")
                    print(self.censor_weights["Denominator Models"][prev_trt].summary())
            elif self.pool_models == "none":
                for prev_trt in [0, 1]:
                    print(f"\n[[n{prev_trt}]]")
                    print(f"Model: P(censor_event = 0 | X, previous treatment = {prev_trt}) for numerator")
                    print(self.censor_weights["Numerator Models"][prev_trt].summary())
                    print(f"\n[[d{prev_trt}]]")
                    print(f"Model: P(censor_event = 0 | X, previous treatment = {prev_trt}) for denominator")
                    print(self.censor_weights["Denominator Models"][prev_trt].summary())

    def __repr__(self):
        return f"TrialEmulation(estimand='{self.estimand}', data_shape={self.data.shape if self.data is not None else None})"

# Load data
file_path = "data/data_censored.csv"  # Ensure this matches R's data_censored
data_censored = pd.read_csv(file_path)

# Create directories
trial_pp_dir = os.path.join(os.getcwd(), "trial_pp")
trial_itt_dir = os.path.join(os.getcwd(), "trial_itt")
os.makedirs(trial_pp_dir, exist_ok=True)
os.makedirs(trial_itt_dir, exist_ok=True)

# Initialize trial objects with estimands
trial_pp = TrialEmulation(estimand="PP")
trial_itt = TrialEmulation(estimand="ITT")

# Set data
trial_pp.set_data(data_censored, "id", "period", "treatment", "outcome", "eligible")
trial_itt.set_data(data_censored, "id", "period", "treatment", "outcome", "eligible")

# Step 3.1: Set switch weight model (PP only)
trial_pp.set_switch_weight_model(
    numerator="age",
    denominator="age + x1 + x3",
    model_fitter=stats_glm_logit,
    save_path=os.path.join(trial_pp_dir, "switch_models")
)

# Step 3.2: Set censor weight models
trial_pp.set_censor_weight_model(
    censor_event="censored",
    numerator="x2",
    denominator="x2 + x1",
    pool_models="none",
    model_fitter=stats_glm_logit,
    save_path=os.path.join(trial_pp_dir, "switch_models")
)

trial_itt.set_censor_weight_model(
    censor_event="censored",
    numerator="x2",
    denominator="x2 + x1",
    pool_models="numerator",
    model_fitter=stats_glm_logit,
    save_path=os.path.join(trial_itt_dir, "switch_models")
)

# Step 4: Calculate weights
trial_pp.calculate_weights()
trial_itt.calculate_weights()

# Display results
print("Trial PP:")
trial_pp.show()
print("\nWeight Models for Trial PP:")
trial_pp.show_weight_models()

print("\nTrial ITT:")
trial_itt.show()
print("\nWeight Models for Trial ITT:")
trial_itt.show_weight_models()

PP Switching - Observations for prev_treatment = 0: 337
PP Switching - Observations for prev_treatment = 1: 299
PP Censoring - Observations for prev_treatment = 0: 337
PP Censoring - Observations for prev_treatment = 1: 299
ITT Censoring - Observations for prev_treatment = 0: 337
ITT Censoring - Observations for prev_treatment = 1: 299
Trial PP:
Trial Sequence Object
Estimand: PP

Data:
 - N: 725 observations from 89 patients
   id  period  treatment  x1        x2  x3        x4  age     age_s  outcome  \
0   1       0          1   1  1.146148   0  0.734203   36  0.083333        0   
1   1       1          1   1  0.002200   0  0.734203   37  0.166667        0   
2   1       2          1   0 -0.481762   0  0.734203   38  0.250000        0   
3   1       3          1   0  0.007872   0  0.734203   39  0.333333        0   
4   1       4          1   1  0.216054   0  0.734203   40  0.416667        0   

   censored  eligible  prev_treatment  switch_weight  censor_weight  
0         0        