In [None]:
import os
import json
import numpy as np
import pandas as pd
from datetime import datetime, timedelta
from tqdm import tqdm
import xml.etree.ElementTree as ET
from scipy import signal
import glob
from scipy.stats import pearsonr
from sklearn.metrics import mean_absolute_error, mean_squared_error
import matplotlib.pyplot as plt
from scipy.signal import resample
import pyedflib
import shutil
from xml.dom import minidom
import lxml.etree as ET

# Necessary function

In [None]:
resp_events = {"Apnea Obstructive", "Apnea Central", "Apnea Mixed", "Hypopnea"}
sleep_events = {"Wake", "N1", "N2", "N3", "REM"}

def str_to_datetime(time_str):
    return datetime.strptime(time_str, '%Y/%m/%d %I:%M:%S %p')

def SBD_event_label(predictions, stages, artifacts, desat, arousal, min_consecutive=100):
    # Ensure predictions, stages, artifacts, desat, and arousal are lists
    predictions = list(predictions)
    stages = list(stages)
    artifacts = list(artifacts)
    desat = list(desat)
    arousal = list(arousal)
    
    # Ensure min_consecutive is an integer
    if isinstance(min_consecutive, list):
        min_consecutive = min_consecutive[0] if min_consecutive else 100
    min_consecutive = int(min_consecutive)
    
    total_length = len(predictions)
    if total_length == 0 or len(stages) != total_length or len(artifacts) != total_length or len(desat) != total_length or len(arousal) != total_length:
        return predictions
        
    # Process obstructive apnea events (1s)
    i = 0
    while i < total_length:
        if predictions[i] == 1:
            event_start = i
            while i < total_length and predictions[i] == 1:
                i += 1
            event_end = i
            event_length = event_end - event_start
            
            if event_length < min_consecutive or (0 not in stages[event_start:event_end]):
                predictions[event_start:event_end] = [0] * event_length
        else:
            i += 1
            
    # Process central apnea events (1.1s)
    i = 0
    while i < total_length:
        if predictions[i] == 1.1:
            event_start = i
            while i < total_length and predictions[i] == 1.1:
                i += 1
            event_end = i
            event_length = event_end - event_start
            
            if event_length < min_consecutive or (0 not in stages[event_start:event_end]):
                predictions[event_start:event_end] = [0] * event_length
        else:
            i += 1
            
    # Process mixed apnea events (1.2s)
    i = 0
    while i < total_length:
        if predictions[i] == 1.2:
            event_start = i
            while i < total_length and predictions[i] == 1.2:
                i += 1
            event_end = i
            event_length = event_end - event_start
            
            if event_length < min_consecutive or (0 not in stages[event_start:event_end]):
                predictions[event_start:event_end] = [0] * event_length
        else:
            i += 1
    
    # Process hypopnea events (2s)
    i = 0
    while i < total_length:
        if predictions[i] == 2:
            event_start = i
            while i < total_length and predictions[i] == 2:
                i += 1
            event_end = i
            event_length = event_end - event_start
            
            if event_length < min_consecutive or (0 not in stages[event_start:event_end]) or (1 in artifacts[event_start:event_end]):
                predictions[event_start:event_end] = [0] * event_length
            else:
                # Check for desaturation or arousal
                # MrOS1: desat[event_end:min(total_length, event_end+300), # MrOS2: desat[event_end:min(total_length, event_end+450)
                if not (1 in desat[event_end:min(total_length, event_end+450)] or 1 in arousal[event_end:min(total_length, event_end+50)]):
                    predictions[event_start:event_end] = [0] * event_length
        else:
            i += 1
    
    return predictions

In [None]:
def prettify_event(elem):
    rough_string = ET.tostring(elem, encoding='utf-8').decode('utf-8')
    reparsed = minidom.parseString(rough_string)
    # Remove indentation but keep newlines, and make sure tags are on separate lines
    return "\n".join([line.strip() for line in reparsed.toprettyxml(indent="").splitlines() if line.strip()])

