In [None]:
# Import necessary libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from sklearn.datasets import make_classification
import warnings
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
np.random.seed(42)

print("Libraries imported successfully!")


In [None]:
def generate_survival_data(n_samples=200, censoring_rate=0.3):
    """Generate synthetic survival data with censoring"""
    
    # Generate survival times from exponential distribution
    # Different groups have different hazard rates
    np.random.seed(42)
    
    # Group indicator (treatment vs control)
    group = np.random.binomial(1, 0.5, n_samples)
    
    # Base hazard rates
    lambda_control = 0.05  # Control group hazard
    lambda_treatment = 0.03  # Treatment group hazard (lower = better)
    
    # Generate true survival times
    survival_times = []
    for i in range(n_samples):
        if group[i] == 0:  # Control
            t = np.random.exponential(1/lambda_control)
        else:  # Treatment
            t = np.random.exponential(1/lambda_treatment)
        survival_times.append(t)
    
    survival_times = np.array(survival_times)
    
    # Generate censoring times
    censoring_times = np.random.exponential(1/(censoring_rate * 0.02), n_samples)
    
    # Observed times are minimum of survival and censoring times
    observed_times = np.minimum(survival_times, censoring_times)
    
    # Event indicator (1 = event observed, 0 = censored)
    events = (survival_times <= censoring_times).astype(int)
    
    # Additional covariates
    age = np.random.normal(50, 15, n_samples)
    age = np.clip(age, 20, 80)  # Realistic age range
    
    gender = np.random.binomial(1, 0.5, n_samples)  # 0=Female, 1=Male
    
    # Create DataFrame
    data = pd.DataFrame({
        'time': observed_times,
        'event': events,
        'group': group,
        'age': age,
        'gender': gender
    })
    
    return data

# Generate survival data
survival_data = generate_survival_data(n_samples=300, censoring_rate=0.3)

print("Survival Data Summary:")
print(f"Total samples: {len(survival_data)}")
print(f"Events observed: {survival_data['event'].sum()}")
print(f"Censored observations: {(1 - survival_data['event']).sum()}")
print(f"Censoring rate: {(1 - survival_data['event'].mean()):.2%}")
print(f"\\nData preview:")
print(survival_data.head(10))

# Visualize the data
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Time distribution by group
axes[0, 0].hist(survival_data[survival_data['group']==0]['time'], 
               alpha=0.7, label='Control', bins=20)
axes[0, 0].hist(survival_data[survival_data['group']==1]['time'], 
               alpha=0.7, label='Treatment', bins=20)
axes[0, 0].set_title('Distribution of Observed Times')
axes[0, 0].set_xlabel('Time')
axes[0, 0].set_ylabel('Frequency')
axes[0, 0].legend()

# Event status by group
event_counts = survival_data.groupby(['group', 'event']).size().unstack()
event_counts.plot(kind='bar', ax=axes[0, 1], alpha=0.7)
axes[0, 1].set_title('Event Status by Group')
axes[0, 1].set_xlabel('Group (0=Control, 1=Treatment)')
axes[0, 1].set_ylabel('Count')
axes[0, 1].legend(['Censored', 'Event'])
axes[0, 1].tick_params(axis='x', rotation=0)

# Age distribution
axes[1, 0].hist(survival_data['age'], bins=20, alpha=0.7, edgecolor='black')
axes[1, 0].set_title('Age Distribution')
axes[1, 0].set_xlabel('Age')
axes[1, 0].set_ylabel('Frequency')

# Scatter plot: Time vs Age
colors = ['red' if e == 1 else 'blue' for e in survival_data['event']]
axes[1, 1].scatter(survival_data['age'], survival_data['time'], 
                  c=colors, alpha=0.6)
axes[1, 1].set_title('Time vs Age (Red=Event, Blue=Censored)')
axes[1, 1].set_xlabel('Age')
axes[1, 1].set_ylabel('Time')

plt.tight_layout()
plt.show()


