In [1]:
import pandas as pd
import numpy as np
from datetime import datetime, timedelta

In [2]:
THRESH_SPO2_LOW = 92
THRESH_SPO2_CRITICAL = 88
THRESH_SHOCK_INDEX_WARNING = 0.9
THRESH_SHOCK_INDEX_CRITICAL = 1.0
THRESH_MAP_LOW = 65  # Mean Arterial Pressure threshold for shock
THRESH_LACTATE_HIGH = 2.0  # mmol/L for tissue hypoperfusion
THRESH_LACTATE_CRITICAL = 4.0
THRESH_UO_LOW = 0.5  # mL/kg/hr (oliguria)

# Trend Analysis
TREND_WINDOW = 6  # Number of readings for short-term trend analysis

# AGE_THRESHOLDS (kept same as user-provided)
AGE_THRESHOLDS = {
    'neonate': {
        'rr_low': 30, 'rr_normal': 40, 'rr_high': 60,
        'hr_low': 100, 'hr_normal': 140, 'hr_high': 160,
        'sbp_low': 60, 'sbp_normal': 70, 'sbp_high': 90,
        'temp_low': 36.0, 'temp_normal': 37.2, 'temp_high': 38.0
    },
    'infant': {
        'rr_low': 24, 'rr_normal': 30, 'rr_high': 40,
        'hr_low': 80, 'hr_normal': 120, 'hr_high': 140,
        'sbp_low': 70, 'sbp_normal': 85, 'sbp_high': 100,
        'temp_low': 36.0, 'temp_normal': 37.2, 'temp_high': 38.0
    },
    'child': {
        'rr_low': 16, 'rr_normal': 20, 'rr_high': 30,
        'hr_low': 70, 'hr_normal': 90, 'hr_high': 110,
        'sbp_low': 80, 'sbp_normal': 95, 'sbp_high': 110,
        'temp_low': 36.0, 'temp_normal': 37.0, 'temp_high': 38.0
    },
    'adolescent': {
        'rr_low': 12, 'rr_normal': 16, 'rr_high': 20,
        'hr_low': 60, 'hr_normal': 75, 'hr_high': 100,
        'sbp_low': 90, 'sbp_normal': 105, 'sbp_high': 120,
        'temp_low': 35.8, 'temp_normal': 36.8, 'temp_high': 37.8
    },
    'adult': {
        'rr_low': 12, 'rr_normal': 16, 'rr_high': 20,
        'hr_low': 60, 'hr_normal': 80, 'hr_high': 100,
        'sbp_low': 90, 'sbp_normal': 115, 'sbp_high': 130,
        'temp_low': 35.5, 'temp_normal': 36.8, 'temp_high': 38.0
    },
    'geriatric': {
        'rr_low': 12, 'rr_normal': 16, 'rr_high': 24,
        'hr_low': 55, 'hr_normal': 70, 'hr_high': 90,
        'sbp_low': 90, 'sbp_normal': 125, 'sbp_high': 140,
        'temp_low': 35.5, 'temp_normal': 36.5, 'temp_high': 37.5
    }
}

In [3]:
def assign_age_category(df):
    df = df.copy()
    def _categorize(age):
        if age <= 0.083: return 'neonate'
        elif age <= 1:   return 'infant'
        elif age < 5:    return 'child'
        elif age < 13:   return 'adolescent'
        elif age < 65:   return 'adult'
        else:            return 'geriatric'
    if 'age' in df.columns:
        df['age_category'] = df['age'].apply(_categorize)
    else:
        df['age_category'] = 'adult'
    return df

