In [1]:
import os
import h5py
import mne
import pandas as pd
import numpy as np
from fractions import Fraction
from scipy.signal import resample_poly
from datetime import datetime, time
import matplotlib.pyplot as plt
from scipy.signal import filtfilt, find_peaks

import sys
project_root = os.path.dirname(os.getcwd())
src_path = os.path.join(project_root, "src")
sys.path.append(str(src_path))
from pipeline_io.read_annot import read_annot
from analysis.process_sleep_stages import process_sleep_stages
from analysis.process_signals import process_signals
from analysis.selecting_windows import selecting_windows
from features.extract_features import extract_features
from pipeline_io.save_features import save_features, save_features_wide

  import pkg_resources


### Raw code 

In [14]:
def detect_oxygen_desaturation(spo2, is_plot=False, duration_max=120, return_type='pd'):
    spo2_max = spo2[0]  # Initialize maximum SpO2 value
    spo2_max_index = 1  # Initialize index of maximum SpO2 value
    spo2_min = 100  # Initialize minimum SpO2 value
    des_onset_pred_set = np.array([], dtype=int)  # Collection of predicted desaturation onset points
    des_duration_pred_set = np.array([], dtype=int)  # Collection of predicted desaturation durations
    des_level_set = np.array([])  # Collection of recorded desaturation events (e.g., 2%, 3%, 4%, 5% drops, etc.)
    des_onset_pred_point = 0  # Predicted onset point of the current desaturation event
    des_flag = 0  # Flag indicating whether a desaturation event is occurring
    ma_flag = 0  # Flag indicating whether a motion artifact event is occurring
    spo2_des_min_thre = 2  # Minimum desaturation threshold (in %) to trigger detection
    spo2_des_max_thre = 50  # Motion artifact threshold (if SpO2 drops more than 50%, it's likely an artifact)
    duration_min = 5  # Minimum duration (in seconds) for a desaturation event to be recorded
    prob_end = []  # List to store probable end points of desaturation events

    for i, current_value in enumerate(spo2):

        des_percent = spo2_max - current_value  # Desaturation value

        # Detect motion artifacts
        if ma_flag and (des_percent < spo2_des_max_thre):
            if des_flag and len(prob_end) != 0:
                des_onset_pred_set = np.append(des_onset_pred_set, des_onset_pred_point)
                des_duration_pred_set = np.append(des_duration_pred_set, prob_end[-1] - des_onset_pred_point)
                des_level_point = spo2_max - spo2_min
                des_level_set = np.append(des_level_set, des_level_point)
            # Reset
            spo2_max = current_value
            spo2_max_index = i
            ma_flag = 0
            des_flag = 0
            spo2_min = 100
            prob_end = []
            continue

        # If desaturation value is greater than 2%, record the onset time
        if des_percent >= spo2_des_min_thre:
            if des_percent > spo2_des_max_thre:
                ma_flag = 1
            else:
                des_onset_pred_point = spo2_max_index
                des_flag = 1
                if current_value < spo2_min:
                    spo2_min = current_value

        if current_value >= spo2_max and not des_flag:
            spo2_max = current_value
            spo2_max_index = i

        elif des_flag:

            if current_value > spo2_min:
                if current_value > spo2[i - 1]:
                    prob_end.append(i)

                # Locate consecutive SpO2 drop points
                if current_value <= spo2[i - 1] < spo2[i - 2]:
                    spo2_des_duration = prob_end[-1] - spo2_max_index

                    # If the drop duration is too short, it is not considered a desaturation event
                    if spo2_des_duration < duration_min:
                        spo2_max = spo2[i - 2]
                        spo2_max_index = i - 2
                        spo2_min = 100
                        des_flag = 0
                        prob_end = []
                        continue

                    else:
                        # If the drop duration meets the requirement, record this desaturation event
                        if duration_min <= spo2_des_duration <= duration_max:
                            des_onset_pred_set = np.append(des_onset_pred_set, des_onset_pred_point)
                            des_duration_pred_set = np.append(des_duration_pred_set, spo2_des_duration)
                            des_level_point = spo2_max - spo2_min
                            des_level_set = np.append(des_level_set, des_level_point)

                        # If the drop duration is too long, it indicates multiple desaturation events that need to be recorded separately
                        else:
                            # Record the first desaturation event
                            des_onset_pred_set = np.append(des_onset_pred_set, des_onset_pred_point)
                            des_duration_pred_set = np.append(des_duration_pred_set, prob_end[0] - des_onset_pred_point)
                            des_level_point = spo2_max - spo2_min
                            des_level_set = np.append(des_level_set, des_level_point)

                            # Recheck for possible desaturation events
                            remain_spo2 = spo2[prob_end[0]:i + 1]
                            _onset, _duration, _des_level = detect_oxygen_desaturation(remain_spo2, is_plot=False, return_type='tuple')
                            des_onset_pred_set = np.append(des_onset_pred_set, _onset + prob_end[0])
                            des_duration_pred_set = np.append(des_duration_pred_set, _duration)
                            des_level_set = np.append(des_level_set, _des_level)

                        spo2_max = spo2[i - 2]
                        spo2_max_index = i - 2
                        spo2_min = 100
                        des_flag = 0
                        prob_end = []

    if is_plot:
        fig, [ax1, ax2] = plt.subplots(2, 1, sharex=True)
        # ground truth annotation
        ax1.plot(spo2, 'b')
        ax1.set_title('ground truth')
        for i in range(0, len(od_start)):
            s = od_start[i]
            e = od_start[i] + od_duration[i]
            se = [item for item in range(s, e + 1)]
            ax1.plot(s, spo2[s], color='r', linestyle='none', marker='o')
            ax1.plot(se, spo2[se], color='r', linestyle='-')
            ax1.plot(e, spo2[e], color='r', linestyle='none', marker='*')

        # predicted annotation
        ax2.plot(spo2, 'b')
        ax2.set_title('prediction')
        for i in range(0, len(des_onset_pred_set)):
            s = des_onset_pred_set[i]
            e = des_onset_pred_set[i] + des_duration_pred_set[i]
            se = [item for item in np.arange(s, e + 1)]
            ax2.plot(s, spo2[s], color='r', linestyle='none', marker='o')
            ax2.plot(se, spo2[se], color='r', linestyle='-')
            ax2.plot(e, spo2[e], color='r', linestyle='none', marker='*')

        plt.show()

    return pd.DataFrame(data={'onset':des_onset_pred_set, 'duration':des_duration_pred_set, 'desaturation':des_level_set})

