In [None]:
import os
import pandas as pd
import numpy as np
import statsmodels.api as sm
from statsmodels.duration.hazard_regression import PHReg
import pickle
import matplotlib.pyplot as plt
from dataclasses import dataclass, field
from typing import Optional, List, Callable, Dict, Any, Union
import tempfile
import random

# Dataclass to mimic the trial_sequence object in R
@dataclass
class TrialSequence:
    estimand: str  # "PP" or "ITT"
    data: Optional[pd.DataFrame] = None
    id_col: Optional[str] = None
    period_col: Optional[str] = None
    treatment_col: Optional[str] = None
    outcome_col: Optional[str] = None
    eligible_col: Optional[str] = None
    switch_weights: Optional[pd.DataFrame] = None
    censor_weights: Optional[pd.DataFrame] = None
    combined_weights: Optional[pd.DataFrame] = None
    outcome_model: Optional[Any] = None
    expansion: Optional[pd.DataFrame] = None
    expansion_options: Optional[Dict] = None
    
    def set_data(self, data, id, period, treatment, outcome, eligible):
        """Set the data and column names for the trial sequence."""
        self.data = data
        self.id_col = id
        self.period_col = period
        self.treatment_col = treatment
        self.outcome_col = outcome
        self.eligible_col = eligible
        return self
    
    def set_switch_weight_model(self, numerator, denominator, model_fitter):
        """Set the switch weight model specifications."""
        # Convert R formula strings to lists of variable names
        num_vars = self._formula_to_vars(numerator)
        denom_vars = self._formula_to_vars(denominator)
        
        # Calculate switch weights
        self.switch_weights = model_fitter.fit(
            self.data, 
            self.treatment_col,
            num_vars,
            denom_vars,
            self.id_col,
            self.period_col
        )
        
        return self
    
    def set_censor_weight_model(self, censor_event, numerator, denominator, pool_models, model_fitter):
        """Set the censoring weight model specifications."""
        # Convert R formula strings to lists of variable names
        num_vars = self._formula_to_vars(numerator)
        denom_vars = self._formula_to_vars(denominator)
        
        # Calculate censoring weights
        censor_calculator = CensorWeightCalculator(
            model_fitter=model_fitter,
            censor_event=censor_event,
            pool_models=pool_models
        )
        
        self.censor_weights = censor_calculator.fit(
            self.data,
            num_vars,
            denom_vars,
            self.id_col,
            self.period_col
        )
        
        return self
    
    def calculate_weights(self):
        """Combine switch and censor weights."""
        if self.switch_weights is None and self.censor_weights is None:
            raise ValueError("No weights have been calculated yet.")
        
        # Create a base DataFrame with all patient-periods
        all_periods = pd.DataFrame({
            self.id_col: self.data[self.id_col].unique()
        }).merge(
            pd.DataFrame({self.period_col: self.data[self.period_col].unique()}),
            how='cross'
        )
        
        # Merge with switch weights if available
        if self.switch_weights is not None:
            all_periods = pd.merge(
                all_periods,
                self.switch_weights[[self.id_col, self.period_col, 'weight']],
                on=[self.id_col, self.period_col],
                how='left'
            )
            all_periods.rename(columns={'weight': 'switch_weight'}, inplace=True)
            all_periods['switch_weight'].fillna(1.0, inplace=True)
        else:
            all_periods['switch_weight'] = 1.0
        
        # Merge with censor weights if available
        if self.censor_weights is not None:
            all_periods = pd.merge(
                all_periods,
                self.censor_weights[[self.id_col, self.period_col, 'weight']],
                on=[self.id_col, self.period_col],
                how='left'
            )
            all_periods.rename(columns={'weight': 'censor_weight'}, inplace=True)
            all_periods['censor_weight'].fillna(1.0, inplace=True)
        else:
            all_periods['censor_weight'] = 1.0
        
        # Calculate combined weight
        all_periods['weight'] = all_periods['switch_weight'] * all_periods['censor_weight']
        
        # Store combined weights
        self.combined_weights = all_periods
        
        return self
    
    def set_outcome_model(self, adjustment_terms=None):
        """Set up the outcome model for survival analysis."""
        if adjustment_terms is None:
            self.outcome_model = OutcomeModel()
        else:
            adj_vars = self._formula_to_vars(adjustment_terms)
            self.outcome_model = OutcomeModel(adjustment_vars=adj_vars)
        
        return self
    
    def set_expansion_options(self, output=None, chunk_size=500):
        """Set options for trial expansion."""
        self.expansion_options = {
            'output_handler': output,
            'chunk_size': chunk_size
        }
        return self
    
    def expand_trials(self):
        """Expand the trial data for analysis."""
        if self.expansion_options is None:
            raise ValueError("Expansion options not set. Call set_expansion_options first.")
        
        # Get unique individuals
        individuals = self.data[self.id_col].unique()
        
        # Process in chunks
        chunk_size = self.expansion_options['chunk_size']
        results = []
        
        for i in range(0, len(individuals), chunk_size):
            chunk_ids = individuals[i:i+chunk_size]
            
            # Filter data for current chunk
            chunk_data = self.data[self.data[self.id_col].isin(chunk_ids)].copy()
            
            # Create expanded data
            expanded = self._expand_individuals(chunk_data)
            results.append(expanded)
        
        # Combine results
        self.expansion = pd.concat(results, ignore_index=True)
        
        return self
    
    def _expand_individuals(self, data):
        """Create expanded data for a set of individuals."""
        expanded_data = []
        
        for id_val in data[self.id_col].unique():
            # Get data for this individual
            indiv_data = data[data[self.id_col] == id_val].sort_values(by=self.period_col)
            
            # For each period where the individual is eligible
            for _, row in indiv_data[indiv_data[self.eligible_col] == 1].iterrows():
                period = row[self.period_col]
                treatment = row[self.treatment_col]
                
                # Create a trial record
                trial_record = {
                    self.id_col: id_val,
                    'trial_period': period,
                    'trial_arm': treatment,
                    'original_' + self.period_col: period
                }
                
                # Add original outcome and other columns
                for col in [self.outcome_col, self.treatment_col]:
                    trial_record['original_' + col] = row[col]
                
                expanded_data.append(trial_record)
        
        return pd.DataFrame(expanded_data)
    
    def load_expanded_data(self, seed=None, p_control=0.5):
        """Load expanded data and apply sampling weights."""
        if self.expansion is None:
            raise ValueError("No expanded data available. Call expand_trials first.")
        
        # Set random seed if provided
        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)
        
        # Sample from expanded data
        expanded_data = self.expansion.copy()
        
        # Calculate sampling weights based on treatment assignment
        expanded_data['sample_weight'] = np.where(
            expanded_data['trial_arm'] == 0,  # Assuming 0 is control
            1.0 / p_control,
            1.0 / (1.0 - p_control)
        )
        
        # Store back to expansion
        self.expansion = expanded_data
        
        return self
    
    def fit_msm(self, weight_cols, modify_weights=None):
        """Fit a marginal structural model."""
        if self.expansion is None:
            raise ValueError("No expanded data available. Call expand_trials and load_expanded_data first.")
        
        # Prepare data for modeling
        model_data = self.expansion.copy()
        
        # Combine weights
        model_data['combined_weight'] = 1.0
        for col in weight_cols:
            model_data['combined_weight'] *= model_data[col]
        
        # Apply weight modification if provided
        if modify_weights is not None:
            model_data['combined_weight'] = modify_weights(model_data['combined_weight'])
        
        # Fit outcome model
        if self.outcome_model is None:
            self.set_outcome_model()
        
        self.outcome_model.fit(
            data=model_data,
            id_col=self.id_col,
            time_col='followup_time',  # Assuming this is in the expanded data
            event_col=self.outcome_col,
            treatment_col='trial_arm',
            weight_col='combined_weight'
        )
        
        return self
    
    def predict(self, newdata, predict_times, type="survival"):
        """Predict outcomes based on the fitted model."""
        if self.outcome_model is None or not self.outcome_model.is_fitted():
            raise ValueError("Outcome model not fitted. Call fit_msm first.")
        
        # Prepare prediction data
        pred_data = newdata.copy()
        
        # Make predictions for each treatment arm
        results = {}
        for arm in [0, 1]:  # Assuming binary treatment (0=control, 1=treatment)
            pred_data['trial_arm'] = arm
            surv_curves = self.outcome_model.predict(pred_data, predict_times)
            results[f'arm_{arm}'] = surv_curves
        
        # Calculate difference (treatment effect)
        diff_data = pd.DataFrame({
            'followup_time': predict_times,
            'survival_diff': results['arm_1']['survival'] - results['arm_0']['survival'],
            '2.5%': results['arm_1']['lower'] - results['arm_0']['upper'],  # Conservative CI
            '97.5%': results['arm_1']['upper'] - results['arm_0']['lower']  # Conservative CI
        })
        
        return {
            'arm_0': results['arm_0'],
            'arm_1': results['arm_1'],
            'difference': diff_data
        }
    
    def _formula_to_vars(self, formula):
        """Convert an R-style formula to a list of variable names."""
        if isinstance(formula, str):
            # Remove ~ and split by +
            parts = formula.replace("~", "").split("+")
            return [part.strip() for part in parts]
        else:
            # For our example, we'll handle the simpler case where the formula is a string like "~ age + x1 + x3"
            # For the case where formula is passed as a tilde object in R, we just extract the string after the tilde
            formula_str = formula.replace("~", "")
            parts = formula_str.split("+")
            return [part.strip() for part in parts]