In [4]:
def apply_vital_range_flags(df):
    df = df.copy()
    df = assign_age_category(df)

    # SpO₂ flags
    df['flag_spo2_low'] = df['spo2'] < THRESH_SPO2_LOW
    df['flag_spo2_critical'] = df['spo2'] < THRESH_SPO2_CRITICAL

    # Temperature flags
    df['flag_temp_high'] = df.apply(lambda row: row['temperature'] >= AGE_THRESHOLDS[row['age_category']]['temp_high'], axis=1)
    df['flag_temp_low'] = df.apply(lambda row: row['temperature'] < AGE_THRESHOLDS[row['age_category']]['temp_low'], axis=1)

    # Respiratory Rate flags
    df['flag_rr_low'] = df.apply(lambda row: row['resp_rate'] < AGE_THRESHOLDS[row['age_category']]['rr_low'], axis=1)
    df['flag_rr_high'] = df.apply(lambda row: row['resp_rate'] >= AGE_THRESHOLDS[row['age_category']]['rr_high'], axis=1)

    # Heart Rate flags
    df['flag_hr_low'] = df.apply(lambda row: row['heart_rate'] < AGE_THRESHOLDS[row['age_category']]['hr_low'], axis=1)
    df['flag_hr_high'] = df.apply(lambda row: row['heart_rate'] >= AGE_THRESHOLDS[row['age_category']]['hr_high'], axis=1)

    # Shock Index
    df['shock_index'] = df['heart_rate'] / np.clip(df['sbp'], a_min=1, a_max=None)
    df['flag_si_warning'] = df['shock_index'] >= THRESH_SHOCK_INDEX_WARNING
    df['flag_si_critical'] = df['shock_index'] >= THRESH_SHOCK_INDEX_CRITICAL

    # Blood Pressure flags (sbp/dbp)
    df['flag_sbp_low'] = df.apply(lambda row: row['sbp'] < AGE_THRESHOLDS[row['age_category']]['sbp_low'], axis=1)
    df['flag_sbp_high'] = df.apply(lambda row: row['sbp'] >= AGE_THRESHOLDS[row['age_category']]['sbp_high'], axis=1)
    df['flag_dbp_low'] = df.apply(lambda row: row['dbp'] < (AGE_THRESHOLDS[row['age_category']]['sbp_low'] * 0.6), axis=1)
    df['flag_dbp_high'] = df.apply(lambda row: row['dbp'] >= (AGE_THRESHOLDS[row['age_category']]['sbp_high'] * 0.6), axis=1)

    return df