def get_start_time(event):
    try:
        start_elem = event.find('Start')
        if start_elem is not None and start_elem.text is not None:
            return float(start_elem.text)
        else:
            print(f"Warning: Event found without valid Start time: {ET.tostring(event, encoding='unicode')}")
            return float('inf')
    except ValueError as e:
        print(f"Error parsing start time: {e}")
        return float('inf')

def find_continuous_sequences(lst):
    sequences = []  # List to store (start, end, value) tuples
    start = None    # Variable to store the start index of a sequence
    current_value = None  # Variable to track the current sequence value

    for i in range(len(lst)):
        if lst[i] in [1.0, 1.1, 1.2, 2.0]:  # Check for the new set of possible numbers
            if start is None:  # Start a new sequence
                start = i
                current_value = lst[i]
            elif lst[i] != current_value:  # If the value changes, end the sequence
                if i - start >= 2:  # Only store if the sequence is at least 2 long
                    sequences.append((start, i - 1, current_value))  # Append (start, end, value)
                start = i  # Start a new sequence
                current_value = lst[i]
        else:
            if start is not None:  # If the sequence ends because of a non-matching value
                if i - start >= 2:  # Only store if the sequence is at least 2 long
                    sequences.append((start, i - 1, current_value))  # Append (start, end, value)
                start = None
                current_value = None

    # Check if there's an open sequence at the end of the list
    if start is not None and len(lst) - start >= 2:
        sequences.append((start, len(lst) - 1, current_value))  # Append the final sequence

    return sequences

# Conversion code

In [None]:
"""
Following conversion code will produce two outputs:
  1. 1s apnea & hypopnea label, which may be useful for practical analysis
  2. Adjusted xml annotation file leaving only those apnea and hypopnea that satisfy the AASM guideline
"""

# Define directory for original annotation file and edf file
anno_dir = '{dir}/polysomnography/annotations-events-nsrr/visit1' # visit1 or visit2 depending on your dataset
edf_dir = '{dir}/polysomnography/edfs/visit1' # visit1 or visit2 depending on your dataset

# Define directory to save the adjusted annotation file as necessary
save_dir = ''

## In MrOS1, following patients have corrupted edf or annotation file:
problem_list = ['AA2159', 'AA2180', 'AA4156', 'AA0236', 'AA0425', 'AA0461', 'AA1477', 'AA2123', 'AA2339', 'AA2608', 'AA3354', 'AA3386', 'AA4712', 'AA4858', 'AA4860',
                'AA2033', 'AA3905', 'AA4225', 'AA4398', 'AA4928', 'AA5622', 'AA5660', 'AA1407', 'AA2359', 'AA4490', 'AA5311', 'AA5760', 'AA2930', 'AA4998']

## Conversion code
patient_ID = '' # ID of choice

## XML file processing: change visit1 or visit2 as necessary
xml_path = f'{anno_dir}/mros-visit1-{patient_ID}-nsrr.xml' # Change path as necessary
edf_path = f'{edf_dir}/mros-visit1-{patient_ID}.edf' # Change path as necessary

tree = ET.parse(xml_path)
root = tree.getroot()

stages = []
stages_sleep = []
total_duration = 0
start_time = None

for event in root.findall('.//ScoredEvent'):
    event_type = event.find('EventType').text
    event_concept = event.find('EventConcept').text
    start = float(event.find('Start').text)
    duration = float(event.find('Duration').text)
    clock_time = event.find('ClockTime').text if event.find('ClockTime') is not None else None

    if event_concept == 'Recording Start Time':
        start_time = clock_time

    if 'Stage' in event_concept or "Wake" in event_concept or "REM" in event_concept:
        stage_number = int(event_concept.split('|')[-1])
        num_epochs = int(duration / 30)
        stages.extend([stage_number] * num_epochs)
        total_duration += duration


## EDF load to extract SpO2 channel, which is necessary for determining whether particular hypopnea event is associated with desaturation (nadir within 30-45 seconds of hypopnea termination)
f = pyedflib.EdfReader(edf_path)
signal_labels = [f.getLabel(i) for i in range(f.signals_in_file)]

## SaO2 label for visit1 and visit2
SaO2_label = 'SaO2' # Visit1
#SaO2_label = ['SaO2', 'SpO2'] # Visit1

sao2_idx = signal_labels.index(SaO2_label)
sao2_freq = f.getSampleFrequency(sao2_idx)

