In [19]:
import os
import warnings
import pandas as pd
import numpy as np
import pyedflib
from scipy.signal import find_peaks, butter, filtfilt
from hrvanalysis import get_time_domain_features, get_frequency_domain_features, get_poincare_plot_features
from lifelines import CoxPHFitter, KaplanMeierFitter
from lifelines.utils import concordance_index
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import logging
from matplotlib.backends.backend_pdf import PdfPages
from sklearn.metrics import roc_curve

# Disable warnings and suppress output
warnings.filterwarnings("ignore")
logging.basicConfig(level=logging.CRITICAL)
import matplotlib
matplotlib.use('Agg')

# ----------------------- Data Processing Functions -----------------------

def load_edf_file(file_path: str):
    """Load an EDF file and return the first signal (assumed to be ECG)."""
    try:
        with pyedflib.EdfReader(file_path) as edf:
            ecg_signal = edf.readSignal(0)
        return ecg_signal
    except Exception:
        return None

def filter_ecg_signal(ecg_signal: np.ndarray, lowcut: float = 0.5, highcut: float = 50.0,
                      fs: float = 1000.0, order: int = 1) -> np.ndarray:
    """Bandpass filter the ECG signal."""
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')
    return filtfilt(b, a, ecg_signal)

def detect_r_peaks(ecg_signal: np.ndarray, distance: int = 300) -> np.ndarray:
    """Detect R-peaks in the given ECG signal."""
    r_peaks, _ = find_peaks(ecg_signal, distance=distance)
    return r_peaks

def compute_hrv_metrics(r_peaks: np.ndarray, fs: float = 1000.0) -> dict:
    """Compute HRV metrics from the R-peaks."""
    rr_intervals = np.diff(r_peaks) / fs * 1000.0  # in milliseconds
    time_domain_features = get_time_domain_features(rr_intervals)
    frequency_domain_features = get_frequency_domain_features(rr_intervals)
    poincare_features = get_poincare_plot_features(rr_intervals)
    return {**time_domain_features, **frequency_domain_features, **poincare_features}

def extract_apnea_events(xml_path: str):
    """Extract apnea events from an XML file."""
    events = []
    try:
        import xml.etree.ElementTree as ET
        tree = ET.parse(xml_path)
        root = tree.getroot()
        for event in root.findall('.//ScoredEvent'):
            event_concept = event.find('EventConcept')
            if event_concept is not None and 'apnea' in event_concept.text.lower():
                start = event.find('Start')
                duration = event.find('Duration')
                if start is not None and duration is not None:
                    start_time = float(start.text)
                    duration_time = float(duration.text)
                    event_type = 'unknown'
                    concept_text = event_concept.text.lower()
                    if 'obstructive' in concept_text:
                        event_type = 'obstructive'
                    elif 'central' in concept_text:
                        event_type = 'central'
                    elif 'mixed' in concept_text:
                        event_type = 'mixed'
                    events.append((start_time, duration_time, event_type))
    except Exception:
        pass
    return events

def extract_rem_events(xml_path: str):
    """Extract REM events from an XML file."""
    events = []
    try:
        import xml.etree.ElementTree as ET
        tree = ET.parse(xml_path)
        root = tree.getroot()
        for event in root.findall('.//ScoredEvent'):
            event_concept = event.find('EventConcept')
            if event_concept is not None and 'rem' in event_concept.text.lower():
                start = event.find('Start')
                duration = event.find('Duration')
                if start is not None and duration is not None:
                    start_time = float(start.text)
                    duration_time = float(duration.text)
                    events.append((start_time, duration_time))
    except Exception:
        pass
    return events

# ----------------------- ROC-Threshold Helper Functions -----------------------

def get_optimal_threshold(data: pd.DataFrame, variable: str) -> float:
    """
    Compute the optimal threshold for a given variable using ROC analysis.
    The optimal threshold maximizes the Youden index (tpr - fpr)
    with respect to the binary outcome 'vital'.
    """
    fpr, tpr, thresholds = roc_curve(data['vital'], data[variable])
    youden_index = tpr - fpr
    optimal_idx = np.argmax(youden_index)
    return thresholds[optimal_idx]