In [15]:
B = [0.000109398212241, 0.000514594526374, 0.001350397179936, 0.002341700062534,
     0.002485940327008, 0.000207543145171, -0.005659450344228, -0.014258087808069,
     -0.021415481383353, -0.019969417749860, -0.002425120103463, 0.034794452821365,
     0.087695691366900, 0.144171828095816, 0.187717212244959, 0.204101948813338,
     0.187717212244959, 0.144171828095816, 0.087695691366900, 0.034794452821365,
     -0.002425120103463, -0.019969417749860, -0.021415481383353, -0.014258087808069,
     -0.005659450344228, 0.000207543145171, 0.002485940327008, 0.002341700062534,
     0.001350397179936, 0.000514594526374, 0.000109398212241]

BAD_SPO2_THRESHOLD = 80


def filter_spo2(spo2_arr, spo2_sfreq, event_end_time, verbose=False, time_span=120):
    # Replace abnormal values with the mean
    spo2_mean = np.mean(spo2_arr[spo2_arr >= BAD_SPO2_THRESHOLD])
    spo2_arr[spo2_arr < BAD_SPO2_THRESHOLD] = spo2_mean

    if spo2_sfreq != 1:
        spo2_arr = nk.signal_resample(spo2_arr, sampling_rate=spo2_sfreq, desired_sampling_rate=1)
        spo2_sfreq = 1

    # Reduce SpO₂ jitter and adjust SpO₂ resolution to 0.5
    spo2_filtered = filtfilt(B, 1, spo2_arr, axis=0, padtype='odd')
    spo2_filtered *= 2
    spo2_filtered = np.round(spo2_filtered) / 2

    if verbose:
        with h5py.File(row["h5_path"], "r") as f:
            abd_signal = np.array(f["signals/RESP/RESP_ABDOMINAL"])     
            sfreq_abd = f["signals/RESP/RESP_ABDOMINAL"].attrs.get('fs', None)
            airflow_signal = np.array(f["signals/RESP/RESP_AIRFLOW"])     
            sfreq_airflow = f["signals/RESP/RESP_AIRFLOW"].attrs.get('fs', None)
        
        start = int(event_end_time - time_span)
        end = int(event_end_time + time_span)
        abd = abd_signal[start * sfreq_abd : end * sfreq_abd]
        af = airflow_signal[start * sfreq_airflow: end * sfreq_airflow]
        tt = np.arange(len(spo2_arr))/spo2_sfreq
        tt_abd = np.arange(len(abd))/sfreq_abd
        tt_af = np.arange(len(af))/sfreq_airflow

        plt.close()
        fig = plt.figure()
        ax = fig.add_subplot(311); ax0 = ax
        ax.plot(tt, spo2_arr)
        ax.plot(tt, spo2_filtered)
        ax.axvline(x=time_span/spo2_sfreq)
        ax = fig.add_subplot(312, sharex=ax0)
        ax.plot(tt_abd, abd)
        ax = fig.add_subplot(313, sharex=ax0)
        ax.plot(tt_af, af)

        plt.show()
        #plt.savefig(f'./img/b/{idx}.png')
    return spo2_filtered

    