In [5]:
def compute_recent_trends_delta(df):
    """
    Computes recent trends for each vital by differencing consecutive readings.
    Uses age-specific thresholds + detects likely false-positive variations
    (e.g., transient spikes/drops or conflicting multi-vital patterns).
    """
    df = df.copy().sort_values("timestamp").reset_index(drop=True)

    if 'age_category' not in df.columns:
        df = assign_age_category(df)

    trends = {}
    recent = df.tail(TREND_WINDOW)
    if recent.empty:
        return trends

    age_group = recent['age_category'].iloc[-1]
    thresholds = AGE_THRESHOLDS[age_group]

    vital_map = {
        'rr': ('rr_low', 'rr_normal', 'rr_high'),
        'hr': ('hr_low', 'hr_normal', 'hr_high'), 
        'sbp': ('sbp_low', 'sbp_normal', 'sbp_high'),
        'temperature': ('temp_low', 'temp_normal', 'temp_high'),
        'spo2': (None, None, None)
    }

    # --- Enhanced false positive detection ---
    def is_transient_spike(values, delta, threshold_ratio=0.15):
        """Detect short-lived sharp deviations likely due to artifacts."""
        if len(values) < 3 or np.isnan(values).any():
            return False
        median_val = np.median(values)
        if median_val == 0:
            return False
        deviation = abs(values[-1] - median_val)
        return deviation / median_val > threshold_ratio and abs(delta) < 0.2 * deviation

    def is_unstable_signal(values, threshold_ratio=0.5):
        """Detect excessive variability suggesting measurement noise."""
        if len(values) < 2 or np.isnan(values).any():
            return False
        value_std = np.std(values)
        if value_std == 0:
            return False
        diff_std = np.std(np.diff(values))
        return diff_std > threshold_ratio * value_std

    for vital in ['rr', 'hr', 'sbp', 'temperature', 'spo2']:
        if vital not in recent.columns or recent[vital].isnull().all():
            continue

        y = recent[vital].dropna().values
        if len(y) < 2:
            continue

        avg_delta = np.mean(np.diff(y))
        latest = y[-1]
        trends[f"{vital}_trend"] = round(avg_delta, 3)

        # --- Enhanced false positive detection ---
        transient_spike = is_transient_spike(y, avg_delta)
        unstable_signal = is_unstable_signal(y)
        
        fp_evidence = []
        if transient_spike:
            fp_evidence.append("transient_spike")
        if unstable_signal:
            fp_evidence.append("unstable_signal")

        # SPO₂ special handling
        if vital == 'spo2':
            if latest < THRESH_SPO2_LOW:
                if avg_delta > 0:
                    flag = "Still abnormal — but improving"
                elif avg_delta < 0:
                    flag = "Abnormal and worsening"
                else:
                    flag = "Abnormal and flat"
            else:
                if avg_delta < 0:
                    flag = "Normal but deteriorating"
                else:
                    flag = "Normal and stable"

        else:
            low_key, norm_key, high_key = vital_map[vital]
            low = thresholds[low_key]
            normal = thresholds[norm_key]  
            high = thresholds[high_key]

            if latest < low or latest > high:
                if (latest > high and avg_delta < 0) or (latest < low and avg_delta > 0):
                    flag = "Still abnormal — but improving"
                else:
                    flag = "Abnormal and worsening"
            else:
                if avg_delta < 0:
                    flag = "Normal but deteriorating"
                else:
                    flag = "Normal and stable"

        if fp_evidence:
            flag += f" (possible false-positive: {', '.join(fp_evidence)})"
            trends[f"{vital}_false_positive"] = True
            trends[f"{vital}_fp_evidence"] = fp_evidence
            trends[f"{vital}_confidence"] = "LOW"
        else:
            trends[f"{vital}_false_positive"] = False
            trends[f"{vital}_confidence"] = "HIGH"

        trends[f"{vital}_trend_flag"] = flag

    # --- Shock Index trend ---
    if all(col in recent.columns for col in ['hr', 'sbp']):
        hr = recent['hr'].values
        sbp = np.clip(recent['sbp'].values, a_min=1, a_max=None)
        si = hr / sbp

        if len(si) >= 2:
            avg_si_delta = np.mean(np.diff(si))
            trends['shock_index_trend'] = round(avg_si_delta, 3)

            latest_si = si[-1]
            if latest_si >= THRESH_SHOCK_INDEX_CRITICAL:
                flag = "Shock Index high — improving" if avg_si_delta < 0 else "Shock Index high — worsening"
            else:
                flag = "Normal but improving" if avg_si_delta < 0 else "Normal but rising"

            si_fp_evidence = []
            if is_unstable_signal(si, threshold_ratio=0.3):
                si_fp_evidence.append("unstable_si")

            if latest_si >= THRESH_SHOCK_INDEX_CRITICAL:
                latest_hr = recent['hr'].iloc[-1]
                latest_sbp = recent['sbp'].iloc[-1]
                hr_normal = latest_hr < thresholds['hr_high']
                sbp_normal = latest_sbp >= thresholds['sbp_low']
                
                if (hr_normal and not sbp_normal) or (sbp_normal and not hr_normal):
                    si_fp_evidence.append("single_component_abnormality")

            if si_fp_evidence:
                flag += f" (possible false-positive: {', '.join(si_fp_evidence)})"
                trends['shock_index_false_positive'] = True
                trends['shock_index_fp_evidence'] = si_fp_evidence
                trends['shock_index_confidence'] = "LOW"
            else:
                trends['shock_index_false_positive'] = False  
                trends['shock_index_confidence'] = "HIGH"

            trends['shock_index_trend_flag'] = flag

    # --- Overall confidence summary ---
    false_positive_count = sum(1 for key in trends if key.endswith('_false_positive') and trends[key])
    total_metrics = sum(1 for key in trends if key.endswith('_false_positive'))
    
    if total_metrics > 0:
        fp_ratio = false_positive_count / total_metrics
        if fp_ratio >= 0.5:
            trends['overall_confidence'] = "LOW"
        elif fp_ratio >= 0.25:
            trends['overall_confidence'] = "MEDIUM" 
        else:
            trends['overall_confidence'] = "HIGH"
    else:
        trends['overall_confidence'] = "HIGH"

    return trends


