In [None]:
import mne
import json
from datetime import datetime
from pathlib import Path

# --- Load and crop the EDF file ---
edf_path = Path("data/test_real/test.edf")
raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False)
raw.crop(tmin=20)

# --- Save the cropped data into a new EDF ---
output_path = edf_path.parent / "test_from_20s.edf"
raw.export(output_path, fmt="edf", overwrite=True)
print(f"Cropped EDF saved to: {output_path}")

# --- Extract metadata ---
info = raw.info
start_time = raw.info['meas_date']  # datetime object or None
if start_time is None:
    # fallback to "unknown" date if EDF header has no time
    start_time = datetime(1970, 1, 1)
    
duration = raw.times[-1]  # in seconds
end_time = start_time.timestamp() + duration

# --- Format times as ISO strings ---
start_time_iso = start_time.isoformat()
end_time_iso = datetime.fromtimestamp(end_time).isoformat()

# --- Build dictionary ---
project_dict = {
    "projectName": "real data test",
    "currentTime": 0,
    "signals": [
        {
            "signalName": "EEG",
            "startTime": start_time_iso,
            "endTime": end_time_iso,
            "edfFile": str(output_path.relative_to(output_path.parent)),
            "channels": info['ch_names'],
            "visible": True
        }
    ],
    "annotations": []
}


# EEG-only sleep staging with explicit spindle, K-complex, and slow-wave detectors (AASM-inspired)

Uses explicit detectors:
- Spindles: 11–16 Hz envelope thresholding with duration check.
- K-complexes: 0.5–4 Hz negative peak with rebound, ≥75 µV, ≥0.5 s.
- Slow waves: 0.5–2 Hz full-cycle p2p ≥75 µV, 0.5–2.0 s; N3 if coverage ≥20% of epoch.

Chooses frontal (F*) for slow waves and central (C*) for spindles when available; otherwise averages.

In [None]:
import re
import json
import numpy as np
import mne
from scipy.signal import butter, filtfilt, hilbert, find_peaks
from tqdm import tqdm

# ----------------------- Config -----------------------
edf_file = "your_file.edf"     # <- set path
epoch_len = 30.0               # seconds
target_sfreq = 200.0           # resample for stable detectors
hp_lp = (0.3, 35.0)            # prefilter
ref = "average"                # "average" or None

# thresholds in microvolts (MNE is in Volts internally)
SW_MIN_P2P_UV = 75.0           # AASM: ≥75 µV
K_MIN_P2P_UV  = 75.0
SP_MIN_DUR = 0.5               # s
SP_MAX_DUR = 2.0               # s
K_MIN_DUR  = 0.5               # s (neg-to-pos complex window)
SW_MIN_DUR = 0.5               # s (full wave 0.5–2 Hz → 0.5–2.0 s)
SW_MAX_DUR = 2.0               # s
N3_MIN_COVERAGE = 0.20         # ≥20% of epoch covered by slow waves

# --------------------- Helpers ------------------------
def bp_filter(x, sf, lo, hi, order=4):
    b, a = butter(order, [lo/(sf/2), hi/(sf/2)], btype="band")
    return filtfilt(b, a, x)

def lp_filter(x, sf, hi, order=4):
    b, a = butter(order, hi/(sf/2), btype="low")
    return filtfilt(b, a, x)

def pick_channel(raw, patterns):
    names = raw.ch_names
    for p in patterns:
        for n in names:
            if re.search(p, n, flags=re.I):
                return n
    return None

def uV(x):  # convert V -> µV
    return x * 1e6

def merge_intervals(intervals, gap=0.1):
    if not intervals: return []
    intervals.sort()
    merged = [intervals[0]]
    for s,e in intervals[1:]:
        if s - merged[-1][1] <= gap:
            merged[-1][1] = max(merged[-1][1], e)
        else:
            merged.append([s,e])
    return merged

# ---------------- Feature Detectors -------------------
def detect_spindles(sig, sf):
    x = bp_filter(sig, sf, 11, 16)
    env = np.abs(hilbert(x))
    thr = np.mean(env) + 1.5*np.std(env)
    above = env > thr
    # intervals where envelope is above threshold
    idx = np.where(above)[0]
    if idx.size == 0: return []
    # group contiguous indices
    runs = np.split(idx, np.where(np.diff(idx) != 1)[0] + 1)
    sp = []
    for r in runs:
        s, e = r[0]/sf, (r[-1]+1)/sf
        dur = e - s
        if SP_MIN_DUR <= dur <= SP_MAX_DUR:
            sp.append([s, e])
    return merge_intervals(sp, gap=0.1)

