In [None]:
import pandas as pd
import numpy as np
import statsmodels.formula.api as smf
from lifelines import CoxPHFitter
import matplotlib.pyplot as plt 
import os
data_censored = pd.DataFrame({
    'id': range(1, 201),
    'period': np.random.choice(range(1, 5), size=200),
    'treatment': np.random.choice([0, 1], size=200),
    'outcome': np.random.normal(size=200),
    'eligible': np.random.choice([True, False], size=200, p=[0.9, 0.1]),
    'age': np.random.randint(20, 80, size=200),
    'x1': np.random.normal(size=200),
    'x2': np.random.normal(size=200),
    'x3': np.random.normal(size=200),
    'censored': np.random.choice([True, False], size=200, p=[0.2, 0.8])
})

print(data_censored.head())
class TrialSequence:
    def __init__(self, estimand):
        self.estimand = estimand
        self.data = None
        self.switch_weight_model = None
        self.censor_weight_model = None
        self.outcome_model = None
        self.switch_weights_data = None
        self.censor_weights_data = None

    def set_data(self, data, id, period, treatment, outcome, eligible):
        self.data = data.copy()
        self.data['id'] = self.data[id]
        self.data['period'] = self.data[period]
        self.data['treatment'] = self.data[treatment]
        self.data['outcome'] = self.data[outcome]
        self.data['eligible'] = self.data[eligible]
        self.data['censored'] = self.data['censored'].astype(int)
        return self

    def set_switch_weight_model(self, numerator, denominator, model_fitter=None):
        self.switch_weight_model = {'numerator': numerator, 'denominator': denominator}
        return self

    def set_censor_weight_model(self, censor_event, numerator, denominator, pool_models=None, model_fitter=None):
        self.censor_weight_model = {
            'censor_event': censor_event,
            'numerator': numerator,
            'denominator': denominator,
            'pool_models': pool_models
        }
        return self

    def calculate_weights(self):
        if self.estimand == "PP":
            formula_sw_num = f"treatment ~ {self.switch_weight_model['numerator'].split('~')[1].strip()}"
            formula_sw_den = f"treatment ~ {self.switch_weight_model['denominator'].split('~')[1].strip()}"

            model_num_sw = smf.logit(formula_sw_num, data=self.data[self.data['period'] == 1]).fit(disp=0)
            model_den_sw = smf.logit(formula_sw_den, data=self.data[self.data['period'] == 1]).fit(disp=0)

            self.switch_weights_data = self.data.copy()
            self.switch_weights_data['num_sw'] = model_num_sw.predict(self.switch_weights_data)
            self.switch_weights_data['den_sw'] = model_den_sw.predict(self.switch_weights_data)
            self.switch_weights_data['switch_weight'] = self.switch_weights_data.apply(
                lambda row: row['num_sw'] / row['den_sw'] if row['eligible'] else 0, axis=1
            )
            formula_c_num = f"{self.censor_weight_model['censor_event']} ~ {self.censor_weight_model['numerator'].split('~')[1].strip()}"
            formula_c_den = f"{self.censor_weight_model['censor_event']} ~ {self.censor_weight_model['denominator'].split('~')[1].strip()}"

            model_num_c = smf.logit(formula_c_num, data=self.data[self.data['period'] == 1]).fit(disp=0)
            model_den_c = smf.logit(formula_c_den, data=self.data[self.data['period'] == 1]).fit(disp=0)

            self.censor_weights_data = self.data.copy()
            self.censor_weights_data['num_c'] = model_num_c.predict(self.censor_weights_data)
            self.censor_weights_data['den_c'] = model_den_c.predict(self.censor_weights_data)
            self.censor_weights_data['censor_weight'] = self.censor_weights_data.apply(
                lambda row: row['num_c'] / row['den_c'] if not row['censored'] else 0, axis=1
            )
            if self.switch_weights_data is not None and self.censor_weights_data is not None:
                self.data = pd.merge(self.data, self.switch_weights_data[['id', 'switch_weight']], on='id', how='left', suffixes=('_orig', '_switch'))
                self.data = pd.merge(self.data, self.censor_weights_data[['id', 'censor_weight']], on='id', how='left', suffixes=('_orig', '_censor'))
                self.data['weight'] = self.data['switch_weight'] * self.data['censor_weight']


        elif self.estimand == "ITT":
            formula_c_num = f"{self.censor_weight_model['censor_event']} ~ {self.censor_weight_model['numerator'].split('~')[1].strip()}"
            formula_c_den = f"{self.censor_weight_model['censor_event']} ~ {self.censor_weight_model['denominator'].split('~')[1].strip()}"

            model_num_c = smf.logit(formula_c_num, data=self.data[self.data['period'] == 1]).fit(disp=0)
            model_den_c = smf.logit(formula_c_den, data=self.data[self.data['period'] == 1]).fit(disp=0)

            self.censor_weights_data = self.data.copy()
            self.censor_weights_data['num_c'] = model_num_c.predict(self.censor_weights_data)
            self.censor_weights_data['den_c'] = model_den_c.predict(self.censor_weights_data)
            self.censor_weights_data['censor_weight'] = self.censor_weights_data.apply(
                lambda row: row['num_c'] / row['den_c'] if not row['censored'] else 0, axis=1
            )
            if self.censor_weights_data is not None:
                self.data = pd.merge(self.data, self.censor_weights_data[['id', 'censor_weight']], on='id', how='left', suffixes=('_orig', '_censor'))
                self.data['weight'] = self.data['censor_weight']
        return self


    def set_outcome_model(self, adjustment_terms=None):
        self.outcome_model_adjustment_terms = adjustment_terms
        return self

    def fit_msm(self, weight_cols, modify_weights=None):
        cph = CoxPHFitter()
        if modify_weights:
            self.data["computed_weight"] = modify_weights(self.data[weight_cols[0]])
        else:
            self.data["computed_weight"] = self.data[weight_cols[0]]
        
        self.data["computed_weight"] = self.data["computed_weight"].apply(lambda x: max(x, 1e-6))

        # Create a DataFrame with just the predictors
        formula = "treatment"
        if self.outcome_model_adjustment_terms is not None:
            formula += " + " + self.outcome_model_adjustment_terms.split('~')[1].strip()
        
        # Pass the predictors directly without using formula
        cph.fit(self.data, duration_col='period', event_col='censored', 
                weights_col='computed_weight', 
                formula=formula)
        
        self.outcome_model = cph
        return self

    def predict(self, newdata, predict_times, type="survival"):
        if type == "survival" and self.outcome_model:
            pred_surv = self.outcome_model.predict_survival_function(newdata, times=predict_times)
            survival_diff_data = []
            for time in predict_times:
                surv_trt1 = pred_surv.loc[time, newdata['treatment'] == 1].mean() if any(newdata['treatment'] == 1) else np.nan
                surv_trt0 = pred_surv.loc[time, newdata['treatment'] == 0].mean() if any(newdata['treatment'] == 0) else np.nan
                survival_diff = surv_trt1 - surv_trt0 if not np.isnan(surv_trt1) and not np.isnan(surv_trt0) else np.nan
                survival_diff_data.append({'followup_time': time, 'survival_diff': survival_diff})
            difference_df = pd.DataFrame(survival_diff_data)
            return {'difference': difference_df}
        else:
            return None

    def __str__(self):
        output_str = f"Trial Sequence (Estimand: {self.estimand})\n"
        if self.switch_weight_model:
            output_str += "\nSwitch Weight Model:\n"
            output_str += f"  Numerator: {self.switch_weight_model['numerator']}\n"
            output_str += f"  Denominator: {self.switch_weight_model['denominator']}\n"
            if hasattr(self, 'switch_weights_data') and 'switch_weight' in self.switch_weights_data.columns:
                output_str += f"  Switch Weights (first 5): {self.switch_weights_data['switch_weight'].head().tolist()}\n"
        if self.censor_weight_model:
            output_str += "\nCensor Weight Model:\n"
            output_str += f"  Censor Event: {self.censor_weight_model['censor_event']}\n"
            output_str += f"  Numerator: {self.censor_weight_model['numerator']}\n"
            output_str += f"  Denominator: {self.censor_weight_model['denominator']}\n"
            output_str += f"  Pooling: {self.censor_weight_model['pool_models']}\n"
            if hasattr(self, 'censor_weights_data') and 'censor_weight' in self.censor_weights_data.columns:
                output_str += f"  Censor Weights (first 5): {self.censor_weights_data['censor_weight'].head().tolist()}\n"
        if self.outcome_model:
            output_str += "\nOutcome Model:\n"
            output_str += f"  Adjustment Terms: {self.outcome_model_adjustment_terms}\n"
            output_str += f"  Model Summary:\n{self.outcome_model.print_summary()}"
        return output_str
