# TTE-v2 Analysis and Clustering Integration Notebook

This notebook merges the code from **Assignment 1** with a robust Time-to-Event (TTE) analysis workflow. We:
1. Load and prepare the dummy data (`data_censored.csv`).
2. Demonstrate how to set up **TrialSequence** objects (`PP` and `ITT`).
3. Simulate training of the TTE-v2 model by generating **risk scores**.
4. Perform **K-means clustering** on these risk scores to identify subgroups.
5. Visualize the distribution of risk scores and show additional bar plots highlighting the effect of clustering.
6. Optionally perform a **Kaplan-Meier survival analysis** comparing treated vs. control groups.

Throughout this notebook, each step is documented to ensure clarity, reproducibility, and a complete demonstration of how the TTE-v2 approach integrates with clustering.

## 1. Setup

We first define our `TrialSequence` class, which mimics the structure from the R package in Assignment 1. We also create directories to save output models if needed.

In [None]:
import os
import tempfile
import logging

logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')

class TrialSequence:
    def __init__(self, estimand):
        self.estimand = estimand
        self.data = None
        self.switch_weight_model = None
        self.censor_weight_model = None
        # These are placeholders for demonstration.
        self.weight_models = {}  # for storing fitted weight models
        self.expansion = None    # for storing expansion details

    def set_data(self, data, id_col, period_col, treatment_col, outcome_col, eligible_col):
        """Assign dataset columns to the trial object."""
        self.data = data.copy()
        self.data.rename(
            columns={
                id_col: "id",
                period_col: "period",
                treatment_col: "treatment",
                outcome_col: "outcome",
                eligible_col: "eligible",
            }, 
            inplace=True
        )
        return self

    def set_switch_weight_model(self, numerator, denominator, save_path):
        """Placeholder for setting a switch weight model (PP)."""
        os.makedirs(save_path, exist_ok=True)
        self.switch_weight_model = {
            "numerator": numerator,
            "denominator": denominator,
            "model_fitter": "stats_glm_logit",
            "save_path": save_path
        }
        return self

    def set_censor_weight_model(self, censor_event, numerator, denominator, pool_models, save_path):
        """Placeholder for setting a censor weight model."""
        os.makedirs(save_path, exist_ok=True)
        self.censor_weight_model = {
            "censor_event": censor_event,
            "numerator": numerator,
            "denominator": denominator,
            "pool_models": pool_models,
            "model_fitter": "stats_glm_logit",
            "save_path": save_path
        }
        return self

    def calculate_weights(self):
        """Placeholder to simulate weight calculation."""
        logging.info("Weight models not truly fitted. This is a placeholder.")
        # In a real scenario, you'd fit logistic models for numerator/denominator.
        return self

    def set_outcome_model(self, adjustment_terms=None):
        """Placeholder for setting outcome model, possibly with adjustment terms."""
        logging.info("Outcome model setup with terms: %s.", adjustment_terms)
        return self

    def set_expansion_options(self, output="datatable", chunk_size=500):
        """Placeholder for specifying expansion options."""
        self.expansion = {
            "output": output,
            "chunk_size": chunk_size,
            "censor_at_switch": False,
            "first_period": 0,
            "last_period": float('inf')
        }
        return self

    def expand_trials(self):
        """Placeholder for creating sequence of target trials data."""
        logging.info("Expanding trials (placeholder).")
        # Real logic would expand the dataset by trial period, etc.
        return self

    def load_expanded_data(self, seed=None, p_control=1.0):
        """Placeholder for loading or sampling from expanded data."""
        logging.info(
            "Loading expanded data with seed=%s, p_control=%.2f (placeholder).",
            str(seed), p_control
        )
        # Real logic would do subsetting/sampling here.
        return self

    def fit_msm(self, weight_cols=None, modify_weights=None):
        """Placeholder for fitting a marginal structural model."""
        logging.info("Fitting MSM (placeholder). Weight columns: %s.", weight_cols)
        # Real logic would fit a logistic or other model.
        return self

    def show_weight_models(self):
        """Placeholder for displaying weight model summaries."""
        print("Weight Models for Informative Censoring (placeholder).")
        # Real logic would print details of fitted weight models.