total_sao2_samples = int(total_duration * sao2_freq)
sao2_raw = f.readSignal(sao2_idx)[:total_sao2_samples]

# Multiplication by 10 because original annotation records time at a decimal level. For exact precision, 
sao2 = np.array([item for item in sao2_raw for _ in range(10)])
total_duration = total_duration * 10

## Empty dataframe to store the time index of respiratory event, sleep stage, artifact, desaturation, and arousal
df_resp = pd.DataFrame({
    'Index': range(int(total_duration)),
    'Resp_Event': [0] * int(total_duration)
})

df_sleep_stage = pd.DataFrame({
    'Index': range(int(total_duration)),
    'Stage': [0] * int(total_duration)
})

df_artifact = pd.DataFrame({
    'Index': range(int(total_duration)),
    'artifact': [0] * int(total_duration)
})

df_desat = pd.DataFrame({
    'Index': range(int(total_duration)),
    'desat': [0] * int(total_duration)
})

df_arousal = pd.DataFrame({
    'Index': range(int(total_duration)),
    'arousal': [0] * int(total_duration)
})

## Processing of original XML file to fill in time index
for event in root.findall('.//ScoredEvent'):
    event_type = event.find('EventType').text
    event_concept = event.find('EventConcept').text
    start = float(event.find('Start').text)
    duration = float(event.find('Duration').text)

    start_time = int(10 * start)
    end_time = int(10 * (start + duration))

    if "apnea" in event_concept.lower():
        if duration >= 10:
            mask = (df_resp['Index'] >= start_time) & (df_resp['Index'] <= end_time)
            # Store each type of apnea with different index
            if "obstructive" in event_concept.lower():
                df_resp.loc[mask, 'Resp_Event'] = 1.0
            if "central" in event_concept.lower():
                df_resp.loc[mask, 'Resp_Event'] = 1.1
            if "mixed" in event_concept.lower():
                df_resp.loc[mask, 'Resp_Event'] = 1.2

    if "hypopnea" in event_concept.lower() or "unsure" in event_concept.lower():
        if duration >= 10:
            mask = (df_resp['Index'] >= start_time) & (df_resp['Index'] <= end_time)
            df_resp.loc[mask, 'Resp_Event'] = 2

    if "Wake" in event_concept:
        mask = (df_sleep_stage['Index'] >= start_time) & (df_sleep_stage['Index'] <= end_time)
        df_sleep_stage.loc[mask, 'Stage'] = 1 # 1 if wake, otherwise 0

    if "artifact" in event_concept:
        mask = (df_artifact['Index'] >= start_time) & (df_artifact['Index'] <= end_time)
        df_artifact.loc[mask, 'artifact'] = 0

    if "Arousal" in event_concept:
        mask = (df_arousal['Index'] >= start_time) & (df_arousal['Index'] <= end_time)
        df_arousal.loc[mask, 'arousal'] = 1 # 1 if Arousal, otherwise 0

    if "desaturation" in event_concept.lower():

        spo2_nadir_report = float(event.find('SpO2Nadir').text)
        spo2_baseline_report = float(event.find('SpO2Baseline').text)
        
        # Change baseline SpO2 to 100 if the reported value exceed 100
        if spo2_baseline_report >= 100:
            spo2_baseline_report = 100

        if start_time > len(sao2):
            continue
        
        # Mark nadir time 
        if (spo2_baseline_report - spo2_nadir_report) >= 3:
            if duration < 2:
                df_desat.iloc[start_time, df_desat.columns.get_loc('desat')] = 1
            else:
                filtered_indices = np.where(sao2[start_time:end_time] > 10)[0]  
                if filtered_indices.size > 0:
                    min_index_within_slice = filtered_indices[sao2[start_time:end_time][filtered_indices].argmin()]
                    min_index = start_time + min_index_within_slice  # Adjusting index to match the original array
                else:
                    min_index = None  # No values > 10 found
                df_desat.iloc[min_index, df_desat.columns.get_loc('desat')] = 1

## Create 1s annotation of apneas and hypopneas 
list_resp = df_resp['Resp_Event']
list_stage = df_sleep_stage['Stage']
list_artifact = df_artifact['artifact']
list_desat = df_desat['desat']
list_arousal = df_arousal['arousal']