temp_dir = "temp_trial_dirs"
trial_pp_dir = os.path.join(temp_dir, "trial_pp")
trial_itt_dir = os.path.join(temp_dir, "trial_itt")

os.makedirs(trial_pp_dir, exist_ok=True)
os.makedirs(trial_itt_dir, exist_ok=True)
trial_pp = TrialSequence(estimand = "PP")
trial_itt = TrialSequence(estimand = "ITT")
trial_pp = trial_pp.set_data(
    data = data_censored,
    id = "id",
    period = "period",
    treatment = "treatment",
    outcome = "outcome",
    eligible = "eligible"
)

trial_itt = trial_itt.set_data(
    data = data_censored,
    id = "id",
    period = "period",
    treatment = "treatment",
    outcome = "outcome",
    eligible = "eligible"
)

print(trial_itt)
trial_pp = trial_pp.set_switch_weight_model(
    numerator = "~ age",
    denominator = "~ age + x1 + x3",
    model_fitter = None
)
if trial_pp.switch_weights_data is not None:
    print("\ntrial_pp.switch_weights (first 5):", trial_pp.switch_weights_data['switch_weight'].head().tolist())

trial_pp = trial_pp.set_censor_weight_model(
    censor_event = "censored",
    numerator = "~ x2",
    denominator = "~ x2 + x1",
    pool_models = "none",
    model_fitter = None
)
if trial_pp.censor_weights_data is not None:
    print("\ntrial_pp.censor_weights (first 5):", trial_pp.censor_weights_data['censor_weight'].head().tolist())