### Create directories for trial_pp and trial_itt (as in Assignment 1)
We replicate the R code's approach of creating directories in a temporary folder. In a real scenario, these could be replaced with persistent paths.

In [None]:
trial_pp_dir = os.path.join(tempfile.gettempdir(), "trial_pp")
os.makedirs(trial_pp_dir, exist_ok=True)

trial_itt_dir = os.path.join(tempfile.gettempdir(), "trial_itt")
os.makedirs(trial_itt_dir, exist_ok=True)


## 2. Data Preparation

We load `data_censored.csv` (a dummy dataset) as in Assignment 1, then instantiate `trial_pp` (Per-Protocol) and `trial_itt` (Intention-to-Treat).

In [None]:
import pandas as pd

data_censored = pd.read_csv("data_censored.csv")
print("--- Loaded data_censored.csv (head) ---")
print(data_censored.head())

# Per-protocol
trial_pp = TrialSequence(estimand="PP").set_data(
    data_censored,
    id_col="id",
    period_col="period",
    treatment_col="treatment",
    outcome_col="outcome",
    eligible_col="eligible"
)

# Intention-to-treat
trial_itt = TrialSequence(estimand="ITT").set_data(
    data_censored,
    id_col="id",
    period_col="period",
    treatment_col="treatment",
    outcome_col="outcome",
    eligible_col="eligible"
)

print("\n--- Trial ITT Data (head) ---")
print(trial_itt.data.head())
print(f"Number of observations: {len(trial_itt.data)}")
print(f"Number of unique patients: {trial_itt.data['id'].nunique()}")

### 2.1 Weight Models and Censoring (Placeholder)
For completeness, we mimic the steps from Assignment 1, setting up switch and censor weight models. (In a real scenario, these would be fitted with logistic regression.)

In [None]:
trial_pp = trial_pp.set_switch_weight_model(
    numerator="treatment ~ age",
    denominator="treatment ~ age + x1 + x3",
    save_path=os.path.join(trial_pp_dir, "switch_models")
)

print("\n--- Switch Weight Model (PP) ---")
print(f"Numerator: {trial_pp.switch_weight_model['numerator']}")
print(f"Denominator: {trial_pp.switch_weight_model['denominator']}")

# Censor weight model for PP
trial_pp = trial_pp.set_censor_weight_model(
    censor_event="censored",
    numerator="1 - censored ~ x2",
    denominator="1 - censored ~ x2 + x1",
    pool_models="none",
    save_path=os.path.join(trial_pp_dir, "censor_models")
)

print("\n--- Censor Weight Model (PP) ---")
print(f"Numerator: {trial_pp.censor_weight_model['numerator']}")
print(f"Denominator: {trial_pp.censor_weight_model['denominator']}")

# Censor weight model for ITT
trial_itt = trial_itt.set_censor_weight_model(
    censor_event="censored",
    numerator="1 - censored ~ x2",
    denominator="1 - censored ~ x2 + x1",
    pool_models="numerator",  # Numerator model is pooled across arms
    save_path=os.path.join(trial_itt_dir, "censor_models")
)
print("\n--- Censor Weight Model (ITT) ---")
print(f"Numerator: {trial_itt.censor_weight_model['numerator']}")
print(f"Denominator: {trial_itt.censor_weight_model['denominator']}")

### 2.2 Calculate Weights (Placeholder)
We call `calculate_weights()` to mimic the weighting procedure. This is a stub here.

In [None]:
trial_pp.calculate_weights()
trial_itt.calculate_weights()

# We can also show weight models, though they're placeholders.
print("\n--- Show Weight Models (PP) ---")
trial_pp.show_weight_models()
print("\n--- Show Weight Models (ITT) ---")
trial_itt.show_weight_models()

### 2.3 Specify Outcome Model and Expand Trials (Placeholder)
These steps replicate the R code from Assignment 1 but do not perform real expansions or outcome model fitting.

In [None]:
trial_pp.set_outcome_model()
trial_itt.set_outcome_model(adjustment_terms=["x2"])

# Set expansion options
trial_pp.set_expansion_options(output="datatable", chunk_size=500)
trial_itt.set_expansion_options(output="datatable", chunk_size=500)

