In [1]:
import pandas as pd
import numpy as np
from datetime import datetime, timedelta

In [2]:
THRESH_SPO2_LOW = 92
THRESH_SPO2_CRITICAL = 88
THRESH_SHOCK_INDEX_WARNING = 0.9
THRESH_SHOCK_INDEX_CRITICAL = 1.0
THRESH_MAP_LOW = 65  # Mean Arterial Pressure threshold for shock
THRESH_LACTATE_HIGH = 2.0  # mmol/L for tissue hypoperfusion
THRESH_LACTATE_CRITICAL = 4.0
THRESH_UO_LOW = 0.5  # mL/kg/hr (oliguria)

# Trend Analysis
TREND_WINDOW = 6  # Number of readings for short-term trend analysis


# AGE-SPECIFIC VITAL SIGN THRESHOLDS (Low, Normal, High)

AGE_THRESHOLDS = {
    'neonate': {
        'rr_low': 30, 'rr_normal': 40, 'rr_high': 60,
        'hr_low': 100, 'hr_normal': 140, 'hr_high': 160,
        'sbp_low': 60, 'sbp_normal': 70, 'sbp_high': 90,
        'temp_low': 36.0, 'temp_normal': 37.2, 'temp_high': 38.0
    },
    'infant': {
        'rr_low': 24, 'rr_normal': 30, 'rr_high': 40,
        'hr_low': 80, 'hr_normal': 120, 'hr_high': 140,
        'sbp_low': 70, 'sbp_normal': 85, 'sbp_high': 100,
        'temp_low': 36.0, 'temp_normal': 37.2, 'temp_high': 38.0
    },
    'child': {
        'rr_low': 16, 'rr_normal': 20, 'rr_high': 30,
        'hr_low': 70, 'hr_normal': 90, 'hr_high': 110,
        'sbp_low': 80, 'sbp_normal': 95, 'sbp_high': 110,
        'temp_low': 36.0, 'temp_normal': 37.0, 'temp_high': 38.0
    },
    'adolescent': {
        'rr_low': 12, 'rr_normal': 16, 'rr_high': 20,
        'hr_low': 60, 'hr_normal': 75, 'hr_high': 100,
        'sbp_low': 90, 'sbp_normal': 105, 'sbp_high': 120,
        'temp_low': 35.8, 'temp_normal': 36.8, 'temp_high': 37.8
    },
    'adult': {
        'rr_low': 12, 'rr_normal': 16, 'rr_high': 20,
        'hr_low': 60, 'hr_normal': 80, 'hr_high': 100,
        'sbp_low': 90, 'sbp_normal': 115, 'sbp_high': 130,
        'temp_low': 35.5, 'temp_normal': 36.8, 'temp_high': 38.0
    },
    'geriatric': {  # Added for elderly patients who may have different baselines
        'rr_low': 12, 'rr_normal': 16, 'rr_high': 24, # Often higher RR baseline
        'hr_low': 55, 'hr_normal': 70, 'hr_high': 90, # Often lower HR
        'sbp_low': 90, 'sbp_normal': 125, 'sbp_high': 140, # Often higher SBP
        'temp_low': 35.5, 'temp_normal': 36.5, 'temp_high': 37.5 # Often lower temp
    }
}


SCORE_THRESHOLDS = {
    "hypotensive_shock": {  
        "yellow": 3,   # Monitor/Review
        "orange": 6,   # Urgent
        "red": 9       # Critical
    },
}

In [3]:
def assign_age_category(df):
    """
    Assigns an age category based on the 'age' column in the DataFrame.
    Now includes more granular pediatric categories and a geriatric category.
    """
    df = df.copy()
    
    def _categorize(age):
        if age <= 0.083: return 'neonate'     # < 1 month
        elif age <= 1:   return 'infant'      # 1 month - 1 year
        elif age < 5:    return 'child'       # 1 - 5 years
        elif age < 13:   return 'adolescent'  # 5 - 12 years
        elif age < 65:   return 'adult'       # 13 - 64 years
        else:            return 'geriatric'   # 65+ years
    
    if 'age' in df.columns:
        df['age_category'] = df['age'].apply(_categorize)
    else:
        # Default to adult if age is not provided
        df['age_category'] = 'adult'
    
    return df