def create_group_column(data: pd.DataFrame, variable: str, threshold: float) -> pd.DataFrame:
    """
    Create a group column in the DataFrame based on the given variable and threshold.
    """
    data = data.copy()
    data['group'] = np.where(data[variable] >= threshold, 'High', 'Low')
    return data

# ----------------------- Plotting Functions for Reports -----------------------

def plot_forest(model: CoxPHFitter, title: str):
    """Generate a forest plot similar to R's forestplot."""
    summary = model.summary.copy().round(2)
    summary_sorted = summary.sort_values(by='exp(coef)')
    num_features = summary_sorted.shape[0]
    
    sns.set_theme(style="whitegrid")
    fig, ax = plt.subplots(figsize=(8, num_features * 0.4 + 2), dpi=300)
    
    y_positions = range(num_features)
    hr_vals = summary_sorted['exp(coef)'].values
    lower_vals = summary_sorted['exp(coef) lower 95%'].values
    upper_vals = summary_sorted['exp(coef) upper 95%'].values
    errors = [hr_vals - lower_vals, upper_vals - hr_vals]
    
    ax.errorbar(hr_vals, y_positions, xerr=errors, fmt='s', color='black', 
                ecolor='green', elinewidth=1.5, capsize=3)
    ax.axvline(x=1, color='red', linestyle='--', linewidth=1)
    ax.set_yticks(y_positions)
    ax.set_yticklabels(summary_sorted.index, fontsize=7)
    ax.set_xlabel('Hazard Ratio (log scale)', fontsize=8)
    ax.set_title(title, fontsize=10)
    ax.set_xscale('log')
    plt.subplots_adjust(left=0.3, right=0.95, top=0.92, bottom=0.1)
    return fig

def plot_model_summary(model: CoxPHFitter, title: str):
    """Generate a table of the model summary with numbers rounded to two decimals."""
    summary = model.summary.copy().round(2)
    
    fig, ax = plt.subplots(figsize=(8, summary.shape[0] * 0.25 + 1.5), dpi=300)
    ax.axis('tight')
    ax.axis('off')
    
    table = ax.table(cellText=summary.values,
                     colLabels=summary.columns,
                     rowLabels=summary.index,
                     loc='center')
    table.auto_set_font_size(False)
    table.set_fontsize(6)
    ax.set_title(f"{title} - Model Summary", fontweight="bold", fontsize=10)
    plt.subplots_adjust(left=0.2, right=0.95, top=0.9, bottom=0.1)
    return fig

def plot_performance_metrics(model: CoxPHFitter, c_index: float, title: str):
    """Generate a table of performance metrics (C-index, Log-likelihood, AIC)."""
    try:
        log_likelihood = model.log_likelihood_
    except Exception:
        log_likelihood = np.nan
    try:
        aic = model.AIC_partial_
    except Exception:
        aic = np.nan
    
    metrics_data = {
        "Metric": ["C-index", "Log-likelihood", "AIC"],
        "Value": [round(c_index, 2), round(log_likelihood, 2), round(aic, 2)]
    }
    df_metrics = pd.DataFrame(metrics_data)
    
    fig, ax = plt.subplots(figsize=(4, 1.2), dpi=300)
    ax.axis('tight')
    ax.axis('off')
    
    table = ax.table(cellText=df_metrics.values, colLabels=df_metrics.columns, loc='center')
    table.auto_set_font_size(False)
    table.set_fontsize(8)
    ax.set_title(f"{title} - Performance Metrics", fontweight="bold", fontsize=10)
    plt.subplots_adjust(left=0.05, right=0.95, top=0.9, bottom=0)
    return fig