In [6]:
def calculate_odi(nightly_spo2_df):
    """
    Processes a DataFrame of nocturnal SpO2 readings to calculate the Oxygen Desaturation Index (ODI).
    The DataFrame must have 'timestamp' and 'spo2' columns.

    Args:
        nightly_spo2_df (pd.DataFrame): DataFrame containing SpO2 data for a single night.

    Returns:
        dict: A dictionary containing ODI, event count, valid hours, and the list of events.
    """
    # SLEEP APNEA SPECIFIC THRESHOLDS for event detection
    EVENT_DROP_PERCENT = 3
    EVENT_RECOVERY_PERCENT = 1.5
    MIN_EVENT_DURATION = 10

    if nightly_spo2_df.empty or len(nightly_spo2_df) < 2:
        return {'odi': 0, 'event_count': 0, 'valid_hours': 0, 'events': []}

    df = nightly_spo2_df.sort_values('timestamp').reset_index(drop=True).copy()
    df['time_diff'] = df['timestamp'].diff().dt.total_seconds().fillna(0)
    total_duration_seconds = (df['timestamp'].iloc[-1] - df['timestamp'].iloc[0]).total_seconds()
    valid_hours = total_duration_seconds / 3600

    events = []
    in_event = False
    event_start_index = None
    baseline_spo2 = None

    for i in range(1, len(df)):
        current_spo2 = df['spo2'].iloc[i]
        prev_spo2 = df['spo2'].iloc[i-1]

        if not in_event:
            # Look for the start of an event: a drop of >= EVENT_DROP_PERCENT
            if prev_spo2 - current_spo2 >= EVENT_DROP_PERCENT:
                in_event = True
                event_start_index = i-1
                baseline_spo2 = df['spo2'].iloc[event_start_index]
        else:
            # We are in an event. Look for recovery: a rise of >= EVENT_RECOVERY_PERCENT from the nadir
            current_nadir = df['spo2'].iloc[event_start_index:i+1].min()
            if current_spo2 - current_nadir >= EVENT_RECOVERY_PERCENT:
                # Event has ended
                event_end_index = i
                event_duration = (df['timestamp'].iloc[event_end_index] - df['timestamp'].iloc[event_start_index]).total_seconds()
                
                if event_duration >= MIN_EVENT_DURATION:
                    # Record the event
                    nadir_value = current_nadir
                    drop_from_baseline = baseline_spo2 - nadir_value
                    events.append({
                        'start': df['timestamp'].iloc[event_start_index],
                        'end': df['timestamp'].iloc[event_end_index],
                        'nadir_spo2': nadir_value,
                        'drop_amount': drop_from_baseline
                    })
                # Reset for next event
                in_event = False
                event_start_index = None
                baseline_spo2 = None

    # Handle case where an event is still ongoing at the end of the data
    if in_event:
        event_duration = (df['timestamp'].iloc[-1] - df['timestamp'].iloc[event_start_index]).total_seconds()
        if event_duration >= MIN_EVENT_DURATION:
            nadir_value = df['spo2'].iloc[event_start_index:].min()
            drop_from_baseline = baseline_spo2 - nadir_value
            events.append({
                'start': df['timestamp'].iloc[event_start_index],
                'end': df['timestamp'].iloc[-1],
                'nadir_spo2': nadir_value,
                'drop_amount': drop_from_baseline
            })

    event_count = len(events)
    odi = event_count / valid_hours if valid_hours > 0 else 0

    return {'odi': odi, 'event_count': event_count, 'valid_hours': valid_hours, 'events': events}