In [None]:
class KaplanMeierEstimator:
    """Kaplan-Meier survival function estimator"""
    
    def __init__(self):
        self.survival_function_ = None
        self.event_times_ = None
        self.n_at_risk_ = None
        self.n_events_ = None
        
    def fit(self, times, events):
        """Fit Kaplan-Meier estimator"""
        times = np.array(times)
        events = np.array(events)
        
        # Get unique event times (only when events occur)
        event_mask = events == 1
        unique_times = np.unique(times[event_mask])
        unique_times = np.sort(unique_times)
        
        # Initialize survival function
        survival_probs = []
        n_at_risk = []
        n_events = []
        
        current_survival = 1.0
        
        for t in unique_times:
            # Number at risk at time t (all subjects with time >= t)
            at_risk = np.sum(times >= t)
            
            # Number of events at time t
            events_at_t = np.sum((times == t) & (events == 1))
            
            # Kaplan-Meier formula: S(t) = S(t-) * (1 - d_t/n_t)
            # where d_t = events at time t, n_t = at risk at time t
            survival_prob = current_survival * (1 - events_at_t / at_risk)
            
            survival_probs.append(survival_prob)
            n_at_risk.append(at_risk)
            n_events.append(events_at_t)
            
            current_survival = survival_prob
        
        self.event_times_ = unique_times
        self.survival_function_ = np.array(survival_probs)
        self.n_at_risk_ = np.array(n_at_risk)
        self.n_events_ = np.array(n_events)
        
        return self
    
    def predict_survival(self, times):
        """Predict survival probability at given times"""
        times = np.array(times)
        survival_probs = np.ones(len(times))
        
        for i, t in enumerate(times):
            # Find the latest event time <= t
            valid_indices = self.event_times_ <= t
            if np.any(valid_indices):
                latest_index = np.where(valid_indices)[0][-1]
                survival_probs[i] = self.survival_function_[latest_index]
        
        return survival_probs
    
    def confidence_interval(self, confidence=0.95):
        """Calculate confidence intervals using Greenwood's formula"""
        if self.survival_function_ is None:
            raise ValueError("Must fit estimator first")
        
        # Greenwood's formula for variance
        alpha = 1 - confidence
        z_score = stats.norm.ppf(1 - alpha/2)
        
        # Cumulative hazard variance
        cumulative_variance = 0
        variances = []
        
        for i in range(len(self.event_times_)):
            n_risk = self.n_at_risk_[i]
            n_event = self.n_events_[i]
            
            if n_risk > n_event:
                cumulative_variance += n_event / (n_risk * (n_risk - n_event))
            
            # Variance of survival function
            s_t = self.survival_function_[i]
            var_s_t = (s_t ** 2) * cumulative_variance
            se_s_t = np.sqrt(var_s_t)
            
            variances.append(se_s_t)
        
        variances = np.array(variances)
        
        # Calculate confidence intervals
        lower_ci = self.survival_function_ - z_score * variances
        upper_ci = self.survival_function_ + z_score * variances
        
        # Clip to valid probability range
        lower_ci = np.clip(lower_ci, 0, 1)
        upper_ci = np.clip(upper_ci, 0, 1)
        
        return lower_ci, upper_ci

# Fit Kaplan-Meier estimators for each group
print("=== Kaplan-Meier Analysis ===")

# Separate data by group
control_data = survival_data[survival_data['group'] == 0]
treatment_data = survival_data[survival_data['group'] == 1]

# Fit KM estimators
km_control = KaplanMeierEstimator()
km_control.fit(control_data['time'], control_data['event'])

km_treatment = KaplanMeierEstimator()
km_treatment.fit(treatment_data['time'], treatment_data['event'])

print(f"Control group: {len(control_data)} subjects, {control_data['event'].sum()} events")
print(f"Treatment group: {len(treatment_data)} subjects, {treatment_data['event'].sum()} events")

# Calculate confidence intervals
control_ci_lower, control_ci_upper = km_control.confidence_interval()
treatment_ci_lower, treatment_ci_upper = km_treatment.confidence_interval()

# Plot survival curves
plt.figure(figsize=(12, 8))

# Control group
plt.step(km_control.event_times_, km_control.survival_function_, 
         where='post', label='Control', linewidth=2, color='red')
plt.fill_between(km_control.event_times_, control_ci_lower, control_ci_upper, 
                alpha=0.3, color='red', step='post')

# Treatment group
plt.step(km_treatment.event_times_, km_treatment.survival_function_, 
         where='post', label='Treatment', linewidth=2, color='blue')
plt.fill_between(km_treatment.event_times_, treatment_ci_lower, treatment_ci_upper, 
                alpha=0.3, color='blue', step='post')

plt.xlabel('Time')
plt.ylabel('Survival Probability')
plt.title('Kaplan-Meier Survival Curves with 95% Confidence Intervals')
plt.legend()
plt.grid(True, alpha=0.3)
plt.ylim(0, 1)
plt.show()

# Print survival statistics
print("\\n=== Survival Statistics ===")
print("Control Group:")
print(f"  Median survival time: {km_control.event_times_[np.argmin(np.abs(km_control.survival_function_ - 0.5))]:.2f}")
print(f"  1-year survival: {km_control.predict_survival([365])[0]:.3f}")

print("Treatment Group:")
median_idx = np.argmin(np.abs(km_treatment.survival_function_ - 0.5))
if median_idx < len(km_treatment.event_times_):
    print(f"  Median survival time: {km_treatment.event_times_[median_idx]:.2f}")
else:
    print(f"  Median survival time: Not reached")
print(f"  1-year survival: {km_treatment.predict_survival([365])[0]:.3f}")