def detect_k_complexes(sig, sf):
    x = bp_filter(sig, sf, 0.5, 4)
    # negative peak followed by positive deflection, p2p >= 75 µV within ~1.0 s
    neg_peaks, _ = find_peaks(-x, distance=int(0.25*sf))  # avoid clustering
    kc = []
    for nidx in neg_peaks:
        win = int(1.0*sf)
        s = max(0, nidx - int(0.2*sf))
        e = min(len(x)-1, nidx + win)
        seg = x[s:e]
        if seg.size < int(0.3*sf): 
            continue
        neg_amp = np.min(seg)
        # positive rebound after the negative peak
        post = x[nidx:e]
        if post.size < int(0.2*sf):
            continue
        pos_amp = np.max(post)
        p2p = uV(pos_amp - neg_amp)
        if p2p >= K_MIN_P2P_UV:
            # duration = from neg peak to subsequent pos peak
            pos_peaks, _ = find_peaks(post)
            if pos_peaks.size:
                pidx = pos_peaks[0] + nidx
                dur = (pidx - nidx)/sf
                if dur >= K_MIN_DUR:
                    kc.append([nidx/sf, pidx/sf])
    return merge_intervals(kc, gap=0.15)

def detect_slow_waves(sig, sf):
    x = bp_filter(sig, sf, 0.5, 2.0)
    # zero-crossing based full-waves
    zc = np.where(np.diff(np.signbit(x)))[0]
    sw = []
    for i in range(len(zc)-2):
        s0 = zc[i]      # zero up/down
        s1 = zc[i+2]    # next same-phase crossing -> full cycle
        dur = (s1 - s0)/sf
        if SW_MIN_DUR <= dur <= SW_MAX_DUR:
            seg = x[s0:s1]
            mn = np.min(seg); mx = np.max(seg)
            p2p = uV(mx - mn)
            if p2p >= SW_MIN_P2P_UV:
                sw.append([s0/sf, s1/sf])
    return merge_intervals(sw, gap=0.2)

def time_covered(intervals, t0, t1):
    if not intervals: return 0.0
    clipped = [[max(s,t0), min(e,t1)] for s,e in intervals if e>t0 and s<t1]
    merged = merge_intervals(clipped, gap=0.0)
    return sum(e-s for s,e in merged)

# -------------------- Pipeline ------------------------
raw.pick_types(eeg=True)
if ref == "average":
    raw.set_eeg_reference("average")
raw.filter(hp_lp[0], hp_lp[1], fir_design="firwin", verbose=False)
if target_sfreq:
    raw.resample(target_sfreq, verbose=False)

sf = raw.info["sfreq"]
n_samp = raw.n_times
dur = n_samp / sf

# choose channels per physiology (fallbacks to mean if missing)
ch_sw = pick_channel(raw, [r"^F3", r"^F4", r"^Fz", r"F\d"])
ch_sp = pick_channel(raw, [r"^C3", r"^C4", r"^Cz", r"C\d"])

data = raw.get_data()  # shape: (n_ch, n_times)
if ch_sw is not None:
    sw_sig = data[raw.ch_names.index(ch_sw)]
else:
    sw_sig = data.mean(axis=0)

if ch_sp is not None:
    sp_sig = data[raw.ch_names.index(ch_sp)]
else:
    sp_sig = data.mean(axis=0)

# alpha/theta power (for W/N1 support)
def rel_bandpower(x, sf, bands=((0.5,4,'delta'),(4,8,'theta'),(8,13,'alpha'),(11,16,'sigma'))):
    from scipy.signal import welch
    f, Pxx = welch(x, sf, nperseg=int(sf*2))
    out = {}
    tot = 0.0
    for lo,hi,name in bands:
        mask = (f>=lo)&(f<=hi)
        p = np.trapz(Pxx[mask], f[mask])
        out[name] = p
        tot += p
    for k in list(out.keys()):
        out[k] = out[k]/(tot + 1e-12)
    return out

