In [63]:
import pandas as pd
import numpy as np
import statsmodels.api as sm
import statsmodels.formula.api as smf
import os
import matplotlib.pyplot as plt
from patsy import dmatrix

class TrialSequence:
    def __init__(self, estimand):
        self.estimand = estimand
        self.data = None
        self.censor_weights = None
        self.switch_weights = None if estimand != "Per-Protocol" else {}
        self.expanded_data = None
        self.outcome_model = None
        self.outcome_data = None
        self.save_dir = None

    def set_data(self, data, id="id", period="period", treatment="treatment", outcome="outcome", eligible="eligible", censor_event="censored"):
        data = data.rename(columns={id: "id", period: "period", treatment: "treatment", outcome: "outcome", eligible: "eligible", censor_event: "censored"})
        data = data.sort_values(by=["id", "period"])
        data["treatment_lag"] = data.groupby("id")["treatment"].shift(1).fillna(0)
        if self.estimand == "Per-Protocol":
            data["switch"] = (data["treatment"] != data["treatment_lag"]) & (~data["treatment_lag"].isna())
        self.data = data
        return self

    def set_censor_weight_model(self, censor_event, numerator, denominator, pool_models, model_fitter, save_path=None):
        self.censor_weights = {
            "censor_event": censor_event,
            "numerator": numerator,
            "denominator": denominator,
            "pool_models": pool_models,
            "model_fitter": model_fitter,
            "save_path": save_path,
            "fitted_models": {}
        }
        return self

    def set_switch_weight_model(self, numerator, denominator, model_fitter, save_path=None):
        if self.estimand != "Per-Protocol":
            raise ValueError("Switch weights are only applicable for Per-Protocol estimand.")
        self.switch_weights = {
            "numerator": numerator,
            "denominator": denominator,
            "model_fitter": model_fitter,
            "save_path": save_path,
            "fitted_models": {}
        }
        return self

    def print_model_summary(self, model, title):
        print(f"{title}:")
        print("term        estimate   std.error statistic p.value")
        for param in model.params.index:
            term = "(Intercept)" if param == "Intercept" else param
            estimate = model.params[param]
            std_error = model.bse[param]
            statistic = model.tvalues[param]
            p_value = model.pvalues[param]
            print(f"{term:<12}{estimate:>10.6f}{std_error:>12.6f}{statistic:>10.4f}{p_value:>12.4e}")
        null_deviance = -2 * model.llnull
        deviance = -2 * model.llf
        df_null = int(model.nobs - 1)
        print("\nnull.deviance df.null logLik    AIC      BIC      deviance df.residual nobs")
        print(f"{null_deviance:>13.4f}{df_null:>8d}{model.llf:>10.4f}{model.aic:>10.4f}{model.bic:>10.4f}{deviance:>10.4f}{int(model.df_resid):>12d}{int(model.nobs):>5d}")
        print()

    def calculate_weights(self):
        if self.censor_weights is None:
            raise ValueError("Censor weight model not specified.")
        
        data = self.data.copy()
        censor_event = self.censor_weights["censor_event"]
        data["not_censored"] = 1 - data[censor_event]
        
        print(f"Step 4: Weight Models for Informative Censoring ({self.estimand})")
        print("---------------------------------------")
        
        if self.censor_weights["pool_models"] == "numerator":
            num_formula = f"not_censored ~ {self.censor_weights['numerator']}"
            model_n = smf.logit(num_formula, data=data).fit(disp=0)
            self.censor_weights["fitted_models"]["n"] = model_n
            
            den_formula = f"not_censored ~ {self.censor_weights['denominator']}"
            model_d0 = smf.logit(den_formula, data=data[data["treatment_lag"] == 0]).fit(disp=0)
            model_d1 = smf.logit(den_formula, data=data[data["treatment_lag"] == 1]).fit(disp=0)
            self.censor_weights["fitted_models"]["d0"] = model_d0
            self.censor_weights["fitted_models"]["d1"] = model_d1
            
            data["censor_prob_num"] = model_n.predict(data)
            data["censor_prob_den"] = np.where(
                data["treatment_lag"] == 0,
                model_d0.predict(data),
                model_d1.predict(data)
            )
            self.print_model_summary(model_n, "Numerator Model (Pooled)")
            self.print_model_summary(model_d0, "Denominator Model (Treatment Lag = 0)")
            self.print_model_summary(model_d1, "Denominator Model (Treatment Lag = 1)")
        else:
            num_formula = f"not_censored ~ {self.censor_weights['numerator']}"
            model_n0 = smf.logit(num_formula, data=data[data["treatment_lag"] == 0]).fit(disp=0)
            model_n1 = smf.logit(num_formula, data=data[data["treatment_lag"] == 1]).fit(disp=0)
            den_formula = f"not_censored ~ {self.censor_weights['denominator']}"
            model_d0 = smf.logit(den_formula, data=data[data["treatment_lag"] == 0]).fit(disp=0)
            model_d1 = smf.logit(den_formula, data=data[data["treatment_lag"] == 1]).fit(disp=0)
            self.censor_weights["fitted_models"]["n0"] = model_n0
            self.censor_weights["fitted_models"]["n1"] = model_n1
            self.censor_weights["fitted_models"]["d0"] = model_d0
            self.censor_weights["fitted_models"]["d1"] = model_d1
            
            data["censor_prob_num"] = np.where(
                data["treatment_lag"] == 0,
                model_n0.predict(data),
                model_n1.predict(data)
            )
            data["censor_prob_den"] = np.where(
                data["treatment_lag"] == 0,
                model_d0.predict(data),
                model_d1.predict(data)
            )
            self.print_model_summary(model_n0, "Numerator Model (Treatment Lag = 0)")
            self.print_model_summary(model_n1, "Numerator Model (Treatment Lag = 1)")
            self.print_model_summary(model_d0, "Denominator Model (Treatment Lag = 0)")
            self.print_model_summary(model_d1, "Denominator Model (Treatment Lag = 1)")
        
        data["wtC"] = data["censor_prob_num"] / data["censor_prob_den"]
        data["wtC"] = data.groupby("id")["wtC"].cumprod()
        
        if self.estimand == "Per-Protocol" and self.switch_weights:
            data["stayed_on_treatment"] = (~data["switch"]).astype(int)
            num_formula = f"stayed_on_treatment ~ {self.switch_weights['numerator']}"
            den_formula = f"stayed_on_treatment ~ {self.switch_weights['denominator']}"
            
            model_n0 = smf.logit(num_formula, data=data[data["treatment_lag"] == 0]).fit(disp=0)
            model_n1 = smf.logit(num_formula, data=data[data["treatment_lag"] == 1]).fit(disp=0)
            model_d0 = smf.logit(den_formula, data=data[data["treatment_lag"] == 0]).fit(disp=0)
            model_d1 = smf.logit(den_formula, data=data[data["treatment_lag"] == 1]).fit(disp=0)
            self.switch_weights["fitted_models"]["n0"] = model_n0
            self.switch_weights["fitted_models"]["n1"] = model_n1
            self.switch_weights["fitted_models"]["d0"] = model_d0
            self.switch_weights["fitted_models"]["d1"] = model_d1
            
            data["switch_prob_num"] = np.where(
                data["treatment_lag"] == 0,
                model_n0.predict(data),
                model_n1.predict(data)
            )
            data["switch_prob_den"] = np.where(
                data["treatment_lag"] == 0,
                model_d0.predict(data),
                model_d1.predict(data)
            )
            data["wtS"] = data["switch_prob_num"] / data["switch_prob_den"]
            data["wtS"] = data.groupby("id")["wtS"].cumprod()
            data["weight"] = data["wtC"] * data["wtS"]
            
            print("Switch Weight Models:")
            self.print_model_summary(model_n0, "Numerator Model (Treatment Lag = 0)")
            self.print_model_summary(model_n1, "Numerator Model (Treatment Lag = 1)")
            self.print_model_summary(model_d0, "Denominator Model (Treatment Lag = 0)")
            self.print_model_summary(model_d1, "Denominator Model (Treatment Lag = 1)")
        else:
            data["weight"] = data["wtC"]
        
        self.data = data
        return self

    def set_outcome_model(self, adjustment_terms="x2" if self.estimand == "Intention-to-Treat" else None):
        print("Step 5: Outcome Model Specification")
        print("-----------------------------------")
        if self.estimand == "Intention-to-Treat":
            terms = "assigned_treatment + followup_time + I(followup_time**2) + trial_period + I(trial_period**2)"
            if adjustment_terms:
                terms += f" + {adjustment_terms}"  # Adds x2 only if specified
        else:  # Per-Protocol
            terms = "assigned_treatment + followup_time + I(followup_time**2)"
        self.outcome_model = {"formula": f"outcome ~ {terms}", "fitted": None}
        print(f"Outcome Model Formula: {self.outcome_model['formula']}\n")
        return self

    def set_expansion_options(self, output="memory", chunk_size=500):
        self.expansion_options = {"output": output, "chunk_size": chunk_size}
        return self

    def expand_trials(self):
        data = self.data.copy()
        periods = data["period"].unique()
        expanded_rows = []
        
        for trial_period in periods:
            eligible = data[(data["period"] == trial_period) & (data["eligible"] == 1)]
            for _, row in eligible.iterrows():
                patient_data = data[data["id"] == row["id"]]
                start_idx = patient_data.index[patient_data["period"] == trial_period][0]
                follow_up = patient_data.loc[start_idx:]
                
                assigned_treatment = row["treatment"]
                for t, f_row in enumerate(follow_up.itertuples()):
                    expanded_row = {
                        "id": row["id"],
                        "trial_period": trial_period,
                        "followup_time": t,
                        "outcome": f_row.outcome,
                        "weight": f_row.weight,
                        "treatment": f_row.treatment,
                        "assigned_treatment": assigned_treatment,
                        "x2": f_row.x2,
                        "age": f_row.age
                    }
                    if self.estimand == "Per-Protocol" and f_row.switch:
                        break
                    expanded_rows.append(expanded_row)
        
        self.expanded_data = pd.DataFrame(expanded_rows)
        print("Step 6: Trial Expansion")
        print("------------------------")
        print(f"Expanded Data: {self.expanded_data.shape[0]} rows, {self.expanded_data.shape[1]} columns")
        print("First few rows of expanded data:")
        print(self.expanded_data.head().to_string(index=True), "\n")
        return self

    def load_expanded_data(self, seed=1234, p_control=0.5):
        np.random.seed(seed)
        expanded = self.expanded_data.copy()
        control_mask = (expanded["outcome"] == 0)
        sample_mask = control_mask & (np.random.random(len(expanded)) < p_control)
        self.outcome_data = expanded[sample_mask | ~control_mask].copy()
        self.outcome_data["sample_weight"] = np.where(self.outcome_data["outcome"] == 0, 1 / p_control, 1)
        print("Step 7: Load Expanded Data")
        print("--------------------------")
        print(f"Loaded Data: {self.outcome_data.shape[0]} rows")
        print("First few rows of loaded data:")
        print(self.outcome_data.head().to_string(index=True), "\n")
        return self

    def fit_msm(self, weight_cols=["weight", "sample_weight"], modify_weights=None):
        data = self.outcome_data.copy()
        data["w"] = data[weight_cols].prod(axis=1)
        if modify_weights:
            data["w"] = modify_weights(data["w"])
        
        model = smf.logit(self.outcome_model["formula"], data=data, weights=data["w"]).fit(disp=0)
        self.outcome_model["fitted"] = model
        print("Step 8: Fit Marginal Structural Model")
        print("-------------------------------------")
        print("Marginal Structural Model Summary:")
        print(model.summary())
        print()
        return self

    def predict(self, newdata, predict_times):
        model = self.outcome_model["fitted"]
        params = model.params
        cov = model.cov_params()
        n_bootstrap = 1000
        bootstrap_params = np.random.multivariate_normal(params, cov, size=n_bootstrap)

        rhs = self.outcome_model["formula"].split("~")[1].strip()
        mean_row = newdata[newdata["trial_period"] == 1].mean().to_frame().T
        ref_data = pd.concat([mean_row] * len(predict_times) * 2, ignore_index=True)
        ref_data["followup_time"] = predict_times * 2
        ref_data["assigned_treatment"] = [0] * len(predict_times) + [1] * len(predict_times)

        X_pred = dmatrix(rhs, ref_data, return_type="dataframe")
        trt0_mask = ref_data["assigned_treatment"] == 0
        X_pred_trt0 = X_pred[trt0_mask]
        X_pred_trt1 = X_pred[~trt0_mask]

        linear_pred_trt0 = np.dot(X_pred_trt0, params)
        p_trt0 = 1 / (1 + np.exp(-linear_pred_trt0))
        survival_trt0 = np.cumprod(1 - p_trt0)

        linear_pred_trt1 = np.dot(X_pred_trt1, params)
        p_trt1 = 1 / (1 + np.exp(-linear_pred_trt1))
        survival_trt1 = np.cumprod(1 - p_trt1)

        diff = survival_trt1 - survival_trt0

        diff_bootstrap = np.zeros((n_bootstrap, len(predict_times)))
        for i in range(n_bootstrap):
            beta_i = bootstrap_params[i]
            linear_pred_trt0_i = np.dot(X_pred_trt0, beta_i)
            p_trt0_i = 1 / (1 + np.exp(-linear_pred_trt0_i))
            survival_trt0_i = np.cumprod(1 - p_trt0_i)

            linear_pred_trt1_i = np.dot(X_pred_trt1, beta_i)
            p_trt1_i = 1 / (1 + np.exp(-linear_pred_trt1_i))
            survival_trt1_i = np.cumprod(1 - p_trt1_i)

            diff_bootstrap[i, :] = survival_trt1_i - survival_trt0_i

        ci_lower = np.percentile(diff_bootstrap, 2.5, axis=0)
        ci_upper = np.percentile(diff_bootstrap, 97.5, axis=0)

        return {
            "difference": {
                "followup_time": predict_times,
                "survival_diff": diff,
                "ci_lower": ci_lower,
                "ci_upper": ci_upper
            }
        }