SBD_event = SBD_event_label(list_resp, list_stage, list_artifact, list_desat, list_arousal) # 1s annotation

######################### Creating new XML file from SBD_event #########################
time_info = find_continuous_sequences(SBD_event)   
    
## Make copy of original NSRR file
xml_copy = f'{save_dir}/mros-visit1-{patient_ID}-nsrr-adjusted.xml' # Change path as necessary
shutil.copy2(xml_path, xml_copy)

tree = ET.parse(xml_copy)
root = tree.getroot()

stages = []
stages_sleep = []
total_duration = 0
start_time = None

for event in root.findall('.//ScoredEvent'):
    event_type = event.find('EventType').text
    event_concept = event.find('EventConcept').text
    start = float(event.find('Start').text)
    duration = float(event.find('Duration').text)
    clock_time = event.find('ClockTime').text if event.find('ClockTime') is not None else None

    if event_concept == 'Recording Start Time':
        start_time = clock_time

    if 'Stage' in event_concept or "Wake" in event_concept or "REM" in event_concept:
        stage_number = int(event_concept.split('|')[-1])
        num_epochs = int(duration / 30)
        stages.extend([stage_number] * num_epochs)
        total_duration += duration
    
    # Replace Unsure with Hypopnea
    if event_concept == "Unsure|Unsure":
        event.find('EventConcept').text = "Hypopnea|Hypopnea"

# Initially remove all SBD events annotation from the file
for parent in root.findall('.//ScoredEvent/..'):  # Find parent of ScoredEvent
    for event in parent.findall('.//ScoredEvent'):
        event_concept = event.find('EventConcept').text if event.find('EventConcept') is not None else None
        start = float(event.find('Start').text)
        duration = float(event.find('Duration').text)

        if event_concept and ("apnea" in event_concept.lower() or "hypopnea" in event_concept.lower() or "unsure" in event_concept.lower()):
            parent.remove(event)

## Fill in updated SBD events into the xml file            
new_events = []
for start, end, val in time_info:
    # Map value to the appropriate event concept
    if val == 1.0:
        concept = 'Obstructive apnea|Obstructive Apnea'
    elif val == 1.1:
        concept = 'Central apnea|Central Apnea'
    elif val == 1.2:
        concept = 'Mixed apnea|Mixed Apnea'
    elif val == 2.0:
        concept = 'Hypopnea|Hypopnea'

    # Create a new event dictionary
    new_event = {
        "EventType": "Respiratory|Respiratory",
        "EventConcept": concept,
        "Start": start / 10,  # Convert start time to correct scale
        "Duration": (end - start) / 10,  # Calculate and scale duration
    }

    # Append the new event to the new_events list
    new_events.append(new_event)

for event in new_events:
    scored_event = ET.Element('ScoredEvent')

    # Create XML structure for each event
    for key, value in event.items():
        element = ET.SubElement(scored_event, key)
        element.text = str(value)

    # Convert the single scored event to a string with proper indentation
    formatted_event = prettify_event(scored_event)

    # Append the prettified event directly to the XML file
    root.append(ET.fromstring(formatted_event))

# Add newline between each ScoredEvent in the XML tree
for i, event in enumerate(root.findall('ScoredEvent')[:-1]):
    next_event = root.findall('ScoredEvent')[i + 1]
    event.tail = "\n"  # Adds a newline between </ScoredEvent> and <ScoredEvent>

########## Rewriting xml files ##########

# Parse the XML file
all_scored_events = root.findall('.//ScoredEvent')
#print(f"Found {len(all_scored_events)} total events to sort")


# Sort all events
all_scored_events.sort(key=get_start_time)

# Remove all existing events from their parent elements
for event in all_scored_events:
    parent = event.getparent()  # or use parent = root if direct children of root
    if parent is not None:
        parent.remove(event)

# Re-append all sorted events to the root
for event in all_scored_events:
    root.append(event)
    event.tail = "\n"  # Add newline for formatting

# Save the sorted XML file
tree.write(xml_copy, encoding='utf-8', xml_declaration=True)