# epoching + staging
n_epochs = int(dur // epoch_len)
results = []

for i in tqdm(range(n_epochs)):
    t0 = i*epoch_len
    t1 = (i+1)*epoch_len
    s0, s1 = int(t0*sf), int(t1*sf)

    sw_ep = sw_sig[s0:s1]
    sp_ep = sp_sig[s0:s1]
    mean_ep = data[:, s0:s1].mean(axis=0)

    sp_intervals = detect_spindles(sp_ep, sf)
    kc_intervals = detect_k_complexes(mean_ep, sf)  # K-complex can be broad; use mean
    sw_intervals = detect_slow_waves(sw_ep, sf)

    # coverage for N3
    sw_cov = time_covered(sw_intervals, 0, epoch_len) / epoch_len

    # spectral support
    rel = rel_bandpower(mean_ep, sf)

    # Decision rules (EEG-only, AASM-inspired)
    if sw_cov >= N3_MIN_COVERAGE:
        label = "N3"
    elif (len(sp_intervals) > 0) or (len(kc_intervals) > 0):
        label = "N2"
    elif rel["alpha"] > 0.35:
        label = "W"
    elif rel["theta"] > 0.25:
        label = "N1"
    else:
        label = "REM/N1?"

    results.append({"label": label, "startSec": float(t0), "endSec": float(t1)})

merged_results = []
if results:
    current = results[0].copy()

    for r in results[1:]:
        if r["label"] == current["label"]:
            # extend the current interval
            current["endSec"] = r["endSec"]
        else:
            # save the finished interval
            merged_results.append(current)
            current = r.copy()
    merged_results.append(current)  # don't forget the last interval

merged_results[:8], len(merged_results)


In [None]:
project_dict["annotations"].append(
    {
        "name": "Sleep stages",
        "opacity": .2,
        "visible": True,
        "events": merged_results,
    }
)

# Simple artifacts detector

- Eye movements/blinks (EOG): Uses `find_eog_events` (detects spikes in EOG/frontal leads).
- Heartbeats (ECG): Uses `find_ecg_events` if ECG channel exists.
- Muscle activity (EMG, incl. swallowing/chewing): Detects broadband high-frequency bursts (30–100 Hz Hilbert envelope $ > 5  \sigma$).

In [None]:
import mne
import numpy as np

# raw.set_eeg_reference("average")

# -----------------------
# Filtering
# -----------------------
raw.filter(0.5, 45., fir_design="firwin")  # bandpass for artifact detection

# -----------------------
# 1. Detect eye blinks / movements from frontal channels
# -----------------------
# Use frontal electrodes as EOG surrogates
frontal_chs = ["EEG FP1-A1", "EEG FP2-A2", "EEG FPZ-A1"]

# Pick frontal data
raw_frontal = raw.copy().pick_channels(frontal_chs)

# Band-pass filter to blink-dominant range (~1-10 Hz)
raw_frontal.filter(1., 10., fir_design="firwin")

# Get absolute signal and detect peaks
data = raw_frontal.get_data().mean(axis=0)
threshold = data.mean() + 4 * data.std()

eog_artifacts = []
in_artifact = False
for i, sample in enumerate(data):
    if sample > threshold and not in_artifact:
        in_artifact = True
        start = raw.times[i]
    elif sample <= threshold and in_artifact:
        in_artifact = False
        end = raw.times[i]
        eog_artifacts.append({"label": "EOG", "startSec": start, "endSec": end})

# -----------------------
# 2. Detect cardiac artifacts (if ECG channel available)
# -----------------------
try:
    ecg_epochs = mne.preprocessing.create_ecg_epochs(raw, baseline=(None, -0.2))
    ecg_events, _, _ = mne.preprocessing.find_ecg_events(raw)

    ecg_artifacts = []
    for onset, _, _ in ecg_events:
        start = raw.times[onset]
        ecg_artifacts.append({"label": "ECG", "startSec": start, "endSec": start + 0.6})
except Exception as e:
    ecg_artifacts = []
    print("No ECG channel detected:", e)

# -----------------------
# 3. Detect muscle artifacts (heuristic: high-frequency power > threshold)
# -----------------------
# Use 30–100 Hz band to detect EMG bursts

sfreq = raw.info["sfreq"]
nyq = sfreq / 2.0

# Target EMG band ~30–100 Hz, but ensure high_cut < Nyquist and > low_cut
low_cut = 30.0
high_cut = min(100.0, nyq - 1.0)  # keep a 1-Hz safety margin

# If sampling rate is low and the band collapses, widen downwards a bit
if high_cut <= low_cut + 1.0:
    # Keep at least some high-frequency range; fall back to ~0.3*Nyq–(Nyq-1)
    high_cut = max(nyq - 1.0, 5.0)
    low_cut = max(5.0, min(0.3 * nyq, high_cut - 5.0))

raw_emg = (
    raw.copy()
       .pick_types(eeg=True)                 # only EEG channels
       .filter(low_cut, high_cut, fir_design="firwin")
)

# Hilbert envelope (per channel), then average across channels
emg_data = raw_emg.get_data()                # shape: (n_channels, n_times)
envelope = np.abs(hilbert(emg_data, axis=1)).mean(axis=0)

# Robust thresholding
thr = envelope.mean() + 5 * envelope.std()
above = envelope > thr

# Extract continuous segments above threshold
muscle_artifacts = []
in_art = False
for i, flag in enumerate(above):
    if flag and not in_art:
        in_art = True
        start = raw.times[i]
    elif not flag and in_art:
        in_art = False
        end = raw.times[i]
        muscle_artifacts.append({"label": "EMG", "startSec": start, "endSec": end})

# Optional: merge short gaps and drop tiny detections
min_len = 0.05  # seconds
merged = []
for seg in muscle_artifacts:
    if not merged or seg["startSec"] - merged[-1]["endSec"] > 0.05:
        merged.append(seg)
    else:
        merged[-1]["endSec"] = seg["endSec"]
muscle_artifacts = [s for s in merged if (s["endSec"] - s["startSec"]) >= min_len]


# -----------------------
# 4. Combine results
# -----------------------
artifacts = eog_artifacts + ecg_artifacts + muscle_artifacts

project_dict["annotations"].append(
    {
        "name": "artifacts",
        "opacity": .2,
        "visible": True,
        "events": artifacts
    }
)

In [None]:
# --- Save to file ---
output_path =  edf_path.parent / "real_test.vembproj.json"
with output_path.open("w", encoding="utf-8") as f:
    json.dump(project_dict, f, indent=2)

print(f"Saved project file to {output_path}")