def plot_km_scatter(variable: str, data: pd.DataFrame, title: str):
    """
    Generate Kaplan-Meier plots using a ROC-derived threshold.
    The ROC analysis determines the optimal cutoff for the variable,
    splitting the data into "High" and "Low" groups.
    """
    optimal_threshold = get_optimal_threshold(data, variable)
    print(f"Optimal threshold for {variable}: {optimal_threshold:.2f}")
    
    group_data = create_group_column(data, variable, optimal_threshold)
    group_high = group_data[group_data['group'] == 'High']
    group_low = group_data[group_data['group'] == 'Low']
    
    kmf = KaplanMeierFitter()
    fig, ax = plt.subplots(figsize=(6, 5), dpi=300)
    
    # Plot KM curve for high values
    kmf.fit(group_high['censdate'], event_observed=group_high['vital'], label=f"{variable} High")
    kmf.plot(ax=ax, marker='o', ci_show=True, linewidth=1.2)
    
    # Plot KM curve for low values
    kmf.fit(group_low['censdate'], event_observed=group_low['vital'], label=f"{variable} Low")
    kmf.plot(ax=ax, marker='o', ci_show=True, linewidth=1.2)
    
    ax.set_title(title, fontsize=10)
    ax.set_xlabel("Time", fontsize=9)
    ax.set_ylabel("Survival Probability", fontsize=9)
    plt.subplots_adjust(left=0.15, right=0.95, top=0.9, bottom=0.15)
    return fig

def plot_density_event_status_with_threshold(variable: str, data: pd.DataFrame, title: str):
    """
    Generate a density plot for the given variable using a ROC-derived optimal threshold.
    The data are split into "High" and "Low" groups, and vertical lines are drawn at the threshold and each group's median.
    """
    optimal_threshold = get_optimal_threshold(data, variable)
    print(f"Optimal threshold for {variable}: {optimal_threshold:.2f}")
    
    # Create groups based on the threshold
    data = create_group_column(data, variable, optimal_threshold)
    
    fig, ax = plt.subplots(figsize=(6, 5), dpi=300)
    
    # Data for each group
    subset_high = data[data['group'] == 'High']
    subset_low = data[data['group'] == 'Low']
    
    # Plot density curves for each group
    sns.kdeplot(x=subset_high[variable], label="High Group", fill=True, common_norm=False, alpha=0.5, ax=ax)
    sns.kdeplot(x=subset_low[variable], label="Low Group", fill=True, common_norm=False, alpha=0.5, ax=ax)
    
    # Draw vertical line at the optimal threshold
    ax.axvline(optimal_threshold, linestyle='--', color='red', linewidth=1, label=f"Threshold = {optimal_threshold:.2f}")
    
    # Calculate and plot medians for each group
    high_median = subset_high[variable].median()
    low_median = subset_low[variable].median()
    ax.axvline(high_median, linestyle='--', color='blue', linewidth=1, label=f"High Median = {high_median:.2f}")
    ax.axvline(low_median, linestyle='--', color='green', linewidth=1, label=f"Low Median = {low_median:.2f}")
    
    ax.set_title(title, fontsize=10)
    ax.set_xlabel(f"{variable} Value", fontsize=9)
    ax.set_ylabel("Density", fontsize=9)
    ax.legend(fontsize=8)
    plt.tight_layout()
    
    return fig

def plot_violin_event_status_with_threshold(variable: str, data: pd.DataFrame, title: str):
    """
    Generate a violin plot for the given variable using a ROC-derived optimal threshold.
    The data is grouped into "High" and "Low" based on the optimal threshold.
    Median values for each group are overlaid on the plot.
    """
    optimal_threshold = get_optimal_threshold(data, variable)
    print(f"Optimal threshold for {variable}: {optimal_threshold:.2f}")
    
    # Create groups based on the threshold
    data = create_group_column(data, variable, optimal_threshold)
    
    fig, ax = plt.subplots(figsize=(6, 5), dpi=300)
    
    sns.violinplot(x="group", y=variable, data=data, palette="Set2", ax=ax, inner="quartile")
    
    # Calculate medians for each group and overlay scatter markers
    group_medians = data.groupby("group")[variable].median().reset_index()
    for idx, row in group_medians.iterrows():
        # Determine the x-axis position for each group in the violin plot
        # Adjust the x positions if needed depending on plot order
        x_pos = 0 if row["group"] == "High" else 1
        ax.scatter(x_pos, row[variable], color="black", zorder=10, s=50, label=f"{row['group']} Median = {row[variable]:.2f}")
    
    # Avoid duplicate labels in legend
    handles, labels = ax.get_legend_handles_labels()
    unique = dict(zip(labels, handles))
    ax.legend(unique.values(), unique.keys(), fontsize=8)
    
    ax.set_title(title, fontsize=10)
    ax.set_xlabel("Group (based on optimal threshold)", fontsize=9)
    ax.set_ylabel(f"{variable} Value", fontsize=9)
    plt.tight_layout()
    
    return fig

