In [None]:
from dataclasses import dataclass
from typing import Iterable
from dateutil import parser
from enum import Enum
from pathlib import Path
from datetime import datetime

import intervaltree
import xml.etree.ElementTree as ET
from tqdm import tqdm

sns = set()
def read_apple_health(path: Path, source_names: Iterable[str]) -> dict[str, list[ET.Element]]:
    tree = ET.parse(path)
    result = {sn: [] for sn in source_names}
    for elem in tqdm(list(tree.getroot().iter())):
        sn = elem.attrib.get("sourceName")
        if sn:
            sns.add(sn)
        if sn in source_names:
            result[sn].append(elem)
    return result


root_dir = Path("/Users/vzuev/Documents/git/gh_zuevval/video_eeg/data/ashih")
apple_health_sleep_routine_path: Path = root_dir / "apple_export_sleep_routine.xml"
apple_health_iphone_13_path: Path = root_dir / "apple_export_iph_13.xml"
fitbit_path: Path = root_dir / "fitbit-sleep-2025-05-03.json"

sleep_entries = read_apple_health(apple_health_sleep_routine_path, ["Sleep Routine"])
iph13_entries = read_apple_health(apple_health_iphone_13_path, ["Sleep Routine", "Apple\xa0Watch\xa0— Maria"])

In [None]:
datetime_start = datetime(year=2025, month=5, day=27, hour=19)
datetime_end = datetime(year=2025, month=5, day=28, hour=10)

In [None]:
class State(Enum):
    awake = "awake"
    light = "light"
    deep = "deep"
    rem = "rem"


@dataclass
class SleepStage:
    state: State
    interval: intervaltree.Interval

apple_sleep_stages_map: dict[str, State] = {
    "HKCategoryValueSleepAnalysisAwake": State.awake,
    "HKCategoryValueSleepAnalysisAsleepCore": State.light,
    "HKCategoryValueSleepAnalysisAsleepDeep": State.deep,
    "HKCategoryValueSleepAnalysisAsleepREM": State.rem,
}

def apple_to_sleep_stages(xml_elems: Iterable[ET.Element]) -> list[SleepStage]:
    result: list[SleepStage] = []
    for entry in xml_elems:
        type = entry.attrib["type"]
        if type != "HKCategoryTypeIdentifierSleepAnalysis":
            continue
        start = parser.parse(entry.attrib["startDate"]).replace(tzinfo=None)
        end = parser.parse(entry.attrib["endDate"]).replace(tzinfo=None)
        if end < datetime_start or start > datetime_end:
            continue
        interval = intervaltree.Interval(start, end)
        stage_str = entry.attrib["value"]
        stage = apple_sleep_stages_map.get(stage_str)
        if stage:
            result.append(SleepStage(stage, interval))
        elif stage_str == "HKCategoryValueSleepAnalysisInBed":  # category for total time in bed
            pass
        else:
            print(f"unrecognized category: {stage_str}")
    return result


stages_sleep_routine = apple_to_sleep_stages(sleep_entries["Sleep Routine"])
print(f"Sleep Routine: {len(stages_sleep_routine)} intervals total")  # 44
stages_sleep_routine_iph13 = apple_to_sleep_stages(iph13_entries["Sleep Routine"])
print(f"Sleep Routine - iPh13: {len(stages_sleep_routine_iph13)}") # 42
stages_wathces = apple_to_sleep_stages(iph13_entries["Apple\xa0Watch\xa0— Maria"])
print(f"Apple Watches entries: {len(stages_wathces)}")

In [None]:
import json
from datetime import datetime, timedelta
with fitbit_path.open() as json_file:
    fitbit_json = json.load(json_file)

fitbit_state_map = {
    "awake": State.awake, # TODO inspect: any differences between wake and awake?
    "wake": State.awake,
    "light": State.light,
    "deep": State.deep,
    "rem": State.rem
}

stages_fitbit: list[SleepStage] = []
for day_entry in fitbit_json:
    for entry in day_entry['levels']['data']:
        start = parser.parse(entry["dateTime"]).replace(tzinfo=None)
        end = start + timedelta(seconds=entry["seconds"])
        interval = intervaltree.Interval(start, end)
        level: str = entry["level"]
        state = fitbit_state_map.get(level)
        if state:
            stages_fitbit.append(SleepStage(state, interval))
        elif level not in ("restless", "asleep"):  # restless, asleep - stages when Fitbit fails to calculate sleep stages properly
            print(f"unrecognized level: {level}")
            

print(f"Fitbit: total {len(stages_fitbit)} intervals") # 24

Copilot prompt:

<blockquote>

Write a code which finds all common sub-intervals for stages_sleep_routine, stages_sleep_routine_iph13, stages_watches and stages_fitbit, i.e. all longest possible time intervals in which all four sleep stage markups agree on the sleep stage. For example, if one list contains an interval of type Awake from 3:30 to 4:50, another from 2:50 to 4:20 and two remaining - from 2:45 to 4:10, then the result should contain an interval from 2:50 to 4:20 of type awake.
You are encouraged to use intervaltree if necessary

</blockquote>

Copilot response was slightly modified



In [None]:
def find_common_sleep_intervals(*stage_lists):
    # Collect all unique boundaries
    boundaries = set()
    for stages in stage_lists:
        for s in stages:
            boundaries.add(s.interval.begin)
            boundaries.add(s.interval.end)
    boundaries = sorted(boundaries)

    # Build interval trees for fast lookup
    trees = [intervaltree.IntervalTree.from_tuples((s.interval.begin, s.interval.end, s.state) for s in stages) for stages in stage_lists]

    result = []
    prev_state = None
    prev_start = None

    for i in range(len(boundaries) - 1):
        start, end = boundaries[i], boundaries[i+1]
        if start == end:
            continue
        states = []
        for tree in trees:
            intervals = tree[start]
            if not intervals:
                break
            # All intervals at this point should have the same state (by construction)
            states.append(next(iter(intervals)).data)
        if len(states) == len(stage_lists) and all(s == states[0] for s in states):
            # Extend previous interval if same state
            if prev_state == states[0] and prev_start is not None:
                # continue current interval
                pass
            else:
                # start new interval
                prev_start = start
                prev_state = states[0]
            # If this is the last boundary, close the interval
            if i == len(boundaries) - 2:
                result.append((prev_start, end, prev_state))
        else:
            if prev_state is not None and prev_start is not None:
                result.append((prev_start, start, prev_state))
                prev_start = None
                prev_state = None
    return result

common_intervals = find_common_sleep_intervals(
    stages_sleep_routine,
    stages_sleep_routine_iph13,
    stages_wathces,
    stages_fitbit
)

for start, end, state in common_intervals:
    print(f"{start} - {end}: {state.name}")