In [4]:
def apply_vital_range_flags(df):
    """
    Applies age-specific thresholds to flag abnormal vital signs.
    Now includes flags for all parameters needed across pipelines.
    """
    df = df.copy()
    df = assign_age_category(df)  # Inject age group

    # SpO₂ flags (Absolute threshold)
    df['flag_spo2_low'] = df['spo2'] < THRESH_SPO2_LOW
    df['flag_spo2_critical'] = df['spo2'] < THRESH_SPO2_CRITICAL

    # Temperature flags
    df['flag_temp_high'] = df.apply(lambda row: row['temperature'] >= AGE_THRESHOLDS[row['age_category']]['temp_high'], axis=1)
    df['flag_temp_low'] = df.apply(lambda row: row['temperature'] < AGE_THRESHOLDS[row['age_category']]['temp_low'], axis=1)

    # Respiratory Rate flags
    df['flag_rr_low'] = df.apply(lambda row: row['resp_rate'] < AGE_THRESHOLDS[row['age_category']]['rr_low'], axis=1)
    df['flag_rr_high'] = df.apply(lambda row: row['resp_rate'] >= AGE_THRESHOLDS[row['age_category']]['rr_high'], axis=1)

    # Heart Rate flags
    df['flag_hr_low'] = df.apply(lambda row: row['heart_rate'] < AGE_THRESHOLDS[row['age_category']]['hr_low'], axis=1)
    df['flag_hr_high'] = df.apply(lambda row: row['heart_rate'] >= AGE_THRESHOLDS[row['age_category']]['hr_high'], axis=1)

    # Calculate Shock Index (handle division by zero)
    df['shock_index'] = df['heart_rate'] / np.clip(df['sbp'], a_min=1, a_max=None)    
    # Flag based on Shock Index
    df['flag_si_warning'] = df['shock_index'] >= THRESH_SHOCK_INDEX_WARNING
    df['flag_si_critical'] = df['shock_index'] >= THRESH_SHOCK_INDEX_CRITICAL

    # Blood Pressure flags
    df['flag_sbp_low'] = df.apply(lambda row: row['sbp'] < AGE_THRESHOLDS[row['age_category']]['sbp_low'], axis=1)
    df['flag_sbp_high'] = df.apply(lambda row: row['sbp'] >= AGE_THRESHOLDS[row['age_category']]['sbp_high'], axis=1)
    df['flag_dbp_low'] = df.apply(lambda row: row['dbp'] < (AGE_THRESHOLDS[row['age_category']]['sbp_low'] * 0.6), axis=1) # Estimate DBP low
    df['flag_dbp_high'] = df.apply(lambda row: row['dbp'] >= (AGE_THRESHOLDS[row['age_category']]['sbp_high'] * 0.6), axis=1) # Estimate DBP high

    return df

In [5]:
def compute_recent_trends_delta(df):
    """
    Computes trends for each vital by differencing consecutive readings.
    Applies stricter interpretation using age-specific thresholds.
    """
    df = df.copy().sort_values("timestamp").reset_index(drop=True)

    if 'age_category' not in df.columns:
        df = assign_age_category(df)

    trends = {}
    recent = df.tail(TREND_WINDOW)
    age_group = recent['age_category'].iloc[-1]
    thresholds = AGE_THRESHOLDS[age_group]

    vital_map = {
        'resp_rate': ('rr_low', 'rr_normal', 'rr_high'),
        'heart_rate': ('hr_low', 'hr_normal', 'hr_high'),
        'sbp': ('sbp_low', 'sbp_normal', 'sbp_high'),
        'temperature': ('temp_low', 'temp_normal', 'temp_high'),
        'spo2': (None, None, None)  # handled separately
    }

    for vital in ['resp_rate', 'heart_rate', 'sbp', 'temperature', 'spo2']:
        if vital not in recent.columns or recent[vital].isnull().all():
            continue

        y = recent[vital].dropna().values
        if len(y) < 2:
            continue

        avg_delta = np.mean(np.diff(y))
        latest = y[-1]
        trends[f"{vital}_trend"] = round(avg_delta, 3)

        if vital == 'spo2':
            if latest < THRESH_SPO2_LOW:
                if avg_delta > 0:
                    flag = "Still abnormal — but improving"
                elif avg_delta < 0:
                    flag = "Abnormal and worsening"
                else:
                    flag = "Abnormal and flat"
            else:
                if avg_delta < 0:
                    flag = "Normal but deteriorating"
                else:
                    flag = "Normal and stable"

        else:
            low_key, norm_key, high_key = vital_map[vital]
            low = thresholds[low_key]
            normal = thresholds[norm_key]
            high = thresholds[high_key]

            if latest < low or latest > high:
                if (latest > high and avg_delta < 0) or (latest < low and avg_delta > 0):
                    flag = "Still abnormal — but improving"
                else:
                    flag = "Abnormal and worsening"
            else:
                if avg_delta < 0:
                    flag = "Normal but deteriorating"
                else:
                    flag = "Normal and stable"

        trends[f"{vital}_trend_flag"] = flag

    # Shock Index trend
    if all(col in recent.columns for col in ['heart_rate', 'sbp']):
        hr = recent['heart_rate'].values
        sbp = np.clip(recent['sbp'].values, a_min=1, a_max=None)
        si = hr / sbp

        if len(si) >= 2:
            avg_si_delta = np.mean(np.diff(si))
            trends['shock_index_trend'] = round(avg_si_delta, 3)

            latest_si = si[-1]
            if latest_si >= THRESH_SHOCK_INDEX_WARNING:
                flag = "Shock Index high — improving" if avg_si_delta < 0 else "Shock Index high — worsening"
            else:
                flag = "Normal but improving" if avg_si_delta < 0 else "Normal but rising"

            trends['shock_index_trend_flag'] = flag

    return trends