def generate_model_report(model: CoxPHFitter, group_name: str, c_index: float, subset: pd.DataFrame):
    """
    Generate a multipage PDF report for the given model group.
    The report includes:
      - A forest plot, model summary table, and performance metrics.
      - Kaplan-Meier plots using a ROC-derived threshold.
      - Density and violin plots for each significant variable using the optimal threshold, including lines for group medians.
    """
    report_filename = f"model_report_{group_name.lower().replace(' ', '_')}.pdf"
    with PdfPages(report_filename) as pdf:
        # Page 1: Forest Plot
        forest_fig = plot_forest(model, f"Forest Plot: {group_name}")
        pdf.savefig(forest_fig, bbox_inches="tight")
        plt.close(forest_fig)
        
        # Page 2: Model Summary
        summary_fig = plot_model_summary(model, group_name)
        pdf.savefig(summary_fig, bbox_inches="tight")
        plt.close(summary_fig)
        
        # Page 3: Performance Metrics
        performance_fig = plot_performance_metrics(model, c_index, group_name)
        pdf.savefig(performance_fig, bbox_inches="tight")
        plt.close(performance_fig)
        
        # For each significant variable (p < 0.05), plot KM, density, and violin plots.
        significant_vars = model.summary[model.summary['p'] < 0.05].index.tolist()
        for var in significant_vars:
            if var in subset.columns:
                km_fig = plot_km_scatter(var, subset, f"Kaplan-Meier Plot for {var} ({group_name})")
                pdf.savefig(km_fig, bbox_inches="tight")
                plt.close(km_fig)
                
                density_fig = plot_density_event_status_with_threshold(var, subset, f"Density Plot for {var} by Vital ({group_name})")
                pdf.savefig(density_fig, bbox_inches="tight")
                plt.close(density_fig)
                
                violin_fig = plot_violin_event_status_with_threshold(var, subset, f"Violin Plot for {var} by Vital ({group_name})")
                pdf.savefig(violin_fig, bbox_inches="tight")
                plt.close(violin_fig)

# ----------------------- Main Processing Pipeline -----------------------