In [7]:
def detect_sleep_apnea(df, window_hours=24):
    """
    Detect Sleep Apnea Concern using nocturnal SpO₂ and HR variability.
    Now integrates trend analysis for false positive reduction.

    Parameters:
        df: DataFrame with ['timestamp','spo2','heart_rate','age'] at minimum
        window_hours: rolling window for analysis (default: last 24h)

    Returns:
        dict with concern level, reasons, and features
    """
    # --- DATA PREPARATION --- 
    df = df.copy().sort_values("timestamp")
    df['hour'] = df['timestamp'].dt.hour

    # Use last 24h data
    cutoff = df['timestamp'].max() - timedelta(hours=window_hours)
    df_24h = df[df['timestamp'] >= cutoff]

    # Apply your standard flags FIRST (reusing your function)
    df_24h = apply_vital_range_flags(df_24h) 

    # Isolate nocturnal data
    nocturnal_mask = df_24h['hour'].between(0, 6)
    nocturnal_data = df_24h[nocturnal_mask].copy()
    
    features = {}

    # --- FEATURE CALCULATION ---
    # 1. Reuse your existing Nocturnal Drop feature
    night_spo2 = nocturnal_data['spo2'].dropna()
    day_spo2   = df_24h[df_24h['hour'].between(8, 20)]['spo2'].dropna()
    if not night_spo2.empty and not day_spo2.empty:
        features['nocturnal_drop'] = round(night_spo2.median() - day_spo2.median(), 2)
    else:
        features['nocturnal_drop'] = np.nan

    # 2. Calculate TRUE ODI and events
    odi_results = calculate_odi(nocturnal_data[['timestamp', 'spo2']])
    features['odi'] = round(odi_results['odi'], 2)
    features['desat_events'] = odi_results['event_count']
    features['valid_nocturnal_hours'] = round(odi_results['valid_hours'], 2)

    # 3. Calculate % time low (REUSING existing flags!)
    total_nocturnal_readings = len(nocturnal_data)
    if total_nocturnal_readings > 0:
        features['percent_time_below_90'] = (nocturnal_data['flag_spo2_low'].sum() / total_nocturnal_readings) * 100
        features['percent_time_below_88'] = (nocturnal_data['flag_spo2_critical'].sum() / total_nocturnal_readings) * 100
    else:
        features['percent_time_below_90'] = 0
        features['percent_time_below_88'] = 0

    # 4. Heart Rate Variability
    nocturnal_hr = nocturnal_data['heart_rate'].dropna()
    if not nocturnal_hr.empty:
        features['hr_sd'] = round(nocturnal_hr.std(), 2)
        features['hr_brady_episodes'] = int((nocturnal_hr < 60).sum())
        features['hr_tachy_episodes'] = int((nocturnal_hr > 100).sum())
    else:
        features['hr_sd'] = np.nan
        features['hr_brady_episodes'] = 0
        features['hr_tachy_episodes'] = 0

    # --- TREND ANALYSIS ON NOCTURNAL DATA FOR FALSE POSITIVE REDUCTION ---
    trends = compute_recent_trends_delta(nocturnal_data)  # Analyze nocturnal trends only
    
    # Get trend context for false positive reduction
    spo2_trend_flag = trends.get('spo2_trend_flag', 'Normal and stable')
    spo2_trend_value = trends.get('spo2_trend', 0)
    
    # FALSE POSITIVE FILTERS
    is_improving_rapidly = spo2_trend_flag == "Still abnormal — but improving" and spo2_trend_value > 0.5
    is_transient_dip = spo2_trend_flag == "Normal but deteriorating" and features['odi'] < 2.0
    is_very_mild_and_stable = features['odi'] < 1.0 and spo2_trend_flag == "Normal and stable"

    # --- CLASSIFICATION LOGIC WITH FALSE POSITIVE REDUCTION ---
    reasons = []
    concern = "green"

    # PRIMARY RULE: Based on true ODI and prolonged hypoxemia
    if (features['odi'] >= 5.0 or features['percent_time_below_88'] > 5) and not is_improving_rapidly:
        concern = "red"
        reasons.append(f"Oxygen Desaturation Index (ODI) ≥5.0 ({features['odi']}/hr) or significant time with SpO₂ <88%")
    
    # SECONDARY RULE: Based on lower ODI or less severe hypoxemia
    elif (features['odi'] >= 3.0 or features['percent_time_below_90'] > 10) and not is_transient_dip:
        concern = "orange"
        reasons.append(f"Moderate ODI ({features['odi']}/hr) or prolonged time with SpO₂ <90%")
    
    # TERTIARY RULE: Mild indicators (skip very mild stable patterns)
    elif (features['odi'] > 0 or (features['nocturnal_drop'] is not np.nan and features['nocturnal_drop'] <= -2)) and not is_very_mild_and_stable:
        concern = "yellow"
        reasons.append("Mild nocturnal desaturation detected")

    # MODIFIER RULE: HR Variability can upgrade concern
    if features['hr_sd'] > 8.0 and concern != "red": 
        if concern == "green":  # FIXED: was mismatched quote
            concern = "yellow"
            reasons.append("Elevated nocturnal heart rate variability.")
        elif concern == "yellow":
            concern = "orange"
            reasons.append("Elevated nocturnal heart rate variability plus desaturation.")

    # Add trend context to reasons if it influenced the decision
    if is_improving_rapidly and (features['odi'] >= 5.0 or features['percent_time_below_88'] > 5):
        reasons.append("Note: Pattern shows improving trend despite current severity")
    elif is_transient_dip and (features['odi'] >= 3.0 or features['percent_time_below_90'] > 10):
        reasons.append("Note: Appears to be transient dip based on trend analysis")
    elif is_very_mild_and_stable:
        reasons.append("Note: Very mild pattern with stable trend")

    # Include trend info in features for transparency
    features['trend_context'] = {
        'spo2_trend_flag': spo2_trend_flag,
        'spo2_trend_value': spo2_trend_value,
        'heart_rate_trend_flag': trends.get('heart_rate_trend_flag', 'Not available'),
        'resp_rate_trend_flag': trends.get('resp_rate_trend_flag', 'Not available')
    }

    return {
        "concern": concern,
        "features": features,
        "reasons": reasons
    }