def calc_hypoxic_burden(event_times, spo2_arr, sfreq_spo2, verbose=False, time_span=120):
    # Assume the duration of a respiratory event is 10–120 s; the maximum delay of hypoxemia caused by a respiratory event is 120 s
    all_ah_related_spo2 = []
    good_event_ids = []
    for ei, et in enumerate(event_times):
        nearby_spo2 = spo2_arr[int(et - time_span) * sfreq_spo2: int(et + time_span) * sfreq_spo2]
        if len(nearby_spo2) < 2*time_span*sfreq_spo2 \
                or np.mean(nearby_spo2 < BAD_SPO2_THRESHOLD)>0.3:
            continue
        filtered_spo2 = filter_spo2(nearby_spo2, sfreq_spo2, et, verbose, time_span)
        assert sfreq_spo2 == 1, f"Unexpected SpO₂ frequency: {sfreq_spo2}"
        all_ah_related_spo2.append(filtered_spo2)
        good_event_ids.append(ei)
    
    # Get average drop (start and end on average curve)
    all_spo2_dest = np.array(all_ah_related_spo2)
    avg_spo2 = all_spo2_dest.mean(axis=0)
    avg_spo2 = filtfilt(B, 1, avg_spo2, axis=0, padtype='odd')
    peaks, _ = find_peaks(avg_spo2)
    start_secs = peaks[np.where(peaks < time_span)[0][-1]]
    end_secs = peaks[np.where(peaks > time_span)[0][0]]
    if verbose:
        x = np.arange(len(avg_spo2))
        plt.close()
        plt.plot(x, avg_spo2)
        plt.plot(x[start_secs], avg_spo2[start_secs], "o")
        plt.plot(x[end_secs], avg_spo2[end_secs], "*", markersize=10)
        plt.axvline(x=time_span)
        plt.title(f"{spo2_name}")
        plt.show()

    burdens = []
    for spo2_dest_curve in all_spo2_dest:
        baseline_spo2 = np.max(spo2_dest_curve[time_span - 100:time_span])
        interest_spo2 = spo2_dest_curve[start_secs: end_secs]
        burdens.append( sum(baseline_spo2 - interest_spo2)/60 )
    #per_ah_event_burden = total_burden / len(all_spo2_dest)
    #total_burden += per_ah_event_burden * (len(event_times) - len(all_spo2_dest))
    #total_sleep_time_in_hours = anno.get_total_sleep_time() / 60 / 60
    #return total_burden / total_sleep_time_in_hours
    res = pd.DataFrame(data={'EventTime':event_times})
    res.loc[good_event_ids, 'HB'] = burdens
    return res


### Main
Apply full pipeline on local data (mastersheet with 2 subs)

In [16]:
def hb_per_stages(df_hb, full_sleep_stages, sfreq_global):
    # Add Sleep Stages 
    full_sleep_stages_sec = full_sleep_stages[::int(sfreq_global)]
    ids = np.clip(df_hb['EventTime'].astype(int).values, 0, len(full_sleep_stages_sec)-1)
    df_hb['Stage'] = full_sleep_stages_sec[ids]

    # Start and end sleep in sec
    sleep_ids = np.where(np.isin(full_sleep_stages, [1, 2, 3, 4]))[0]
    sleep_start = int(sleep_ids[0] / sfreq_global)
    sleep_end = int((sleep_ids[-1] + 1) / sfreq_global)

    # Total HB during sleep (per hour)
    total_HB = df_hb['HB'][(df_hb['EventTime'] >= sleep_start) & (df_hb['EventTime'] < sleep_end)].sum()
    total_sleep_hours = (sleep_end - sleep_start) / 3600
    HB_per_hour = total_HB / total_sleep_hours

    # NREM hypoxic burden
    mask_nrem = np.in1d(df_hb['Stage'], [1,2,3])
    if mask_nrem.sum() > 0:
        total_nrem_HB = df_hb['HB'][mask_nrem].sum()
        nrem_hours = np.in1d(full_sleep_stages_sec[int(sleep_start):int(sleep_end)], [1,2,3]).sum() / 3600
        HB_NREM_per_hour = total_nrem_HB / nrem_hours
    else:
        HB_NREM_per_hour = 0

    # REM hypoxic burden
    mask_rem = np.in1d(df_hb['Stage'], [4])
    if mask_rem.sum() > 0:
        total_rem_HB = df_hb['HB'][mask_rem].sum()
        rem_hours = np.in1d(full_sleep_stages_sec[int(sleep_start):int(sleep_end)], [4]).sum() / 3600
        HB_REM_per_hour = total_rem_HB / rem_hours
    else:
        HB_REM_per_hour = 0
    
    return HB_per_hour, HB_NREM_per_hour, HB_REM_per_hour
    