In [6]:
def compute_rmssd(arr):
    """Compute RMSSD (Root Mean Square of Successive Differences)"""
    diffs = np.diff(arr.dropna())
    return np.sqrt(np.mean(diffs**2)) if len(diffs) > 1 else 0

In [7]:
def detect_af_suspect(df, hr_col='heart_rate', time_col='timestamp', 
                     rmssd_thresh=40, hr_tachy_thresh=120,
                     persistence_window_minutes=240, # 4 hours for irregularity
                     tachycardia_minutes=30):        # 30 min for tachycardia
    """
    AF suspect detection based on HR variability and tachycardia with proper persistence logic.
    
    Parameters:
        df : DataFrame with timestamp and HR columns
        hr_col : str, column name for heart rate
        time_col : str, column name for timestamp
        rmssd_thresh : float, RMSSD threshold (ms) for irregularity
        hr_tachy_thresh : float, HR threshold for tachycardia (bpm)
        persistence_window_minutes : int, time window to analyze for persistence (minutes)
        tachycardia_minutes : int, required duration of tachycardia (minutes)
    """
    
    # Make sure we have a datetime index or column
    df = df.copy()
    if time_col in df.columns:
        df = df.set_index(time_col)
    
    # Calculate rolling RMSSD
    df['RMSSD'] = df[hr_col].rolling(window=5, min_periods=3).apply(compute_rmssd, raw=False)
    
    # Create minute-by-minute flags
    df['flag_irregular'] = df['RMSSD'] > rmssd_thresh
    df['flag_tachycardia'] = df[hr_col] >= hr_tachy_thresh
    
    # Get the most recent data within our persistence window
    latest_time = df.index.max()
    analysis_start = latest_time - pd.Timedelta(minutes=persistence_window_minutes)
    recent_data = df[df.index >= analysis_start]
    
    if recent_data.empty:
        return {'af_suspect': False, 'flags': {}, 'status': 'Insufficient data'}, df
    
    # Calculate persistence - count minutes where flags are True (handle NaN values)
    irregular_persistence_min = recent_data['flag_irregular'].fillna(False).sum()
    tachy_persistence_min = recent_data['flag_tachycardia'].fillna(False).sum()
    
    result = {
        'af_suspect': False,
        'flags': {
            'high_hr_persistent': tachy_persistence_min >= tachycardia_minutes,
            'irregularity_persistent': irregular_persistence_min >= (persistence_window_minutes * 0.5),
            'tachycardia_duration_min': tachy_persistence_min,
            'irregularity_duration_min': irregular_persistence_min
        },
        'status': 'No AF suspicion',
        'metrics': {
            'avg_hr': recent_data[hr_col].mean(),
            'avg_rmssd': recent_data['RMSSD'].mean(),
            'analysis_window_min': persistence_window_minutes,
            'data_points_in_window': len(recent_data)
        }
    }
    
    # Combination logic with proper persistence requirements
    if (result['flags']['irregularity_persistent'] and 
        result['flags']['high_hr_persistent']):
        result['af_suspect'] = True
        result['status'] = f"High AF suspicion: Irregular for {irregular_persistence_min}min, tachy for {tachy_persistence_min}min"
    
    elif result['flags']['irregularity_persistent']:
        result['status'] = f"Irregular rhythm for {irregular_persistence_min}min — possible controlled AF"
    
    elif result['flags']['high_hr_persistent']:
        result['status'] = f"Sustained tachycardia for {tachy_persistence_min}min — check for causes"
    
    # Return only the necessary columns to avoid large DataFrames
    result_df = df[['heart_rate', 'RMSSD', 'flag_irregular', 'flag_tachycardia']].copy()
    
    return result, result_df

In [8]:
# Sample data
sample_data = pd.DataFrame({
    'timestamp': pd.date_range('2023-01-01', periods=300, freq='min'),
    'heart_rate': np.random.normal(130, 15, 300),  # Simulated AF pattern
    'age': [70] * 300
})

# Run AF detection
af_result, processed_data = detect_af_suspect(sample_data)

print(f"AF Suspect: {af_result['af_suspect']}")
print(f"Status: {af_result['status']}")

AF Suspect: False
Status: Sustained tachycardia for 178min — check for causes


In [9]:
# Test data that should trigger AF suspect
test_af_data = pd.DataFrame({
    'timestamp': pd.date_range('2023-01-01', periods=300, freq='min'),
    'heart_rate': np.random.normal(130, 30, 300),  # High mean + high variability
    'age': [70] * 300
})

af_result, _ = detect_af_suspect(test_af_data)
print(f"AF Suspect: {af_result['af_suspect']}")
print(f"Status: {af_result['status']}")

AF Suspect: True
Status: High AF suspicion: Irregular for 141min, tachy for 161min