In [8]:

# Create a timestamp range for 6 hours at 1-minute intervals (00:00 to 06:00)
timestamps = pd.date_range(start="2024-01-01 00:00", end="2024-01-01 06:00", freq='1min')

# Create a baseline SpO2 value
baseline_spo2 = 97
spo2_values = np.full(len(timestamps), baseline_spo2)

# --- Simulate 3 CLEAR Desaturation Events ---
# Event 1: A deep, long event (typical of apnea)
event_1_start = 60  # After 1 hour (01:00 AM)
spo2_values[event_1_start:event_1_start+20] = 92  # Drop to 92
spo2_values[event_1_start+20:event_1_start+25] = 88  # Nadir at 88
spo2_values[event_1_start+25:event_1_start+45] = 92  # Slow recovery
# spo2 returns to baseline after event_1_start+45

# Event 2: A shorter, shallower event
event_2_start = 180  # After 3 hours (03:00 AM)
spo2_values[event_2_start:event_2_start+15] = 93  # Drop to 93
spo2_values[event_2_start+15:event_2_start+20] = 90  # Nadir at 90
spo2_values[event_2_start+20:event_2_start+35] = 93  # Recovery
# spo2 returns to baseline after event_2_start+35

# Event 3: Another clear event
event_3_start = 270  # After 4.5 hours (04:30 AM)
spo2_values[event_3_start:event_3_start+15] = 91
spo2_values[event_3_start+15:event_3_start+20] = 87  # Nadir < 88
spo2_values[event_3_start+20:event_3_start+40] = 91

# Add some small random noise to make it look more realistic, but keep the events clear
np.random.seed(42) # For reproducible results
spo2_values = spo2_values + np.random.normal(0, 0.5, len(spo2_values))
spo2_values = np.clip(spo2_values, 80, 100) # Ensure values stay within possible range

# --- Create the Test DataFrame ---
df_test = pd.DataFrame({
    'timestamp': timestamps,
    'spo2': spo2_values,
    'heart_rate': np.random.normal(75, 5, len(timestamps)), # Random HR around 75
    'resp_rate': np.random.normal(16, 2, len(timestamps)),  # Random RR around 16
    'age': 45, # Set an age
    # Add other columns your pipeline might expect, even if with dummy data
    'sbp': 120,
    'dbp': 80,
    'temperature': 36.8
})

# --- RUN THE TEST ---
print("Running sleep apnea detection on sample data...")
result = detect_sleep_apnea(df_test, window_hours=24)

# --- PRINT THE RESULTS ---
print(f"\nOverall Concern Level: {result['concern'].upper()}")
print("Reasons:")
for reason in result['reasons']:
    print(f"  - {reason}")
print("\nCalculated Features:")
for key, value in result['features'].items():
    print(f"  - {key}: {value}")

Running sleep apnea detection on sample data...

Overall Concern Level: ORANGE
Reasons:
  - Moderate ODI (0.83/hr) or prolonged time with SpO₂ <90%