class StatsGlmLogit:
    def __init__(self, save_path=None):
        self.save_path = save_path
        if save_path and not os.path.exists(save_path):
            os.makedirs(save_path)
    
    def fit(self, data, treatment_col, numerator_vars, denominator_vars, id_col, period_col):
        """Fit logistic regression models and calculate stabilized weights."""
        # Get periods where treatment can switch
        periods = sorted(data[period_col].unique())
        
        # Initialize DataFrame to store weights
        weights_df = pd.DataFrame()
        
        for period in periods[1:]:  # Skip the first period as there's no prior treatment to switch from
            # Get data for current period
            period_data = data[data[period_col] == period].copy()
            
            if len(period_data) == 0:
                continue
                
            # Get previous period for each individual
            prev_period = periods[periods.index(period) - 1]
            prev_data = data[data[period_col] == prev_period].copy()
            
            # Merge current with previous period data
            merged_data = pd.merge(
                period_data,
                prev_data[[id_col, treatment_col]],
                on=id_col,
                suffixes=('', '_prev')
            )
            
            # Identify individuals who switched treatment
            merged_data['switched'] = (merged_data[treatment_col] != merged_data[f"{treatment_col}_prev"]).astype(int)
            
            # Fit numerator model (simpler model)
            X_num = sm.add_constant(merged_data[numerator_vars])
            num_model = sm.Logit(merged_data['switched'], X_num).fit(disp=0)
            
            # Save model if requested
            if self.save_path:
                with open(os.path.join(self.save_path, f"num_model_period_{period}.pkl"), 'wb') as f:
                    pickle.dump(num_model, f)
            
            # Fit denominator model (full model)
            X_denom = sm.add_constant(merged_data[denominator_vars])
            denom_model = sm.Logit(merged_data['switched'], X_denom).fit(disp=0)
            
            # Save model if requested
            if self.save_path:
                with open(os.path.join(self.save_path, f"denom_model_period_{period}.pkl"), 'wb') as f:
                    pickle.dump(denom_model, f)
            
            # Calculate predicted probabilities
            num_probs = num_model.predict(X_num)
            denom_probs = denom_model.predict(X_denom)
            
            # Calculate stabilized weights
            merged_data['weight'] = num_probs / denom_probs
            merged_data['weight'] = merged_data['weight'].fillna(1.0)  # Handle division by zero
            
            # Add to weights DataFrame
            weights_df = pd.concat([weights_df, merged_data[[id_col, period_col, 'weight']]])
        
        return weights_df