# Expand trials
trial_pp.expand_trials()
trial_itt.expand_trials()

print("\n--- Trial PP expansion details ---")
print(trial_pp.expansion)

print("\n--- Trial ITT expansion details ---")
print(trial_itt.expansion)

# Optionally load expanded data with sampling
trial_itt.load_expanded_data(seed=1234, p_control=0.5)

## 3. Simulate TTE-v2 Model Training and Generate Risk Scores

Here we simulate the TTE-v2 model by generating a `risk_score` column. This is purely for demonstration. In a real scenario, you'd fit an actual time-to-event model.

In [None]:
import numpy as np

def train_tte_model(data):
    """Simulate training of the TTE-v2 model by generating a risk score for each observation."""
    data['risk_score'] = np.random.rand(len(data))
    logging.info("TTE-v2 model simulated: risk scores generated.")
    return data

# Apply the simulated training to the ITT data, for example.
trial_itt.data = train_tte_model(trial_itt.data)

print("\n--- Trial ITT Data with Risk Score (head) ---")
print(trial_itt.data.head())

## 4. Clustering Mechanism Integration

After generating the risk scores, we apply **K-means clustering** to uncover potential subgroups within the data.

**Rationale**:
- Clustering can reveal subpopulations (e.g., high, moderate, low risk). 
- This can inform how we tailor interventions or interpret TTE results.

**Implementation**:
1. Extract the `risk_score` column.
2. Use K-means to partition the data (e.g., 3 clusters).
3. Append cluster labels to the dataset.

In [None]:
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA

def perform_clustering(data, feature_col='risk_score', n_clusters=3, use_pca=False):
    """Apply K-means clustering to the specified feature in the dataset."""
    features = data[[feature_col]].values

    if use_pca and features.shape[1] > 1:
        pca = PCA(n_components=2)
        features = pca.fit_transform(features)
        logging.info("PCA applied for dimensionality reduction.")

    kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    clusters = kmeans.fit_predict(features)
    data['cluster'] = clusters
    logging.info("K-means clustering completed with %d clusters.", n_clusters)
    return data, kmeans

# Perform clustering on the ITT data's risk score
trial_itt.data, kmeans_model = perform_clustering(trial_itt.data, feature_col='risk_score', n_clusters=3)

print("\n--- Clustered ITT Data (head) ---")
print(trial_itt.data.head())

## 5. Visualizing Clusters and Generating Insights

We visualize the clustering results to interpret the risk profiles.

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(8, 6))
sns.boxplot(x='cluster', y='risk_score', data=trial_itt.data)
plt.title('Risk Score Distribution Across Clusters')
plt.xlabel('Cluster')
plt.ylabel('Risk Score')
plt.show()

# Summary statistics per cluster
cluster_summary = trial_itt.data.groupby('cluster')['risk_score'].agg(['mean', 'median', 'std']).reset_index()
print("--- Cluster Summary Statistics ---")
print(cluster_summary)

# Count of observations per cluster
plt.figure(figsize=(6, 4))
sns.countplot(x='cluster', data=trial_itt.data)
plt.title('Number of Observations per Cluster')
plt.xlabel('Cluster')
plt.ylabel('Count')
plt.show()

### Additional Visualization: Effect of Clustering Implementation

To further illustrate the impact of clustering, we create bar plots:
- **Mean Risk Score by Cluster**: Shows average risk score (with standard deviation as error bars).
- **Event Rate by Cluster**: Proportion of events (`outcome`) in each cluster.

In [None]:
# Ensure 'outcome' is numeric. If it's not, convert it.
# If 'outcome' is 0/1, we can treat it as numeric to compute means.
if trial_itt.data['outcome'].dtype != 'int' and trial_itt.data['outcome'].dtype != 'float':
    # Attempt to convert outcome to numeric if it's not.
    trial_itt.data['outcome'] = pd.to_numeric(trial_itt.data['outcome'], errors='coerce')
    trial_itt.data.dropna(subset=['outcome'], inplace=True)

# Compute mean risk score and standard deviation per cluster
cluster_stats = trial_itt.data.groupby('cluster')['risk_score'].agg(['mean', 'std']).reset_index()