def extract_hypoxic_burden(psg_id, row, df_events, sleep_stages):
    # Read SpO2 signal 
    with h5py.File(row["h5_path"], "r") as f:
        spo2_signal = np.array(f["signals/SPO2/SpO2"])
        sfreq_spo2 = f["signals/SPO2/SpO2"].attrs.get("fs", None)
        sfreq_global = row['sfreq_global']
        spo2_signal = spo2_signal[::int(sfreq_global / sfreq_spo2)] 

    # Extract apnea events
    df_apnea = (
        df_events[df_events["event_type"].astype(str).str.contains("pnea", case=False, na=False)]
        .copy()
        .reset_index(drop=True)
    )
    df_apnea["end_time"] = df_apnea["onset"].astype(float) + df_apnea["duration"].astype(float)
    apnea_event_times = np.array(df_apnea["end_time"])

    if len(apnea_event_times) < 2:
        print(f"{psg_id}: len(apnea_event_times)={len(apnea_event_times)} < 2 → skipping")
        return None

    # Compute HB apnea
    df_apnea_hb = calc_hypoxic_burden(apnea_event_times, spo2_signal, sfreq_spo2)
    HB_per_hour_apnea, HB_NREM_per_hour_apnea, HB_REM_per_hour_apnea = hb_per_stages(df_apnea_hb, full_sleep_stages, sfreq_global)

    # Extract desaturation events
    df_desat = (
        df_events[
            df_events["event_type"].astype(str).str.contains("desat", case=False, na=False)
            & ~df_events["event_type"].astype(str).str.contains("artifac", case=False, na=False)
        ]
        .copy()
        .reset_index(drop=True)
    )
    if len(df_desat) == 0: # If no desat events found → detect automatically
        df_desat = detect_oxygen_desaturation(spo2_signal, is_plot=False)
    df_desat["mid_time"] = df_desat["onset"].astype(float) + (df_desat["duration"].astype(float)/2)
    desat_event_times = np.array(df_desat["mid_time"])

    if len(desat_event_times) < 2:
        print(f"{psg_id}: len(desat_event_times)={len(desat_event_times)} < 2 → skipping")
        return None

    # Compute HB desaturation
    df_desat_hb = calc_hypoxic_burden(desat_event_times, spo2_signal, sfreq_spo2)   
    HB_per_hour_desat, HB_NREM_per_hour_desat, HB_REM_per_hour_desat = hb_per_stages(df_desat_hb, full_sleep_stages, sfreq_global)

    return HB_per_hour_apnea, HB_NREM_per_hour_apnea, HB_REM_per_hour_apnea, HB_per_hour_desat, HB_NREM_per_hour_desat, HB_REM_per_hour_desat

In [17]:
# get the row from mastershet
path_mastersheet = "/Users/alicealbrecht/Desktop/UCSF/pipeline_project/datasets/mastersheets/mros_ses01_mastersheet.csv"
mastersheet = pd.read_csv(path_mastersheet)
mastersheet = mastersheet.rename(columns={"subject_id":"sub_id"})
mastersheet["event_path"] = (
    mastersheet["h5_path"]
    .str.replace("hdf5", "events", regex=False)
    .str.replace("signals.h5", "events.csv", regex=False)
)
row = mastersheet.iloc[0]
psg_id = f"sub_{row['sub_id']}_ses-{row['session']}"

# Get df_events and sleep stages
full_sleep_stages, df_events = read_annot(row, "mros")
df_events.to_csv(row['event_path'], index = False)
sleep_stages, sleep_onset_time = process_sleep_stages(full_sleep_stages, row['sfreq_global'], row['start_time'], verbose=False)

HB_per_hour_apnea, HB_NREM_per_hour_apnea, HB_REM_per_hour_apnea, HB_per_hour_desat, HB_NREM_per_hour_desat, HB_REM_per_hour_desat = extract_hypoxic_burden(psg_id, row, df_events, sleep_stages)
print(HB_per_hour_apnea, HB_NREM_per_hour_apnea, HB_REM_per_hour_apnea, HB_per_hour_desat, HB_NREM_per_hour_desat, HB_REM_per_hour_desat)

145.62916291629162 183.45718654434248 290.1368421052632 194.82088208820883 185.30275229357798 248.1684210526316
