In [None]:
# =============================================================================
# 7_generate_catch22_features_sliding_window.py
#
# Description:
# This script implements the "Sliding Window Statistics" approach to address
# the weak supervision problem. Instead of one feature set per surgery, it
# analyzes the waveform in windows (e.g., 10-minute windows, sliding by 5 mins).
# It calculates catch22 features for each window, then computes summary
# statistics (mean, std, min, max) across all windows for each feature.
# The result is a richer feature set per patient that captures the dynamics
# and volatility of the intraoperative period.
#
# =============================================================================

# --- Standard Library Imports ---
import os
import logging
from multiprocessing import Pool, cpu_count
from functools import partial

# --- Third-Party Imports ---
import pandas as pd
import numpy as np
import vitaldb
import pycatch22
from tqdm import tqdm

# =============================================================================
# 1. CONFIGURATION
# =============================================================================
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")

# --- Paths ---
# ** FIXED: Robust pathing for both script and notebook execution **
try:
    # This works when running as a .py file
    PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
except NameError:
    # This works when running in a Jupyter Notebook or other interactive environment
    PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), os.pardir))

PROCESSED_DATA_DIR = os.path.join(PROJECT_ROOT, "data", "processed")
RAW_DATA_DIR = os.path.join(PROJECT_ROOT, "data", "raw")

# --- File Names ---
COHORT_FILE = "final_cohort_with_death_label.csv"
OUTPUT_FILE = "waveform_catch22_features_sliding_window.csv"

# --- Parameters ---
WAVEFORM_NAME = 'SNUADC/PLETH'
# ** FIXED: Using the correct, documented original sampling rate **
ORIGINAL_SR = 500  # SNUADC/PLETH is documented as 500Hz
DOWNSAMPLED_SR = 10  # Target sample rate of 10Hz

# Sliding window parameters, chosen to balance detail and computation
WINDOW_SIZE_SECONDS = 10 * 60  # 10 minutes
SLIDE_SECONDS = 5 * 60         # 5 minutes
WINDOW_SIZE_SAMPLES = WINDOW_SIZE_SECONDS * DOWNSAMPLED_SR
SLIDE_SAMPLES = SLIDE_SECONDS * DOWNSAMPLED_SR

# =============================================================================
# 2. WORKER FUNCTION FOR PARALLEL PROCESSING
# =============================================================================
def process_case(case_info):
    """
    Worker function to process a single case. Loads waveform, slices into
    windows, extracts catch22 features for each window, and returns
    summary statistics across all windows.
    """
    caseid, opstart, opend = case_info
    try:
        # ** CRITICAL FIX: Use vitaldb.load_case(), the correct high-level function. **
        # This function handles finding, downloading, and caching the data.
        # We request the data at its original 500Hz resolution.
        wave = vitaldb.load_case(caseid, WAVEFORM_NAME, 1 / ORIGINAL_SR)

        if wave is None or len(wave) == 0:
            return {'caseid': caseid, 'error': 'Waveform is empty or None'}

        # Dynamic calculation of downsample factor
        DOWNSAMPLE_FACTOR = ORIGINAL_SR // DOWNSAMPLED_SR
        
        # Convert opstart/opend times (in seconds) to sample indices
        start_index = int(opstart * ORIGINAL_SR)
        end_index = int(opend * ORIGINAL_SR)
        intraop_wave = wave[start_index:end_index]
        
        # Validate that the waveform is long enough for at least one window
        if len(intraop_wave) < (WINDOW_SIZE_SECONDS * ORIGINAL_SR):
            return {'caseid': caseid, 'error': 'Waveform shorter than one window'}
            
        downsampled_wave = intraop_wave[::DOWNSAMPLE_FACTOR]

        window_features_list = []
        # --- Sliding Window Loop ---
        for i in range(0, len(downsampled_wave) - WINDOW_SIZE_SAMPLES + 1, SLIDE_SAMPLES):
            window = downsampled_wave[i : i + WINDOW_SIZE_SAMPLES]
            
            # Skip invalid windows (e.g., flatline or all NaNs)
            if np.nanstd(window) < 1e-6 or np.all(np.isnan(window)):
                continue 
                
            features = pycatch22.catch22_all(window, catch24=False)
            window_features_list.append(features)
        
        if not window_features_list:
            return {'caseid': caseid, 'error': 'No valid windows found'}

        # --- Aggregate Features ---
        windows_df = pd.DataFrame(window_features_list)
        
        stats_mean = windows_df.mean().add_suffix('_mean')
        stats_std = windows_df.std().add_suffix('_std')
        stats_min = windows_df.min().add_suffix('_min')
        stats_max = windows_df.max().add_suffix('_max')
        
        final_features = pd.concat([stats_mean, stats_std, stats_min, stats_max])
        
        result = final_features.to_dict()
        result['caseid'] = caseid
        return result

    except Exception as e:
        return {'caseid': caseid, 'error': str(e)}

# =============================================================================
# 3. MAIN EXECUTION
# =============================================================================
def main():
    """Main function to orchestrate the feature extraction."""
    logging.info("Starting sliding window catch22 feature extraction...")

    try:
        cohort_df = pd.read_csv(os.path.join(PROCESSED_DATA_DIR, COHORT_FILE))
        logging.info(f"Loaded {len(cohort_df)} cases from {COHORT_FILE}")
    except FileNotFoundError:
        logging.error(f"FATAL: Cohort file not found at {os.path.join(PROCESSED_DATA_DIR, COHORT_FILE)}")
        return
    
    cases_to_process = list(zip(cohort_df['caseid'], cohort_df['opstart'], cohort_df['opend']))

    num_processes = max(1, cpu_count() - 1)
    logging.info(f"Using {num_processes} processes for extraction.")
    
    all_features = []
    with Pool(processes=num_processes) as pool:
        with tqdm(total=len(cases_to_process), desc="Extracting Sliding Window Features") as pbar:
            for result in pool.imap_unordered(process_case, cases_to_process):
                all_features.append(result)
                pbar.update()

    features_df = pd.DataFrame(all_features)

    error_df = features_df[features_df['error'].notna()]
    success_df = features_df[features_df['error'].isna()].drop(columns=['error'])
    
    if not error_df.empty:
        logging.warning(f"Encountered {len(error_df)} errors during processing.")
        error_log_path = os.path.join(PROCESSED_DATA_DIR, "catch22_sliding_window_errors.csv")
        error_df.to_csv(error_log_path, index=False)
        logging.warning(f"Error log saved to {error_log_path}")

    if not success_df.empty:
        # Fill potential NaN values from std dev calculation with 0
        success_df.fillna(0, inplace=True)

        output_path = os.path.join(PROCESSED_DATA_DIR, OUTPUT_FILE)
        # Reorder columns to have caseid first for readability
        cols = ['caseid'] + [col for col in success_df.columns if col != 'caseid']
        success_df = success_df[cols]
        success_df.to_csv(output_path, index=False)
        
        logging.info(f"Successfully extracted features for {len(success_df)} cases.")
        logging.info(f"Features saved to {output_path}")
    else:
        logging.error("No features were successfully extracted. Check error log.")
        
    logging.info("Feature extraction complete.")


if __name__ == "__main__":
    main()


NameError: name '__file__' is not defined