Calculated Features:
  - nocturnal_drop: nan
  - odi: 0.83
  - desat_events: 5
  - valid_nocturnal_hours: 6.0
  - percent_time_below_90: 19.667590027700832
  - percent_time_below_88: 2.21606648199446
  - hr_sd: 5.12
  - hr_brady_episodes: 0
  - hr_tachy_episodes: 0
  - trend_context: {'spo2_trend_flag': 'Normal and stable (possible false-positive: unstable_signal)', 'spo2_trend_value': np.float64(0.152), 'heart_rate_trend_flag': 'Not available', 'resp_rate_trend_flag': 'Not available'}


In [9]:
def compute_recent_trends_delta(df, trend_window=TREND_WINDOW):
    df = df.copy().sort_values("timestamp").reset_index(drop=True)
    if 'age_category' not in df.columns:
        df = assign_age_category(df)
    trends = {}
    recent = df.tail(trend_window)
    if recent.empty:
        return trends
    age_group = recent['age_category'].iloc[-1]
    thresholds = AGE_THRESHOLDS[age_group]

    vital_map = {
        'resp_rate': ('rr_low', 'rr_normal', 'rr_high'),
        'heart_rate': ('hr_low', 'hr_normal', 'hr_high'),
        'sbp': ('sbp_low', 'sbp_normal', 'sbp_high'),
        'temperature': ('temp_low', 'temp_normal', 'temp_high'),
        'spo2': (None, None, None)
    }

    for vital in ['resp_rate', 'heart_rate', 'sbp', 'temperature', 'spo2']:
        if vital not in recent.columns or recent[vital].isnull().all():
            continue
        y = recent[vital].dropna().values
        if len(y) < 2:
            continue
        avg_delta = np.mean(np.diff(y))
        latest = y[-1]
        trends[f"{vital}_trend"] = round(avg_delta, 3)

        if vital == 'spo2':
            if latest < THRESH_SPO2_LOW:
                if avg_delta > 0:
                    flag = "Still abnormal — but improving"
                elif avg_delta < 0:
                    flag = "Abnormal and worsening"
                else:
                    flag = "Abnormal and flat"
            else:
                if avg_delta < 0:
                    flag = "Normal but deteriorating"
                else:
                    flag = "Normal and stable"
        else:
            low_key, norm_key, high_key = vital_map[vital]
            low = thresholds[low_key]
            normal = thresholds[norm_key]
            high = thresholds[high_key]
            if latest < low or latest > high:
                if (latest > high and avg_delta < 0) or (latest < low and avg_delta > 0):
                    flag = "Still abnormal — but improving"
                else:
                    flag = "Abnormal and worsening"
            else:
                if avg_delta < 0:
                    flag = "Normal but deteriorating"
                else:
                    flag = "Normal and stable"
        trends[f"{vital}_trend_flag"] = flag

    # Shock Index trend
    if all(col in recent.columns for col in ['heart_rate', 'sbp']):
        hr = recent['heart_rate'].values
        sbp = np.clip(recent['sbp'].values, a_min=1, a_max=None)
        si = hr / sbp
        if len(si) >= 2:
            avg_si_delta = np.mean(np.diff(si))
            trends['shock_index_trend'] = round(avg_si_delta, 3)
            latest_si = si[-1]
            if latest_si >= THRESH_SHOCK_INDEX_CRITICAL:
                flag = "Shock Index critical — improving" if avg_si_delta < 0 else "Shock Index critical — worsening"
            else:
                flag = "Normal but improving" if avg_si_delta < 0 else "Normal but rising"
            trends['shock_index_trend_flag'] = flag
    return trends