class CensorWeightCalculator:
    def __init__(self, model_fitter, censor_event, pool_models="none"):
        self.model_fitter = model_fitter
        self.censor_event = censor_event
        self.pool_models = pool_models  # "none", "numerator", or "denominator"
    
    def fit(self, data, numerator_vars, denominator_vars, id_col, period_col):
        """Calculate censoring weights."""
        # Create a copy of the data
        data_copy = data.copy()
        
        # Create pooled models if required
        if self.pool_models in ["numerator", "both"]:
            pooled_num_model = self._fit_pooled_model(data_copy, numerator_vars)
        
        if self.pool_models in ["denominator", "both"]:
            pooled_denom_model = self._fit_pooled_model(data_copy, denominator_vars)
        
        # Get periods
        periods = sorted(data_copy[period_col].unique())
        
        # Initialize DataFrame to store weights
        weights_df = pd.DataFrame()
        
        for period in periods:
            # Get data for current period
            period_data = data_copy[data_copy[period_col] == period].copy()
            
            if len(period_data) == 0:
                continue
            
            # Define censoring indicator
            period_data['is_censored'] = (period_data[self.censor_event] == 1).astype(int)
            
            # Check if we have variation in the outcome
            if period_data['is_censored'].nunique() <= 1:
                # No variation, assign weight of 1.0
                period_data['weight'] = 1.0
            else:
                try:
                    # Fit period-specific models or use pooled models
                    if self.pool_models != "numerator":
                        # Fit numerator model for this period with regularization
                        X_num = sm.add_constant(period_data[numerator_vars])
                        num_model = sm.Logit(period_data['is_censored'], X_num).fit_regularized(
                            alpha=0.01, disp=0
                        )
                        num_probs = num_model.predict(X_num)
                    else:
                        # Use pooled numerator model
                        X_num = sm.add_constant(period_data[numerator_vars])
                        num_probs = pooled_num_model.predict(X_num)
                    
                    if self.pool_models != "denominator":
                        # Fit denominator model for this period with regularization
                        X_denom = sm.add_constant(period_data[denominator_vars])
                        denom_model = sm.Logit(period_data['is_censored'], X_denom).fit_regularized(
                            alpha=0.01, disp=0
                        )
                        denom_probs = denom_model.predict(X_denom)
                    else:
                        # Use pooled denominator model
                        X_denom = sm.add_constant(period_data[denominator_vars])
                        denom_probs = pooled_denom_model.predict(X_denom)
                    
                    # Ensure probabilities are not exactly 0 or 1
                    num_probs = np.clip(num_probs, 0.001, 0.999)
                    denom_probs = np.clip(denom_probs, 0.001, 0.999)
                    
                    # Calculate stabilized weights
                    period_data['weight'] = (1 - num_probs) / (1 - denom_probs)
                    
                except (np.linalg.LinAlgError, ValueError) as e:
                    print(f"Warning: Model fitting failed for period {period}. Setting weights to 1.0. Error: {e}")
                    period_data['weight'] = 1.0
                
            # Handle any remaining NaNs or infinities
            period_data['weight'] = period_data['weight'].fillna(1.0)
            period_data.loc[np.isinf(period_data['weight']), 'weight'] = 1.0
            
            # Trim extreme weights
            q99 = np.percentile(period_data['weight'], 99)
            period_data.loc[period_data['weight'] > q99, 'weight'] = q99
            
            # Add to weights DataFrame
            weights_df = pd.concat([weights_df, period_data[[id_col, period_col, 'weight']]])
        
        return weights_df

    def _fit_pooled_model(self, data, vars_list):
        """Fit a pooled model across all periods."""
        # Create censoring indicator
        data['is_censored'] = (data[self.censor_event] == 1).astype(int)
        
        # Check if we have variation in the outcome
        if data['is_censored'].nunique() <= 1:
            # Return a dummy model that always predicts the constant
            constant_prob = data['is_censored'].mean()
            class DummyModel:
                def predict(self, X):
                    return np.ones(len(X)) * constant_prob
            return DummyModel()
        
        try:
            # Fit model with regularization
            X = sm.add_constant(data[vars_list])
            model = sm.Logit(data['is_censored'], X).fit_regularized(alpha=0.01, disp=0)
            return model
        except (np.linalg.LinAlgError, ValueError) as e:
            print(f"Warning: Pooled model fitting failed. Creating dummy model. Error: {e}")
            # Return a dummy model
            constant_prob = data['is_censored'].mean()
            class DummyModel:
                def predict(self, X):
                    return np.ones(len(X)) * constant_prob
            return DummyModel()


class OutcomeModel:
    def __init__(self, adjustment_vars=None):
        self.adjustment_vars = adjustment_vars
        self.fitted_model = None
        self.model_info = None
    
    def fit(self, data, id_col, time_col, event_col, treatment_col, weight_col):
        """Fit a proportional hazards model."""
        model_data = data.copy()
        
        # Prepare formula
        if self.adjustment_vars:
            formula = f"{time_col} ~ {treatment_col} + " + " + ".join(self.adjustment_vars)
        else:
            formula = f"{time_col} ~ {treatment_col}"
        
        # Fit Cox PH model
        model = PHReg.from_formula(
            formula,
            data=model_data,
            status=model_data[event_col],
            weights=model_data[weight_col]
        )
        
        result = model.fit()
        
        # Store fitted model
        self.fitted_model = result
        
        # Store model info
        self.model_info = {
            'model': model,
            'vcov': result.cov_params(),
            'formula': formula
        }
        
        return self
    
    def is_fitted(self):
        """Check if model is fitted."""
        return self.fitted_model is not None
    
    def predict(self, data, times):
        """Predict survival probabilities."""
        if not self.is_fitted():
            raise ValueError("Model not fitted yet.")
        
        # Prepare prediction data
        pred_data = data.copy()
        
        # Get baseline survival
        baseline_surv = self._estimate_baseline_survival(times)
        
        # Get linear predictor for new data
        lp = self._calculate_linear_predictor(pred_data)
        
        # Calculate survival probabilities
        survival = np.power(baseline_surv, np.exp(lp))
        
        # Calculate confidence intervals (simplified)
        # For a proper implementation, we would need to calculate the variance of the survival estimates
        ci_width = 1.96 * 0.1 * survival  # Simplified CI
        
        return {
            'times': times,
            'survival': survival,
            'lower': np.maximum(0, survival - ci_width),
            'upper': np.minimum(1, survival + ci_width)
        }
    
    def _estimate_baseline_survival(self, times):
        """Estimate baseline survival function."""
        # This is a simplified implementation
        # A proper implementation would estimate the baseline hazard from the data
        
        # For simplicity, we'll use an exponential model
        # In real implementation, this would be based on the fitted model's baseline hazard
        lambda_hat = 0.1  # Placeholder hazard rate
        return np.exp(-lambda_hat * np.array(times))
    
    def _calculate_linear_predictor(self, data):
        """Calculate linear predictor for new data."""
        # Extract coefficients from fitted model
        coefs = self.fitted_model.params
        
        # Prepare design matrix
        X = np.ones((len(data), len(coefs)))
        X[:, 0] = data['trial_arm']  # Treatment indicator
        
        # Add adjustment variables if specified
        if self.adjustment_vars:
            for i, var in enumerate(self.adjustment_vars):
                X[:, i+1] = data[var]
        
        # Calculate linear predictor
        return X @ coefs


def trial_sequence(estimand):
    """Create a new trial sequence object."""
    return TrialSequence(estimand=estimand)


def stats_glm_logit(save_path=None):
    """Create a logistic regression model fitter."""
    return StatsGlmLogit(save_path=save_path)


def save_to_datatable():
    """Handler for saving expanded data to a data table."""
    # This is a placeholder for the actual implementation
    def handler(data):
        return data
    return handler


def outcome_data(trial):
    """Extract outcome data from a trial sequence."""
    if trial.expansion is not None:
        return trial.expansion
    return pd.DataFrame()


def show_weight_models(trial):
    """Display information about weight models."""
    print("Weight Models Summary:")
    if hasattr(trial, 'switch_weights') and trial.switch_weights is not None:
        print(f"  Switch Weights: {len(trial.switch_weights)} rows")
    
    if hasattr(trial, 'censor_weights') and trial.censor_weights is not None:
        print(f"  Censor Weights: {len(trial.censor_weights)} rows")
    
    if hasattr(trial, 'combined_weights') and trial.combined_weights is not None:
        print(f"  Combined Weights: {len(trial.combined_weights)} rows")