if __name__ == "__main__":
    dataset_folder = 'dataset'
    edf_files = [os.path.join(dataset_folder, f) for f in os.listdir(dataset_folder) if f.endswith('.edf')]
    xml_files = [os.path.join(dataset_folder, f) for f in os.listdir(dataset_folder) if f.endswith('.xml')]

    # Process EDF files with TQDM progress bar
    ecg_signals = []
    for file in tqdm(edf_files, desc="Loading EDF files"):
        signal = load_edf_file(file)
        if signal is not None:
            ecg_signals.append(signal)
    filtered_signals = [filter_ecg_signal(signal) for signal in ecg_signals]
    r_peaks_list = [detect_r_peaks(signal) for signal in filtered_signals]
    hrv_metrics_list = [compute_hrv_metrics(r_peaks) for r_peaks in r_peaks_list]
    hrv_df = pd.DataFrame(hrv_metrics_list)

    # Process XML files for apnea and REM events with TQDM progress
    apnea_events_list = [extract_apnea_events(file) for file in tqdm(xml_files, desc="Extracting Apnea Events")]
    recording_duration_hours = 8
    ahi_list = [len(events) / recording_duration_hours for events in apnea_events_list]
    hrv_df['AHI'] = ahi_list
    hrv_df['apnea_status'] = hrv_df['AHI'].apply(lambda x: 1 if x >= 5 else 0)

    rem_events_list = [extract_rem_events(file) for file in tqdm(xml_files, desc="Extracting REM Events")]
    recording_duration_seconds = 8 * 3600
    rem_total_durations = [sum(event[1] for event in events) for events in rem_events_list]
    rem_fraction = [duration / recording_duration_seconds for duration in rem_total_durations]
    hrv_df['rem_fraction'] = rem_fraction
    hrv_df['rem_status'] = hrv_df['rem_fraction'].apply(lambda x: 'REM' if x >= 0.2 else 'Not REM')

    # Load mortality data and combine with HRV data
    mortality_data_path = 'outcomes.csv'
    mortality_data = pd.read_csv(mortality_data_path)
    combined_data = pd.concat([hrv_df.reset_index(drop=True), mortality_data.reset_index(drop=True)], axis=1)
    combined_data_clean = combined_data.dropna().reset_index(drop=True)

    non_hrv = ["AHI", "apnea_status", "rem_fraction", "rem_status", "censdate", "vital"]
    hrv_metric_columns = [col for col in hrv_df.columns if col not in non_hrv]
    scaler = StandardScaler()
    data_for_scaling = combined_data_clean[hrv_metric_columns]
    scaled_data = scaler.fit_transform(data_for_scaling)
    scaled_hrv_df = pd.DataFrame(scaled_data, columns=hrv_metric_columns)

    combined_data_clean_scaled = combined_data_clean.copy()
    for col in hrv_metric_columns:
        combined_data_clean_scaled[col] = scaled_hrv_df[col]

    survival_columns = ['censdate', 'vital']
    model_columns = hrv_metric_columns + survival_columns

    # Define data subsets for four groups based on apnea and REM status
    subset_apnea_rem = combined_data_clean_scaled[
        (combined_data_clean_scaled['apnea_status'] == 1) & (combined_data_clean_scaled['rem_status'] == 'REM')]
    subset_apnea_notrem = combined_data_clean_scaled[
        (combined_data_clean_scaled['apnea_status'] == 1) & (combined_data_clean_scaled['rem_status'] == 'Not REM')]
    subset_noapnea_rem = combined_data_clean_scaled[
        (combined_data_clean_scaled['apnea_status'] == 0) & (combined_data_clean_scaled['rem_status'] == 'REM')]
    subset_noapnea_notrem = combined_data_clean_scaled[
        (combined_data_clean_scaled['apnea_status'] == 0) & (combined_data_clean_scaled['rem_status'] == 'Not REM')]

    models = [
        ("Apnea & REM", subset_apnea_rem),
        ("Apnea & Not REM", subset_apnea_notrem),
        ("No Apnea & REM", subset_noapnea_rem),
        ("No Apnea & Not REM", subset_noapnea_notrem)
    ]
    
    for group_name, subset in tqdm(models, desc="Fitting Models", unit="group"):
        if subset.empty:
            continue
        try:
            cox_model = CoxPHFitter(penalizer=0.1)
            cox_model.fit(subset[model_columns], duration_col='censdate', event_col='vital', show_progress=False)
            
            c_index = concordance_index(
                subset['censdate'],
                -cox_model.predict_partial_hazard(subset),
                subset['vital']
            )
            
            generate_model_report(cox_model, group_name, c_index, subset)
        except Exception:
            continue

    hrv_df.drop(columns=["rem_status"]).to_csv("hrv_data.csv", index=False)

Loading EDF files: 100%|██████████| 50/50 [00:03<00:00, 16.22it/s]
Extracting Apnea Events: 100%|██████████| 50/50 [00:00<00:00, 257.11it/s]
Extracting REM Events: 100%|██████████| 50/50 [00:00<00:00, 301.13it/s]
Fitting Models:  75%|███████▌  | 3/4 [00:02<00:00,  1.20group/s]

Optimal threshold for pnni_20: 0.46
Optimal threshold for pnni_20: 0.46
Optimal threshold for pnni_20: 0.46
Optimal threshold for hf: -0.19
Optimal threshold for hf: -0.19
Optimal threshold for hf: -0.19


Fitting Models: 100%|██████████| 4/4 [00:04<00:00,  1.10s/group]
