#### Import the necessary files

In [1]:
import os
import pandas as pd
import numpy as np
from lifelines import CoxPHFitter
from sklearn.linear_model import LogisticRegression
import matplotlib.pyplot as plt


#### Load and Prepare Data

In [2]:
import pandas as pd
import numpy as np
from lifelines import CoxPHFitter
from sklearn.linear_model import LogisticRegression
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt

class TargetTrialEmulator:
    def __init__(self, estimand="ITT"):
        self.estimand = estimand
        self.weights = None
        self.model = None
        self.expanded_data = None
        self.data = None
        self.cluster_models = {}  # For cluster-specific models

    def prepare_data(self, data_path):
        self.data = pd.read_csv(data_path)
        self.data['age_s'] = self.data['age'] + self.data['period']/12
        return self

    def cluster_patients(self, n_clusters=3):
        """Cluster patients based on baseline characteristics."""
        baseline = self.data.groupby('id').first()[['age', 'x1', 'x2', 'x3']]
        kmeans = KMeans(n_clusters=n_clusters)
        clusters = kmeans.fit_predict(baseline)
        self.data['cluster'] = self.data['id'].map(
            pd.Series(clusters, index=baseline.index))
        return self

    def calculate_weights(self):
        if self.estimand == "PP":
            switch_num = LogisticRegression()
            switch_num.fit(self.data[['age']], self.data['treatment'])
            numer = switch_num.predict_proba(self.data[['age']])[:,1]
            
            switch_den = LogisticRegression()
            switch_den.fit(self.data[['age', 'x1', 'x3']], self.data['treatment'])
            denom = switch_den.predict_proba(self.data[['age', 'x1', 'x3']])[:,1]
            switch_weights = numer / denom
        else:
            switch_weights = np.ones(len(self.data))
        
        censor_model = LogisticRegression()
        censor_model.fit(self.data[['x2', 'x1']], self.data['censored'])
        censor_weights = 1 / censor_model.predict_proba(self.data[['x2', 'x1']])[:,0]
        self.weights = switch_weights * censor_weights
        return self

    def expand_trials(self):
        expanded = []
        for period in self.data['period'].unique():
            period_data = self.data[self.data['period'] == period].copy()
            period_data['trial_period'] = period
            expanded.append(period_data)
        self.expanded_data = pd.concat(expanded)
        return self

    def fit_msm(self):
        q99 = np.quantile(self.weights, 0.99)
        self.expanded_data['weights'] = np.minimum(self.weights, q99)
        self.model = CoxPHFitter()
        self.model.fit(
            self.expanded_data[['treatment', 'x2', 'period', 'outcome', 'weights']],
            duration_col='period',
            event_col='outcome',
            weights_col='weights',
            robust=True
        )
        return self

### Execution Workflow

In [None]:
if __name__ == "__main__":
    # 1. Load data and cluster patients
    emulator = TargetTrialEmulator(estimand="ITT")
    emulator.prepare_data("data_censored.csv")
    emulator.cluster_patients(n_clusters=3)
    
    # 2. Analyze each cluster separately
    cluster_results = {}
    for cluster in sorted(emulator.data['cluster'].unique()):
        print(f"\n=== Analyzing Cluster {cluster} ===")
        
        # Subset cluster data
        cluster_data = emulator.data[emulator.data['cluster'] == cluster].copy()
        
        # Create new emulator for cluster
        cluster_emulator = TargetTrialEmulator(estimand="ITT")
        cluster_emulator.data = cluster_data
        cluster_emulator.calculate_weights()
        cluster_emulator.expand_trials()
        cluster_emulator.fit_msm()
        
        # Store results
        cluster_results[cluster] = {
            'model': cluster_emulator.model,
            'data': cluster_emulator.expanded_data
        }
        
        # Save cluster data
        cluster_emulator.expanded_data.to_csv(f"cluster_{cluster}_data.csv", index=False)

    # 3. Generate cluster survival curves
    plt.figure(figsize=(10, 6))
    max_period = emulator.data['period'].max()
    
    for cluster, result in cluster_results.items():
        # Get baseline prediction data (first observation in cluster)
        baseline_data = result['data'].query("period == 0").iloc[:1]
        
        # Generate survival predictions
        survival = result['model'].predict_survival_function(
            baseline_data, 
            times=np.linspace(0, max_period, 50))
        
        plt.plot(survival.T, label=f'Cluster {cluster}', linewidth=2)

    plt.title("Survival Probability by Patient Cluster", fontsize=14)
    plt.xlabel("Time Periods", fontsize=12)
    plt.ylabel("Survival Probability", fontsize=12)
    plt.ylim(0.5, 1.05)
    plt.grid(alpha=0.2)
    plt.legend()
    plt.tight_layout()
    plt.savefig("cluster_survival_curves.png", dpi=300)
    plt.show()

    # 4. Print cluster summaries
    print("\nCluster-wise Treatment Effects:")
    for cluster, result in cluster_results.items():
        print(f"\nCluster {cluster}:")
        print(result['model'].summary.loc['treatment'])

SyntaxError: invalid syntax. Perhaps you forgot a comma? (3351636683.py, line 42)