# Compute event rate per cluster (assuming 'outcome' is 0 or 1)
event_rate = trial_itt.data.groupby('cluster')['outcome'].mean().reset_index()

# Create a figure with two subplots
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Bar plot for mean risk score
sns.barplot(x='cluster', y='mean', data=cluster_stats, ax=axes[0], capsize=0.1)
axes[0].set_title('Mean Risk Score by Cluster')
axes[0].set_xlabel('Cluster')
axes[0].set_ylabel('Mean Risk Score')

# Add error bars (std) manually
for index, row in cluster_stats.iterrows():
    axes[0].errorbar(
        row['cluster'],
        row['mean'],
        yerr=row['std'],
        fmt='none',
        c='black'
    )

# Bar plot for event rate by cluster
sns.barplot(x='cluster', y='outcome', data=event_rate, ax=axes[1])
axes[1].set_title('Event Rate (Outcome) by Cluster')
axes[1].set_xlabel('Cluster')
axes[1].set_ylabel('Event Rate')

plt.tight_layout()
plt.show()

## 6. Survival Analysis Using Kaplan-Meier Estimator (Optional)

We optionally compare survival probabilities between treated and control groups using the Kaplan-Meier approach. This step uses `lifelines`.

In [None]:
from lifelines import KaplanMeierFitter

# Filter data for the first trial period (if available)
if 'period' in trial_itt.data.columns:
    newdata = trial_itt.data[trial_itt.data['period'] == 1].copy()
else:
    newdata = trial_itt.data.copy()

# Ensure followup_time, assigned_treatment columns exist or create them if needed
# For demonstration, we create them if they do not exist.
if 'followup_time' not in newdata.columns:
    # Suppose we simulate a followup_time as row index, just for demonstration
    newdata['followup_time'] = range(len(newdata))

if 'assigned_treatment' not in newdata.columns:
    # Suppose we treat 'treatment' as 'assigned_treatment'
    newdata['assigned_treatment'] = newdata['treatment']

kmf_treated = KaplanMeierFitter()
kmf_control = KaplanMeierFitter()

# Fit the Kaplan-Meier model for the treated group
treated_mask = newdata['assigned_treatment'] == 1
kmf_treated.fit(
    durations=newdata[treated_mask]['followup_time'],
    event_observed=newdata[treated_mask]['outcome'],
    label="Treated"
)

# Fit the Kaplan-Meier model for the control group
control_mask = newdata['assigned_treatment'] == 0
kmf_control.fit(
    durations=newdata[control_mask]['followup_time'],
    event_observed=newdata[control_mask]['outcome'],
    label="Control"
)

# Predict survival probabilities at time points 0 to 10
predict_times = range(0, 11)
surv_prob_treated = kmf_treated.predict(predict_times)
surv_prob_control = kmf_control.predict(predict_times)

# Compute difference
survival_diff = surv_prob_treated - surv_prob_control
diff_std = survival_diff.std()
ci_lower = survival_diff - 1.96 * diff_std
ci_upper = survival_diff + 1.96 * diff_std

# Plot the survival difference
plt.figure(figsize=(8, 5))
plt.plot(predict_times, survival_diff, label="Survival Difference", color="blue")
plt.fill_between(predict_times, ci_lower, ci_upper, color="red", alpha=0.2, label="95% CI")
plt.xlabel("Follow-up Time")
plt.ylabel("Survival Difference")
plt.title("Survival Probability Difference Over Time")
plt.legend()
plt.show()

## Conclusion

This unified notebook shows how we:
1. **Load** the dummy data (`data_censored.csv`) from Assignment 1.
2. **Set up** `TrialSequence` objects (`PP` and `ITT`).
3. **Simulate** a TTE-v2 model by generating risk scores.
4. **Cluster** those risk scores via K-means and visualize them with boxplots, counts, and additional bar charts.
5. **Optionally** perform Kaplan-Meier survival analysis.

The bar plots demonstrate the **effect** of the clustering implementation:
- One plot shows the **mean risk score** (with standard deviation) for each cluster.
- Another shows the **event rate** (the mean of `outcome`) per cluster.

This approach clarifies how subgroups differ in predicted risk and observed outcomes, illustrating the **practical** impact of our clustering step.