# Example usage
if __name__ == "__main__":
    # Load data
    data_censored = pd.read_csv("data/data_censored.csv")
    
    # Create trial sequence objects
    trial_pp = trial_sequence(estimand="PP")  # Per-protocol
    trial_itt = trial_sequence(estimand="ITT")  # Intention-to-treat
    
    # Create directories
    trial_pp_dir = os.path.join(tempfile.gettempdir(), "trial_pp")
    os.makedirs(trial_pp_dir, exist_ok=True)
    
    trial_itt_dir = os.path.join(tempfile.gettempdir(), "trial_itt")
    os.makedirs(trial_itt_dir, exist_ok=True)
    
    # Set data for both trial sequences
    trial_pp = trial_pp.set_data(
        data=data_censored,
        id="id",
        period="period",
        treatment="treatment",
        outcome="outcome",
        eligible="eligible"
    )
    
    trial_itt = trial_itt.set_data(
        data=data_censored,
        id="id",
        period="period",
        treatment="treatment",
        outcome="outcome",
        eligible="eligible"
    )
    
    # Set switch weight models
    trial_pp = trial_pp.set_switch_weight_model(
        numerator="~ age",
        denominator="~ age + x1 + x3",
        model_fitter=stats_glm_logit(save_path=os.path.join(trial_pp_dir, "switch_models"))
    )
    
    print(trial_pp)
    # Set censor weight models
    trial_pp = trial_pp.set_censor_weight_model(
        censor_event="censored",
        numerator="~ x2",
        denominator="~ x2 + x1",
        pool_models="none",
        model_fitter=stats_glm_logit(save_path=os.path.join(trial_pp_dir, "switch_models"))
    )
    
    trial_itt = trial_itt.set_censor_weight_model(
        censor_event="censored",
        numerator="~ x2",
        denominator="~ x2 + x1",
        pool_models="numerator",
        model_fitter=stats_glm_logit(save_path=os.path.join(trial_itt_dir, "switch_models"))
    )
    
    # Calculate weights
    trial_pp = trial_pp.calculate_weights()
    trial_itt = trial_itt.calculate_weights()
    
    # Show weight models
    show_weight_models(trial_itt)
    show_weight_models(trial_pp)
    
    # Set outcome models
    trial_pp = trial_pp.set_outcome_model()
    trial_itt = trial_itt.set_outcome_model(adjustment_terms="~ x2")
    
    # Set expansion options
    trial_pp = trial_pp.set_expansion_options(
        output=save_to_datatable(),
        chunk_size=500
    )
    
    trial_itt = trial_itt.set_expansion_options(
        output=save_to_datatable(),
        chunk_size=500
    )
    
    # Expand trials
    trial_pp = trial_pp.expand_trials()
    trial_itt = trial_itt.expand_trials()
    
    # Load expanded data and fit MSM
    trial_itt = trial_itt.load_expanded_data(seed=1234, p_control=0.5)
    trial_itt = trial_itt.fit_msm(
        weight_cols=["weight", "sample_weight"],
        modify_weights=lambda w: np.minimum(w, np.quantile(w, 0.99))  # Winsorization
    )
    
    # Make predictions
    prediction_data = outcome_data(trial_itt)
    prediction_data = prediction_data[prediction_data['trial_period'] == 1]
    
    preds = trial_itt.predict(
        newdata=prediction_data,
        predict_times=list(range(11)),  # 0 to 10
        type="survival"
    )
    
    # Plot results
    plt.figure(figsize=(10, 6))
    plt.plot(preds['difference']['followup_time'], preds['difference']['survival_diff'])
    plt.plot(preds['difference']['followup_time'], preds['difference']['2.5%'], 'r--')
    plt.plot(preds['difference']['followup_time'], preds['difference']['97.5%'], 'r--')
    plt.xlabel('Follow up')
    plt.ylabel('Survival difference')
    plt.title('Treatment Effect on Survival')
    plt.grid(True)
    plt.savefig('survival_difference.png')
    plt.close()