In [10]:
def detect_sleep_apnea(df, window_hours=24):
    """
    Detect Sleep Apnea Concern using nocturnal SpO₂ and HR variability.
    Now integrates with existing pipeline functions for robust event detection.

    Parameters:
        df: DataFrame with ['timestamp','spo2','heart_rate','age'] at minimum
        window_hours: rolling window for analysis (default: last 24h)

    Returns:
        dict with concern level, reasons, and features
    """
    # --- DATA PREPARATION --- 
    # Reuse your existing data prep pattern
    df = df.copy().sort_values("timestamp")
    df['hour'] = df['timestamp'].dt.hour

    # Use last 24h data
    cutoff = df['timestamp'].max() - timedelta(hours=window_hours)
    df_24h = df[df['timestamp'] >= cutoff]

    # Apply your standard flags FIRST (reusing your function)
    # This adds 'flag_spo2_low', 'flag_spo2_critical', etc.
    df_24h = apply_vital_range_flags(df_24h) 

    # Isolate nocturnal data
    nocturnal_mask = df_24h['hour'].between(0, 6)
    nocturnal_data = df_24h[nocturnal_mask].copy()
    
    features = {}

    # --- FEATURE CALCULATION ---
    # 1. Reuse your existing Nocturnal Drop feature
    night_spo2 = nocturnal_data['spo2'].dropna()
    day_spo2   = df_24h[df_24h['hour'].between(8, 20)]['spo2'].dropna()
    if not night_spo2.empty and not day_spo2.empty:
        features['nocturnal_drop'] = round(night_spo2.median() - day_spo2.median(), 2)
    else:
        features['nocturnal_drop'] = np.nan

    # 2. NEW & IMPROVED: Calculate TRUE ODI and events
    odi_results = calculate_odi(nocturnal_data[['timestamp', 'spo2']])
    features['odi'] = round(odi_results['odi'], 2)
    features['desat_events'] = odi_results['event_count'] # Renamed to reflect it's true "events"
    features['valid_nocturnal_hours'] = round(odi_results['valid_hours'], 2)

    # 3. IMPROVED: Calculate % time low (REUSING existing flags!)
    total_nocturnal_readings = len(nocturnal_data)
    if total_nocturnal_readings > 0:
        features['percent_time_below_90'] = (nocturnal_data['flag_spo2_low'].sum() / total_nocturnal_readings) * 100
        features['percent_time_below_88'] = (nocturnal_data['flag_spo2_critical'].sum() / total_nocturnal_readings) * 100
    else:
        features['percent_time_below_90'] = 0
        features['percent_time_below_88'] = 0

    # 4. IMPROVED: Heart Rate Variability (Standard Deviation)
    nocturnal_hr = nocturnal_data['heart_rate'].dropna()
    if not nocturnal_hr.empty:
        features['hr_sd'] = round(nocturnal_hr.std(), 2) # Key metric for variability
        features['hr_brady_episodes'] = int((nocturnal_hr < 60).sum()) # Keep as context
        features['hr_tachy_episodes'] = int((nocturnal_hr > 100).sum()) # Keep as context
    else:
        features['hr_sd'] = np.nan
        features['hr_brady_episodes'] = 0
        features['hr_tachy_episodes'] = 0

    # --- CLASSIFICATION LOGIC (Updated for correct metrics) ---
    reasons = []
    concern = "green"  # default

    # PRIMARY RULE: Based on true ODI and prolonged hypoxemia
    if features['odi'] >= 5.0 or features['percent_time_below_88'] > 5: # >5% time below 88% is severe
        concern = "red"
        reasons.append(f"Oxygen Desaturation Index (ODI) ≥5.0 ({features['odi']}/hr) or significant time with SpO₂ <88%")
    
    # SECONDARY RULE: Based on lower ODI or less severe hypoxemia
    elif features['odi'] >= 3.0 or features['percent_time_below_90'] > 10:
        concern = "orange"
        reasons.append(f"Moderate ODI ({features['odi']}/hr) or prolonged time with SpO₂ <90%")
    
    # TERTIARY RULE: Mild indicators
    elif features['odi'] > 0 or (features['nocturnal_drop'] is not np.nan and features['nocturnal_drop'] <= -2):
        concern = "yellow"
        reasons.append("Mild nocturnal desaturation detected")

    # MODIFIER RULE: HR Variability can upgrade concern one level (e.g., yellow to orange, orange to red)
    # Check if HR variability is high, suggesting autonomic arousal
    if features['hr_sd'] > 8.0 and concern != "red": 
        if concern == "green":
            concern = "yellow"
            reasons.append("Elevated nocturnal heart rate variability.")
        elif concern == "yellow":
            concern = "orange"
            reasons.append("Elevated nocturnal heart rate variability plus desaturation.")
        # if already orange, it stays orange (don't upgrade to red on HR alone)

    return {
        "concern": concern,
        "features": features,
        "reasons": reasons
    }