class TrialSequenceITT(TrialSequence):
    def __init__(self):
        super().__init__("Intention-to-Treat")

    def set_censor_weight_model(self, censor_event="censored", numerator="x2", denominator="x2 + x1", pool_models="numerator", model_fitter="stats_glm_logit", save_path=None):
        return super().set_censor_weight_model(censor_event, numerator, denominator, pool_models, model_fitter, save_path)

class TrialSequencePP(TrialSequence):
    def __init__(self):
        super().__init__("Per-Protocol")

    def set_censor_weight_model(self, censor_event="censored", numerator="x2", denominator="x2 + x1", pool_models="none", model_fitter="stats_glm_logit", save_path=None):
        return super().set_censor_weight_model(censor_event, numerator, denominator, pool_models, model_fitter, save_path)

if __name__ == "__main__":
    trial_itt = TrialSequenceITT()
    trial_pp = TrialSequencePP()

    trial_itt_dir = os.path.join(os.path.abspath("."), "trial_itt")
    trial_pp_dir = os.path.join(os.path.abspath("."), "trial_pp")
    os.makedirs(trial_itt_dir, exist_ok=True)
    os.makedirs(trial_pp_dir, exist_ok=True)
    trial_itt.save_dir = trial_itt_dir
    trial_pp.save_dir = trial_pp_dir

    file_path = "data/data_censored.csv"
    data_censored = pd.read_csv(file_path)

    trial_itt.set_data(data_censored)
    trial_pp.set_data(data_censored)

    trial_itt.set_censor_weight_model(save_path=trial_itt_dir)
    trial_pp.set_censor_weight_model(save_path=trial_pp_dir)
    trial_pp.set_switch_weight_model(numerator="age", denominator="age + x1 + x3", model_fitter="stats_glm_logit", save_path=trial_pp_dir)

    trial_itt.calculate_weights()
    trial_pp.calculate_weights()

    trial_itt.set_outcome_model(adjustment_terms="x2")
    trial_itt.set_expansion_options().expand_trials()
    trial_itt.load_expanded_data()
    trial_itt.fit_msm(modify_weights=lambda w: np.minimum(w, np.quantile(w, 0.99)))

    predict_times = list(range(11))
    newdata = trial_itt.outcome_data[trial_itt.outcome_data["trial_period"] == 1]
    preds = trial_itt.predict(newdata, predict_times)

    plt.figure(figsize=(8, 6))
    plt.plot(preds["difference"]["followup_time"], preds["difference"]["survival_diff"], 'b-', label="Survival Difference")
    plt.plot(preds["difference"]["followup_time"], preds["difference"]["ci_lower"], 'r--', label="2.5% CI")
    plt.plot(preds["difference"]["followup_time"], preds["difference"]["ci_upper"], 'r--', label="97.5% CI")
    plt.xlabel("Follow up")
    plt.ylabel("Survival difference")
    plt.ylim(-0.8, 0.1)
    plt.legend()
    plt.grid(True)
    plt.show()

NameError: name 'self' is not defined