trial_itt = trial_itt.set_censor_weight_model(
    censor_event = "censored",
    numerator = "~x2",
    denominator = "~ x2 + x1",
    pool_models = "numerator",
    model_fitter = None
)
if trial_itt.censor_weights_data is not None:
    print("\ntrial_itt.censor_weights (first 5):", trial_itt.censor_weights_data['censor_weight'].head().tolist())
trial_pp = trial_pp.calculate_weights()
trial_itt = trial_itt.calculate_weights()

print("\nShow Weight Models (trial_itt):")
print(trial_itt)

print("\nShow Weight Models (trial_pp):")
print(trial_pp)
trial_pp = trial_pp.set_outcome_model()
trial_itt = trial_itt.set_outcome_model(adjustment_terms = "~x2")
trial_itt = trial_itt.fit_msm(
    weight_cols = ["weight", "sample_weight"],
    modify_weights = lambda w: np.minimum(w, np.quantile(w, 0.99))
)

print("\ntrial_itt.outcome_model:")
if trial_itt.outcome_model:
    trial_itt.outcome_model.print_summary()


print("\ntrial_itt.outcome_model.model (statsmodels model summary - if applicable):")
if hasattr(trial_itt.outcome_model, 'summary'):
    print(trial_itt.outcome_model.summary)
else:
    print("Outcome model summary not directly available for CoxPHFitter.")


print("\ntrial_itt.outcome_model.covariance (vcov - if applicable):")
if hasattr(trial_itt.outcome_model, 'variance_matrix_'):
    print(trial_itt.outcome_model.variance_matrix_)
else:
    print("Outcome model covariance matrix not directly available for CoxPHFitter.")


print("\ntrial_itt:")
print(trial_itt)
newdata_predict = data_censored[data_censored['period'] == 1].copy()
predict_times = range(0, 11)

preds = trial_itt.predict(
    newdata = newdata_predict,
    predict_times = predict_times,
    type = "survival"
)

if preds and 'difference' in preds:
    difference_df = preds['difference']
    plt.figure(figsize=(8, 6))
    plt.plot(difference_df['followup_time'], difference_df['survival_diff'], label = "Survival difference")
    plt.plot(difference_df['followup_time'], [np.nan] * len(difference_df), color = 'red', linestyle = '--', label = "2.5% and 97.5% CI (Placeholder)")
    plt.plot(difference_df['followup_time'], [np.nan] * len(difference_df), color = 'red', linestyle = '--')

    plt.xlabel("Follow up")
    plt.ylabel("Survival difference")
    plt.title("Survival Difference over Follow-up Time")
    plt.legend()
    plt.show()
else:
    print("Prediction failed or difference data not available.")