TrialSequence(estimand='PP', data=     id  period  treatment  x1        x2  x3        x4  age     age_s  \
0     1       0          1   1  1.146148   0  0.734203   36  0.083333   
1     1       1          1   1  0.002200   0  0.734203   37  0.166667   
2     1       2          1   0 -0.481762   0  0.734203   38  0.250000   
3     1       3          1   0  0.007872   0  0.734203   39  0.333333   
4     1       4          1   1  0.216054   0  0.734203   40  0.416667   
..   ..     ...        ...  ..       ...  ..       ...  ...       ...   
720  99       3          0   0 -0.747906   1  0.575268   68  2.750000   
721  99       4          0   0 -0.790056   1  0.575268   69  2.833333   
722  99       5          1   1  0.387429   1  0.575268   70  2.916667   
723  99       6          1   1 -0.033762   1  0.575268   71  3.000000   
724  99       7          0   0 -1.340497   1  0.575268   72  3.083333   

     outcome  censored  eligible  
0          0         0         1  
1          0       

The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  all_periods['switch_weight'].fillna(1.0, inplace=True)
The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  all_periods['censor_weight'].fillna(1.0, inplace=True)
The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we ar

Weight Models Summary:
  Censor Weights: 725 rows
  Combined Weights: 1780 rows
Weight Models Summary:
  Switch Weights: 636 rows
  Censor Weights: 725 rows
  Combined Weights: 1780 rows


KeyError: 'weight'

In [None]:
import os
import pandas as pd
import numpy as np
import statsmodels.api as sm
from statsmodels.duration.hazard_regression import PHReg
import pickle
import matplotlib.pyplot as plt
from dataclasses import dataclass, field
from typing import Optional, List, Callable, Dict, Any, Union
import tempfile
import random

# Dataclass to mimic the trial_sequence object in R
@dataclass
class TrialSequence:
    estimand: str  # "PP" or "ITT"
    data: Optional[pd.DataFrame] = None
    id_col: Optional[str] = None
    period_col: Optional[str] = None
    treatment_col: Optional[str] = None
    outcome_col: Optional[str] = None
    eligible_col: Optional[str] = None
    switch_weights: Optional[pd.DataFrame] = None
    censor_weights: Optional[pd.DataFrame] = None
    combined_weights: Optional[pd.DataFrame] = None
    outcome_model: Optional[Any] = None
    expansion: Optional[pd.DataFrame] = None
    expansion_options: Optional[Dict] = None
    
    def set_data(self, data, id, period, treatment, outcome, eligible):
        """Set the data and column names for the trial sequence."""
        self.data = data
        self.id_col = id
        self.period_col = period
        self.treatment_col = treatment
        self.outcome_col = outcome
        self.eligible_col = eligible
        return self
    
    def set_switch_weight_model(self, numerator, denominator, model_fitter):
        """Set the switch weight model specifications."""
        # Convert R formula strings to lists of variable names
        num_vars = self._formula_to_vars(numerator)
        denom_vars = self._formula_to_vars(denominator)
        
        # Calculate switch weights
        self.switch_weights = model_fitter.fit(
            self.data, 
            self.treatment_col,
            num_vars,
            denom_vars,
            self.id_col,
            self.period_col
        )
        
        return self
    
    def set_censor_weight_model(self, censor_event, numerator, denominator, pool_models, model_fitter):
        """Set the censoring weight model specifications."""
        # Convert R formula strings to lists of variable names
        num_vars = self._formula_to_vars(numerator)
        denom_vars = self._formula_to_vars(denominator)
        
        # Calculate censoring weights
        censor_calculator = CensorWeightCalculator(
            model_fitter=model_fitter,
            censor_event=censor_event,
            pool_models=pool_models
        )
        
        self.censor_weights = censor_calculator.fit(
            self.data,
            num_vars,
            denom_vars,
            self.id_col,
            self.period_col
        )
        
        return self
    
    def calculate_weights(self):
        """Combine switch and censor weights."""
        if self.switch_weights is None and self.censor_weights is None:
            raise ValueError("No weights have been calculated yet.")
        
        # Create a base DataFrame with all patient-periods
        all_periods = pd.DataFrame({
            self.id_col: self.data[self.id_col].unique()
        }).merge(
            pd.DataFrame({self.period_col: self.data[self.period_col].unique()}),
            how='cross'
        )
        
        # Merge with switch weights if available
        if self.switch_weights is not None:
            all_periods = pd.merge(
                all_periods,
                self.switch_weights[[self.id_col, self.period_col, 'weight']],
                on=[self.id_col, self.period_col],
                how='left'
            )
            all_periods.rename(columns={'weight': 'switch_weight'}, inplace=True)
            all_periods['switch_weight'].fillna(1.0, inplace=True)
        else:
            all_periods['switch_weight'] = 1.0
        
        # Merge with censor weights if available
        if self.censor_weights is not None:
            all_periods = pd.merge(
                all_periods,
                self.censor_weights[[self.id_col, self.period_col, 'weight']],
                on=[self.id_col, self.period_col],
                how='left'
            )
            all_periods.rename(columns={'weight': 'censor_weight'}, inplace=True)
            all_periods['censor_weight'].fillna(1.0, inplace=True)
        else:
            all_periods['censor_weight'] = 1.0
        
        # Calculate combined weight
        all_periods['weight'] = all_periods['switch_weight'] * all_periods['censor_weight']
        
        # Store combined weights
        self.combined_weights = all_periods
        
        return self
    
    def set_outcome_model(self, adjustment_terms=None):
        """Set up the outcome model for survival analysis."""
        if adjustment_terms is None:
            self.outcome_model = OutcomeModel()
        else:
            adj_vars = self._formula_to_vars(adjustment_terms)
            self.outcome_model = OutcomeModel(adjustment_vars=adj_vars)
        
        return self
    
    def set_expansion_options(self, output=None, chunk_size=500):
        """Set options for trial expansion."""
        self.expansion_options = {
            'output_handler': output,
            'chunk_size': chunk_size
        }
        return self
    
    def expand_trials(self):
        """Expand the trial data for analysis."""
        if self.expansion_options is None:
            raise ValueError("Expansion options not set. Call set_expansion_options first.")
        
        # Get unique individuals
        individuals = self.data[self.id_col].unique()
        
        # Process in chunks
        chunk_size = self.expansion_options['chunk_size']
        results = []
        
        for i in range(0, len(individuals), chunk_size):
            chunk_ids = individuals[i:i+chunk_size]
            
            # Filter data for current chunk
            chunk_data = self.data[self.data[self.id_col].isin(chunk_ids)].copy()
            
            # Create expanded data
            expanded = self._expand_individuals(chunk_data)
            results.append(expanded)
        
        # Combine results
        self.expansion = pd.concat(results, ignore_index=True)
        
        return self
    
    def _expand_individuals(self, data):
        """Create expanded data for a set of individuals with survival times."""
        expanded_data = []
        
        for id_val in data[self.id_col].unique():
            # Get individual's data, sorted by period
            indiv_data = data[data[self.id_col] == id_val].sort_values(by=self.period_col)
            
            # Find event time: first period where outcome == 1 or censored == 1
            event_rows = indiv_data[(indiv_data[self.outcome_col] == 1) | (indiv_data['censored'] == 1)]
            if not event_rows.empty:
                event_time = event_rows.iloc[0][self.period_col]
                event_status = event_rows.iloc[0][self.outcome_col]
            else:
                # No event or censoring; assume censored at last period
                event_time = indiv_data[self.period_col].max()
                event_status = 0
            
            # Eligible start periods are before or at event_time
            eligible_data = indiv_data[(indiv_data[self.eligible_col] == 1) & 
                                    (indiv_data[self.period_col] <= event_time)]
            
            for _, row in eligible_data.iterrows():
                start_period = row[self.period_col]
                trial_arm = row[self.treatment_col]
                survival_time = event_time - start_period
                event = event_status
                
                # Create trial record
                record = {
                    self.id_col: id_val,
                    'trial_period': start_period,
                    'trial_arm': trial_arm,
                    'survival_time': survival_time,
                    'event': event,
                }
                
                # Include baseline covariates from start period
                for col in self.data.columns:
                    if col not in [self.id_col, self.period_col, self.treatment_col, 
                                self.outcome_col, 'censored', self.eligible_col]:
                        record[col] = row[col]
                
                # Include weight for this trial from combined_weights
                if self.combined_weights is not None:
                    weight_row = self.combined_weights[
                        (self.combined_weights[self.id_col] == id_val) & 
                        (self.combined_weights[self.period_col] == start_period)
                    ]
                    record['weight'] = weight_row['weight'].values[0] if not weight_row.empty else 1.0
                else:
                    record['weight'] = 1.0
                
                expanded_data.append(record)
        
        return pd.DataFrame(expanded_data)
    
    def load_expanded_data(self, seed=None, p_control=0.5):
        """Load expanded data and apply sampling weights."""
        if self.expansion is None:
            raise ValueError("No expanded data available. Call expand_trials first.")
        
        # Set random seed if provided
        if seed is not None:
            random.seed(seed)
            np.random.seed(seed)
        
        # Sample from expanded data
        expanded_data = self.expansion.copy()
        
        # Calculate sampling weights based on treatment assignment
        expanded_data['sample_weight'] = np.where(
            expanded_data['trial_arm'] == 0,  # Assuming 0 is control
            1.0 / p_control,
            1.0 / (1.0 - p_control)
        )
        
        # Store back to expansion
        self.expansion = expanded_data
        
        return self
    
    def fit_msm(self, weight_cols, modify_weights=None):
        """
        Fit a marginal structural model using a Cox proportional hazards model.

        Parameters:
        - weight_cols: List of column names for additional weights (e.g., censoring weights).
        - modify_weights: Optional function to modify combined weights (e.g., winsorization).
        """
        # Check if expanded data exists
        if self.expansion is None:
            raise ValueError("No expanded data available. Call expand_trials first.")
        
        # Prepare a copy of the expanded data
        model_data = self.expansion.copy()
        
        # Combine weights: start with trial weights, then multiply by additional weights
        model_data['combined_weight'] = model_data['weight']
        for col in weight_cols:
            if col in model_data.columns:
                model_data['combined_weight'] *= model_data[col]
        
        # Apply weight modification if provided
        if modify_weights is not None:
            model_data['combined_weight'] = modify_weights(model_data['combined_weight'])
        
        # Ensure the outcome model is set
        if self.outcome_model is None:
            self.set_outcome_model()
        
        # Define the time variable (endog), event status (status), and covariates (exog)
        endog = model_data['survival_time']  # Time-to-event variable
        status = model_data['event']         # Event indicator (1 if event occurred, 0 if censored)
        exog = sm.add_constant(model_data[['trial_arm']])  # Covariates: intercept + treatment
        
        # Include adjustment variables if specified
        if self.outcome_model.adjustment_vars:
            exog = pd.concat([exog, model_data[self.outcome_model.adjustment_vars]], axis=1)
        
        # Fit the Cox PH model
        model = PHReg(endog, exog, status=status, weights=model_data['combined_weight'])
        self.outcome_model.fitted_model = model.fit()
        
        # Store model information
        self.outcome_model.model_info = {
            'model': model,
            'vcov': self.outcome_model.fitted_model.cov_params(),
            'exog_names': exog.columns.tolist()
        }
        
        return self
    
    def predict(self, newdata, predict_times, type="survival"):
        """Predict outcomes based on the fitted model."""
        if self.outcome_model is None or not self.outcome_model.is_fitted():
            raise ValueError("Outcome model not fitted. Call fit_msm first.")
        
        # Prepare prediction data
        pred_data = newdata.copy()
        
        # Make predictions for each treatment arm
        results = {}
        for arm in [0, 1]:  # Binary treatment (0=control, 1=treatment)
            pred_data['trial_arm'] = arm
            surv_curves = self.outcome_model.predict(pred_data, predict_times)
            # Compute mean over individuals (axis=0) for each time point
            mean_survival = np.mean(surv_curves['survival'], axis=0)
            mean_lower = np.mean(surv_curves['lower'], axis=0)
            mean_upper = np.mean(surv_curves['upper'], axis=0)
            results[f'arm_{arm}'] = {
                'times': surv_curves['times'],
                'survival': mean_survival,
                'lower': mean_lower,
                'upper': mean_upper
            }
        
        # Calculate difference (treatment effect)
        diff_data = pd.DataFrame({
            'followup_time': predict_times,
            'survival_diff': results['arm_1']['survival'] - results['arm_0']['survival'],
            '2.5%': results['arm_1']['lower'] - results['arm_0']['upper'],  # Simplified CI
            '97.5%': results['arm_1']['upper'] - results['arm_0']['lower']  # Simplified CI
        })
        
        return {
            'arm_0': results['arm_0'],
            'arm_1': results['arm_1'],
            'difference': diff_data
        }
    
    def _formula_to_vars(self, formula):
        """Convert an R-style formula to a list of variable names."""
        if isinstance(formula, str):
            # Remove ~ and split by +
            parts = formula.replace("~", "").split("+")
            return [part.strip() for part in parts]
        else:
            # For our example, we'll handle the simpler case where the formula is a string like "~ age + x1 + x3"
            # For the case where formula is passed as a tilde object in R, we just extract the string after the tilde
            formula_str = formula.replace("~", "")
            parts = formula_str.split("+")
            return [part.strip() for part in parts]


class StatsGlmLogit:
    def __init__(self, save_path=None):
        self.save_path = save_path
        if save_path and not os.path.exists(save_path):
            os.makedirs(save_path)
    
    def fit(self, data, treatment_col, numerator_vars, denominator_vars, id_col, period_col):
        """Fit logistic regression models and calculate stabilized weights."""
        # Get periods where treatment can switch
        periods = sorted(data[period_col].unique())
        
        # Initialize DataFrame to store weights
        weights_df = pd.DataFrame()
        
        for period in periods[1:]:  # Skip the first period as there's no prior treatment to switch from
            # Get data for current period
            period_data = data[data[period_col] == period].copy()
            
            if len(period_data) == 0:
                continue
                
            # Get previous period for each individual
            prev_period = periods[periods.index(period) - 1]
            prev_data = data[data[period_col] == prev_period].copy()
            
            # Merge current with previous period data
            merged_data = pd.merge(
                period_data,
                prev_data[[id_col, treatment_col]],
                on=id_col,
                suffixes=('', '_prev')
            )
            
            # Identify individuals who switched treatment
            merged_data['switched'] = (merged_data[treatment_col] != merged_data[f"{treatment_col}_prev"]).astype(int)
            
            # Fit numerator model (simpler model)
            X_num = sm.add_constant(merged_data[numerator_vars])
            num_model = sm.Logit(merged_data['switched'], X_num).fit(disp=0)
            
            # Save model if requested
            if self.save_path:
                with open(os.path.join(self.save_path, f"num_model_period_{period}.pkl"), 'wb') as f:
                    pickle.dump(num_model, f)
            
            # Fit denominator model (full model)
            X_denom = sm.add_constant(merged_data[denominator_vars])
            denom_model = sm.Logit(merged_data['switched'], X_denom).fit(disp=0)
            
            # Save model if requested
            if self.save_path:
                with open(os.path.join(self.save_path, f"denom_model_period_{period}.pkl"), 'wb') as f:
                    pickle.dump(denom_model, f)
            
            # Calculate predicted probabilities
            num_probs = num_model.predict(X_num)
            denom_probs = denom_model.predict(X_denom)
            
            # Calculate stabilized weights
            merged_data['weight'] = num_probs / denom_probs
            merged_data['weight'] = merged_data['weight'].fillna(1.0)  # Handle division by zero
            
            # Add to weights DataFrame
            weights_df = pd.concat([weights_df, merged_data[[id_col, period_col, 'weight']]])
        
        return weights_df


class CensorWeightCalculator:
    def __init__(self, model_fitter, censor_event, pool_models="none"):
        self.model_fitter = model_fitter
        self.censor_event = censor_event
        self.pool_models = pool_models  # "none", "numerator", or "denominator"
    
    def fit(self, data, numerator_vars, denominator_vars, id_col, period_col):
        """Calculate censoring weights."""
        # Create a copy of the data
        data_copy = data.copy()
        
        # Create pooled models if required
        if self.pool_models in ["numerator", "both"]:
            pooled_num_model = self._fit_pooled_model(data_copy, numerator_vars)
        
        if self.pool_models in ["denominator", "both"]:
            pooled_denom_model = self._fit_pooled_model(data_copy, denominator_vars)
        
        # Get periods
        periods = sorted(data_copy[period_col].unique())
        
        # Initialize DataFrame to store weights
        weights_df = pd.DataFrame()
        
        for period in periods:
            # Get data for current period
            period_data = data_copy[data_copy[period_col] == period].copy()
            
            if len(period_data) == 0:
                continue
            
            # Define censoring indicator
            period_data['is_censored'] = (period_data[self.censor_event] == 1).astype(int)
            
            # Check if we have variation in the outcome
            if period_data['is_censored'].nunique() <= 1:
                # No variation, assign weight of 1.0
                period_data['weight'] = 1.0
            else:
                try:
                    # Fit period-specific models or use pooled models
                    if self.pool_models != "numerator":
                        # Fit numerator model for this period with regularization
                        X_num = sm.add_constant(period_data[numerator_vars])
                        num_model = sm.Logit(period_data['is_censored'], X_num).fit_regularized(
                            alpha=0.01, disp=0
                        )
                        num_probs = num_model.predict(X_num)
                    else:
                        # Use pooled numerator model
                        X_num = sm.add_constant(period_data[numerator_vars])
                        num_probs = pooled_num_model.predict(X_num)
                    
                    if self.pool_models != "denominator":
                        # Fit denominator model for this period with regularization
                        X_denom = sm.add_constant(period_data[denominator_vars])
                        denom_model = sm.Logit(period_data['is_censored'], X_denom).fit_regularized(
                            alpha=0.01, disp=0
                        )
                        denom_probs = denom_model.predict(X_denom)
                    else:
                        # Use pooled denominator model
                        X_denom = sm.add_constant(period_data[denominator_vars])
                        denom_probs = pooled_denom_model.predict(X_denom)
                    
                    # Ensure probabilities are not exactly 0 or 1
                    num_probs = np.clip(num_probs, 0.001, 0.999)
                    denom_probs = np.clip(denom_probs, 0.001, 0.999)
                    
                    # Calculate stabilized weights
                    period_data['weight'] = (1 - num_probs) / (1 - denom_probs)
                    
                except (np.linalg.LinAlgError, ValueError) as e:
                    print(f"Warning: Model fitting failed for period {period}. Setting weights to 1.0. Error: {e}")
                    period_data['weight'] = 1.0
                
            # Handle any remaining NaNs or infinities
            period_data['weight'] = period_data['weight'].fillna(1.0)
            period_data.loc[np.isinf(period_data['weight']), 'weight'] = 1.0
            
            # Trim extreme weights
            q99 = np.percentile(period_data['weight'], 99)
            period_data.loc[period_data['weight'] > q99, 'weight'] = q99
            
            # Add to weights DataFrame
            weights_df = pd.concat([weights_df, period_data[[id_col, period_col, 'weight']]])
        
        return weights_df

    def _fit_pooled_model(self, data, vars_list):
        """Fit a pooled model across all periods."""
        # Create censoring indicator
        data['is_censored'] = (data[self.censor_event] == 1).astype(int)
        
        # Check if we have variation in the outcome
        if data['is_censored'].nunique() <= 1:
            # Return a dummy model that always predicts the constant
            constant_prob = data['is_censored'].mean()
            class DummyModel:
                def predict(self, X):
                    return np.ones(len(X)) * constant_prob
            return DummyModel()
        
        try:
            # Fit model with regularization
            X = sm.add_constant(data[vars_list])
            model = sm.Logit(data['is_censored'], X).fit_regularized(alpha=0.01, disp=0)
            return model
        except (np.linalg.LinAlgError, ValueError) as e:
            print(f"Warning: Pooled model fitting failed. Creating dummy model. Error: {e}")
            # Return a dummy model
            constant_prob = data['is_censored'].mean()
            class DummyModel:
                def predict(self, X):
                    return np.ones(len(X)) * constant_prob
            return DummyModel()


class OutcomeModel:
    def __init__(self, adjustment_vars=None):
        self.adjustment_vars = adjustment_vars
        self.fitted_model = None
        self.model_info = None
    
    def fit(self, data, id_col, time_col, event_col, treatment_col, weight_col):
        """Fit a proportional hazards model."""
        model_data = data.copy()
        
        # Prepare formula
        if self.adjustment_vars:
            formula = f"{time_col} ~ {treatment_col} + " + " + ".join(self.adjustment_vars)
        else:
            formula = f"{time_col} ~ {treatment_col}"
        
        # Fit Cox PH model
        model = PHReg.from_formula(
            formula,
            data=model_data,
            status=model_data[event_col],
            weights=model_data[weight_col]
        )
        
        result = model.fit()
        
        # Store fitted model
        self.fitted_model = result
        
        # Store model info
        self.model_info = {
            'model': model,
            'vcov': result.cov_params(),
            'formula': formula
        }
        
        return self
    
    def is_fitted(self):
        """Check if model is fitted."""
        return self.fitted_model is not None
    
    def predict(self, data, times):
        """Predict survival probabilities."""
        if not self.is_fitted():
            raise ValueError("Model not fitted yet.")
        
        # Prepare prediction data
        pred_data = data.copy()
        
        # Get baseline survival (shape: (11,))
        baseline_surv = self._estimate_baseline_survival(times)
        
        # Get linear predictor (shape: (32,))
        lp = self._calculate_linear_predictor(pred_data)
        
        # Calculate survival probabilities (shape: (32, 11))
        survival = np.power(baseline_surv, np.exp(lp)[:, np.newaxis])
        
        # Calculate confidence intervals (simplified)
        ci_width = 1.96 * 0.1 * survival  # Simplified CI
        
        return {
            'times': times,
            'survival': survival,
            'lower': np.maximum(0, survival - ci_width),
            'upper': np.minimum(1, survival + ci_width)
        }
    
    def _estimate_baseline_survival(self, times):
        """Estimate baseline survival function."""
        # This is a simplified implementation
        # A proper implementation would estimate the baseline hazard from the data
        
        # For simplicity, we'll use an exponential model
        # In real implementation, this would be based on the fitted model's baseline hazard
        lambda_hat = 0.1  # Placeholder hazard rate
        return np.exp(-lambda_hat * np.array(times))
    
    def _calculate_linear_predictor(self, data):
        exog_names = self.fitted_model.model.exog_names  # ['Intercept', 'trial_arm', 'x2']
        X = pd.DataFrame(index=data.index)
        for name in exog_names:
            if name == 'Intercept':
                X[name] = 1
            else:
                X[name] = data[name]
        return X.values @ self.fitted_model.params


def trial_sequence(estimand):
    """Create a new trial sequence object."""
    return TrialSequence(estimand=estimand)


def stats_glm_logit(save_path=None):
    """Create a logistic regression model fitter."""
    return StatsGlmLogit(save_path=save_path)


def save_to_datatable():
    """Handler for saving expanded data to a data table."""
    # This is a placeholder for the actual implementation
    def handler(data):
        return data
    return handler


def outcome_data(trial):
    """Extract outcome data from a trial sequence."""
    if trial.expansion is not None:
        return trial.expansion
    return pd.DataFrame()


def show_weight_models(trial):
    """Display information about weight models."""
    print("Weight Models Summary:")
    if hasattr(trial, 'switch_weights') and trial.switch_weights is not None:
        print(f"  Switch Weights: {len(trial.switch_weights)} rows")
    
    if hasattr(trial, 'censor_weights') and trial.censor_weights is not None:
        print(f"  Censor Weights: {len(trial.censor_weights)} rows")
    
    if hasattr(trial, 'combined_weights') and trial.combined_weights is not None:
        print(f"  Combined Weights: {len(trial.combined_weights)} rows")


# Example usage
if __name__ == "__main__":
    # Load data
    data_censored = pd.read_csv("data/data_censored.csv")
    
    # Create trial sequence objects
    trial_pp = trial_sequence(estimand="PP")  # Per-protocol
    trial_itt = trial_sequence(estimand="ITT")  # Intention-to-treat
    
    # Create directories
    trial_pp_dir = os.path.join(tempfile.gettempdir(), "trial_pp")
    os.makedirs(trial_pp_dir, exist_ok=True)
    
    trial_itt_dir = os.path.join(tempfile.gettempdir(), "trial_itt")
    os.makedirs(trial_itt_dir, exist_ok=True)
    
    # Set data for both trial sequences
    trial_pp = trial_pp.set_data(
        data=data_censored,
        id="id",
        period="period",
        treatment="treatment",
        outcome="outcome",
        eligible="eligible"
    )
    
    trial_itt = trial_itt.set_data(
        data=data_censored,
        id="id",
        period="period",
        treatment="treatment",
        outcome="outcome",
        eligible="eligible"
    )
    
    # Set switch weight models
    trial_pp = trial_pp.set_switch_weight_model(
        numerator="~ age",
        denominator="~ age + x1 + x3",
        model_fitter=stats_glm_logit(save_path=os.path.join(trial_pp_dir, "switch_models"))
    )
    
    print(trial_pp)
    # Set censor weight models
    trial_pp = trial_pp.set_censor_weight_model(
        censor_event="censored",
        numerator="~ x2",
        denominator="~ x2 + x1",
        pool_models="none",
        model_fitter=stats_glm_logit(save_path=os.path.join(trial_pp_dir, "switch_models"))
    )
    
    trial_itt = trial_itt.set_censor_weight_model(
        censor_event="censored",
        numerator="~ x2",
        denominator="~ x2 + x1",
        pool_models="numerator",
        model_fitter=stats_glm_logit(save_path=os.path.join(trial_itt_dir, "switch_models"))
    )
    
    # Calculate weights
    trial_pp = trial_pp.calculate_weights()
    trial_itt = trial_itt.calculate_weights()
    
    # Show weight models
    show_weight_models(trial_itt)
    show_weight_models(trial_pp)
    
    # Set outcome models
    trial_pp = trial_pp.set_outcome_model()
    trial_itt = trial_itt.set_outcome_model(adjustment_terms="~ x2")
    
    # Set expansion options
    trial_pp = trial_pp.set_expansion_options(
        output=save_to_datatable(),
        chunk_size=500
    )
    
    trial_itt = trial_itt.set_expansion_options(
        output=save_to_datatable(),
        chunk_size=500
    )
    
    # Expand trials
    trial_pp = trial_pp.expand_trials()
    trial_itt = trial_itt.expand_trials()
    
    # Load expanded data and fit MSM
    trial_itt = trial_itt.load_expanded_data(seed=1234, p_control=0.5)
    trial_itt = trial_itt.fit_msm(
        weight_cols=["weight", "sample_weight"],
        modify_weights=lambda w: np.minimum(w, np.quantile(w, 0.99))  # Winsorization
    )
    
    # Prepare the prediction data
    prediction_data = outcome_data(trial_itt)
    prediction_data = prediction_data[prediction_data['trial_period'] == 1]

    # Add the 'const' column with value 1
    prediction_data['const'] = 1

    # Now call predict
    preds = trial_itt.predict(
        newdata=prediction_data,
        predict_times=list(range(11)),  # Assuming you want predictions from 0 to 10
        type="survival"
    )
    
    # Plot results
    plt.figure(figsize=(10, 6))
    plt.plot(preds['difference']['followup_time'], preds['difference']['survival_diff'])
    plt.plot(preds['difference']['followup_time'], preds['difference']['2.5%'], 'r--')
    plt.plot(preds['difference']['followup_time'], preds['difference']['97.5%'], 'r--')
    plt.xlabel('Follow up')
    plt.ylabel('Survival difference')
    plt.title('Treatment Effect on Survival')
    plt.grid(True)
    plt.savefig('survival_difference.png')
    plt.close()

TrialSequence(estimand='PP', data=     id  period  treatment  x1        x2  x3        x4  age     age_s  \
0     1       0          1   1  1.146148   0  0.734203   36  0.083333   
1     1       1          1   1  0.002200   0  0.734203   37  0.166667   
2     1       2          1   0 -0.481762   0  0.734203   38  0.250000   
3     1       3          1   0  0.007872   0  0.734203   39  0.333333   
4     1       4          1   1  0.216054   0  0.734203   40  0.416667   
..   ..     ...        ...  ..       ...  ..       ...  ...       ...   
720  99       3          0   0 -0.747906   1  0.575268   68  2.750000   
721  99       4          0   0 -0.790056   1  0.575268   69  2.833333   
722  99       5          1   1  0.387429   1  0.575268   70  2.916667   
723  99       6          1   1 -0.033762   1  0.575268   71  3.000000   
724  99       7          0   0 -1.340497   1  0.575268   72  3.083333   

     outcome  censored  eligible  
0          0         0         1  
1          0       

The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  all_periods['switch_weight'].fillna(1.0, inplace=True)
The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  all_periods['censor_weight'].fillna(1.0, inplace=True)
The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we ar

Weight Models Summary:
  Censor Weights: 725 rows
  Combined Weights: 1780 rows
Weight Models Summary:
  Switch Weights: 636 rows
  Censor Weights: 725 rows
  Combined Weights: 1780 rows


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  prediction_data['const'] = -1
