In [1045]:
import numpy as np
import mne
from mne.datasets.sleep_physionet.age import fetch_data

## Loading Data

In [1048]:
missing_subjects = [36,52,39, 68, 69, 78, 79]
all_subjects_train = [s for s in range(0, 4) if s not in missing_subjects]
all_subjects_test  = [s for s in range(4, 6) if s not in missing_subjects]
subjects = all_subjects_train + all_subjects_test   
data = fetch_data(subjects=subjects, recording=[1])


Using default location ~/mne_data for PHYSIONET_SLEEP...


## Reading data

In [1051]:
def read_data(fname):
    raw = mne.io.read_raw_edf(
        fname,
        stim_channel="Event marker",
        infer_types=True,
        preload=True,
        verbose="error")
    return raw
    

## Extracting training epochs
We will work only with 5 stages: Wake (W), Stage 1, Stage 2, Stage 3/4, and REM sleep (R). 


In [1054]:
annotation_desc_2_event_id = {
    "Sleep stage W": 1,
    "Sleep stage 1": 2,
    "Sleep stage 2": 3,
    "Sleep stage 3": 4,
    "Sleep stage 4": 4,
    "Sleep stage R": 5,
}
# create a new event_id that unifies stages 3 and 4
event_id = {
    "Sleep stage W": 1,
    "Sleep stage 1": 2,
    "Sleep stage 2": 3,
    "Sleep stage 3/4": 4,
    "Sleep stage R": 5,
}

In [1056]:
# specific frequency bands
FREQ_BANDS = {
    "delta": [0.5, 4.5],
    "theta": [4.5, 8.5],
    "alpha": [8.5, 11.5],
    "sigma": [11.5, 15.5],
    "beta": [15.5, 30],
}

In [1058]:
def extract_events(raw):
    events, _ = mne.events_from_annotations(
    raw, event_id=annotation_desc_2_event_id, chunk_duration=30.0)
    return events

In [1060]:
# Returns an array of 5 columns, for each one is associated one band, and one frequency
def eeg_power_band(epochs):
    spectrum = epochs.compute_psd(picks="eeg", fmin=0.5, fmax=30.0)
    psds, freqs = spectrum.get_data(return_freqs=True)
    # Normalize the PSDs
    psds /= np.sum(psds, axis=-1, keepdims=True)
    X = []
    for fmin, fmax in FREQ_BANDS.values():
        psds_band = psds[:, :, (freqs >= fmin) & (freqs < fmax)].mean(axis=-1)
        X.append(psds_band.reshape(len(psds), -1))
    return np.concatenate(X, axis=1)

In [1062]:
annotation_desc_2_event_id = {
    "Sleep stage W": 1,
    "Sleep stage 1": 2,
    "Sleep stage 2": 3,
    "Sleep stage 3": 4,
    "Sleep stage 4": 4,
    "Sleep stage R": 5,
}
all_epochs_train = []  
all_y_train=[]
for sid in all_subjects_train:  
    # Reading data
    raw_train = read_data(data[sid][0])
    # Setting annotations
    annot_train = mne.read_annotations(data[sid][1])
    annot_train.crop(annot_train[1]["onset"] - 30 * 60, annot_train[-2]["onset"] + 30 * 60)
    raw_train.set_annotations(annot_train, emit_warning=False)
    #Extracting events
    events_train= extract_events(raw_train)
    # Extracting epochs 
    tmax = 30.0 - 1.0 / raw_train.info["sfreq"]  
    epochs_train = mne.Epochs(
        raw=raw_train,
        events=events_train,
        event_id=event_id,
        tmin=0.0,
        tmax=tmax,
        baseline=None,
    )
    y=epochs_train.events[:,2]
    
    epochs=eeg_power_band(epochs_train)
    all_epochs_train.append(epochs)
    all_y_train.append(y)

Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage 4', 'Sleep stage R', 'Sleep stage W']
Not setting metadata
841 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 841 events and 3000 original time points ...
0 bad epochs dropped
    Using multitaper spectrum estimation with 7 DPSS windows
Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage 4', 'Sleep stage R', 'Sleep stage W']
Not setting metadata
1103 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 1103 events and 3000 original time points ...
0 bad epochs dropped
    Using multitaper spectrum estimation with 7 DPSS windows
Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage 4', 'Sleep stage R', 'Sleep stage W']
Not setting metadata
1025 matching events found
No baseline 

## Extracting testing epochs

In [1064]:
all_epochs_test = []  
all_y_test= []
for sid in all_subjects_test:  
    # Reading data
    raw_test = read_data(data[sid][0])
    # Setting annotations
    annot_test = mne.read_annotations(data[sid][1])
    annot_test.crop(annot_test[1]["onset"] - 30 * 60, annot_test[-2]["onset"] + 30 * 60)
    raw_test.set_annotations(annot_test, emit_warning=False)
    #Extracting events
    events_test= extract_events(raw_test)
    # Extracting epochs 
    tmax = 30.0 - 1.0 / raw_test.info["sfreq"]  
    epochs_test = mne.Epochs(
        raw=raw_test,
        events=events_test,
        event_id=event_id,
        tmin=0.0,
        tmax=tmax,
        baseline=None,
    )
    y=epochs_test.events[:,2]
    epochs=eeg_power_band(epochs_test)
    all_epochs_test.append(epochs)
    all_y_test.append(y)

Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage R', 'Sleep stage W']
Not setting metadata
1235 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 1235 events and 3000 original time points ...
0 bad epochs dropped
    Using multitaper spectrum estimation with 7 DPSS windows
Used Annotations descriptions: ['Sleep stage 1', 'Sleep stage 2', 'Sleep stage 3', 'Sleep stage 4', 'Sleep stage R', 'Sleep stage W']
Not setting metadata
672 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 672 events and 3000 original time points ...
0 bad epochs dropped
    Using multitaper spectrum estimation with 7 DPSS windows


## Designing a scikit-learn transformer 

In [1084]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import FunctionTransformer

In [1086]:
X_train=np.vstack(all_epochs_train)
y_train = np.hstack(all_y_train).ravel()
y_train=y_train.reshape(-1,1)

In [1088]:
X_test=np.vstack(all_epochs_test)
y_test=np.hstack(all_y_test).ravel()
y_test=y_test.reshape(-1,1)

In [1090]:
pipe = make_pipeline(
    RandomForestClassifier(n_estimators=100, random_state=42),
)
# Fit
pipe.fit(X_train, y_train)


  return fit_method(estimator, *args, **kwargs)


In [1091]:
# Test
y_pred = pipe.predict(X_test)

In [1092]:
# Assess the results
acc = accuracy_score(y_test, y_pred)

print(f"Accuracy score: {acc}")

Accuracy score: 0.7430519140010488


In [1096]:
print(confusion_matrix(y_test, y_pred))

[[312  74   5   1  16]
 [ 29 101  11   0  69]
 [  4  42 731   2  58]
 [  0   0 118  69   1]
 [  7  45   8   0 204]]


In [1098]:
print(classification_report(y_test, y_pred, target_names=event_id.keys()))

                 precision    recall  f1-score   support

  Sleep stage W       0.89      0.76      0.82       408
  Sleep stage 1       0.39      0.48      0.43       210
  Sleep stage 2       0.84      0.87      0.85       837
Sleep stage 3/4       0.96      0.37      0.53       188
  Sleep stage R       0.59      0.77      0.67       264

       accuracy                           0.74      1907
      macro avg       0.73      0.65      0.66      1907
   weighted avg       0.78      0.74      0.74      1907

