# Legacy Sleep Wave Project Code

#### Initial Setup, Imports and Visuals

Using "Fpz-Cz"-- voltage difference bewteen electroide at center of forehead and top of head (this can be swapped to different electrodes)

In [None]:
import numpy as np
import pandas as pd
import mne 
import matplotlib.pyplot as plt

psg_file = "../data/sleep_waves/ST7011J0-PSG.edf"
hypnogram = "../data/sleep_waves/ST7011JP-Hypnogram.edf"

# Loading new EEG data in
raw = mne.io.read_raw(psg_file, preload=True)
annotations = mne.read_annotations(hypnogram)
raw.set_annotations(annotations) 

# Focus on EEG channel: 
# Fpz-Cz -- voltage diff between electrode at forehead center and top of head
# One of two best for sleep data
raw.pick(['EEG Fpz-Cz'])

raw.plot(duration=60, n_channels=1, title='Raw EEG');

# The warnings that come up are not a concern: 

#### Preprocessing the Signal

EEG contains a huge range of freqs-- but sleep-related brain activity is concentrated in slow to moderate freqs. So by setting our limits to 0.5 - 40 Hz, we're filtering out all the non-sleep related (outside movement, muscle twitches, etc.)  

We're using "firwin" (Finite Impulse Response) because it's useful for isolating clean frequency bands

In [None]:
raw.filter(0.5, 40., fir_design='firwin')

#### Segment into 30s Epochs (Sleep Scoring Standard)

30 second epochs have been the longstanding clinical and industry standard for measuring and classifying sleep-related brainwaves  

For clarity's sake, sleep stages do *not* change every second... they tend to last minutes at least. However, 30s windows are great for catching transitions between the stages with minimal lag  

**A thought**: if we encounter difficulty with classification, we could experiment with 60s window (would get more stable features)... BUT that runs the real risk of accidently mixing stages.

In [None]:
# Each annotation = 30s
# We'll match that in epoching for clarity's sake (industry standard)

events, event_id = mne.events_from_annotations(raw)
epochs = mne.Epochs(raw, events=events, event_id=event_id, tmin=0, tmax=30, baseline=None, preload=True)


epochs[10].plot(title='Epoch 10: Raw EEG');

#### Extract Sleep Stage Labels

Mapping MNE annotations to the standard numeric sleep stages:

In [None]:
def extract_labels_from_events(events: np.ndarray, event_id: dict):
    '''
    Maps MNE event integer codes to standardized sleep labels 

    0: Wake
    1: N1
    2: N2
    3: N3 (includes Stage 4)
    4: REM

    Args: 
        events (np.ndarray): array of shape (n_epochs, 3) from MNE Epochs.events
        event_id (dict): Mapping from annotation strings to event codes
    '''

    reverse_event_map = {v: k for k, v in event_id.items()} 

    stage_map = {
        'Sleep stage W': 'Wake',
        'Sleep stage 1': 'N1',
        'Sleep stage 2': 'N2',
        'Sleep stage 3': 'N3',
        'Sleep stage 4': 'N3',  # Merged with Stage 3
        'Sleep stage R': 'REM'
    }
    
    labels = []
    for e in events: 
        raw_label = reverse_event_map.get(e[-1]) 
        if raw_label in stage_map: 
            labels.append(stage_map[raw_label])

        else: 
            # Dropping unknown / invalid labels
            labels.append(-1) 

    return labels


#### Feature Extraction (Band Power):

Computing avergae power in key freq bands per epoch

In [None]:
from mne.time_frequency import psd_array_welch

# Get raw epoch data as numpy array
epoch_data = epochs.get_data()  # shape: (n_epochs, n_channels, n_times)
sfreq = epochs.info['sfreq']    # typically 100 Hz for Sleep-EDF

# Compute Power Spectral Density
psds, freqs = psd_array_welch(epoch_data, sfreq=sfreq, fmin=0.5, fmax=40, n_fft=256)

def bandpower(psds: np.ndarray, freqs: np.ndarray, band: tuple): 
    ''' 
    Computes the mean power in a specific freq band for all epochs

    Args: 
        psds (np.array): Power spectral densities, shape (n_epochs, n_channels, n_freqs)
        freqs (np.ndarray): Array of freq values
        band (tuple): Lower and upper bounds of thr freq band (0.5-40Hz) 
    '''

    low, high = band
    idx = (freqs >= low) & (freqs <= high) 
    
    return psds[:, :, idx].mean(axis=-1).mean(axis=1) 


In [None]:
features = pd.DataFrame({
    'delta': bandpower(psds, freqs, (0.5, 4)), 
    'theta': bandpower(psds, freqs, (4, 8)), 
    'alpha': bandpower(psds, freqs, (8, 13)), 
    'beta': bandpower(psds, freqs, (13, 30))
})

labels = extract_labels_from_events(epochs.events, event_id)
features['label'] = labels
features = features[features['label'] != -1]
features