In [7]:
# import all the relevant libraries
import wfdb
import mne
import numpy as np
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, hamming_loss
from sklearn.pipeline import make_pipeline,Pipeline
from sklearn.preprocessing import FunctionTransformer,StandardScaler
from mne.datasets import sample
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score, KFold
from mne_features.feature_extraction import extract_features
from sklearn.datasets import load_iris
from sklearn.feature_selection import SelectKBest,f_classif

  warn('{}. Your code will be slower.'.format(err))


In [8]:
# This function takes in the record path of a PSG file and outputs the epochs. It renames, changes types and removes channels so that they
# are consistent between different patients. Some patients had more/less data collected through different manners which had to be fixed.
def create_epoch(record_path):
    record = wfdb.rdrecord(record_path)
    annotation = wfdb.rdann(record_path, 'st')
    
    annotation_mne = mne.Annotations(
        onset=annotation.sample / record.fs,
        duration=30,
        description=annotation.aux_note
    )
    
    signals = record.p_signal  # Signal data as NumPy array
    sampling_rate = record.fs  # Sampling frequency (e.g., 250 Hz)
    channel_names = record.sig_name  # Channel names
    seen_categories = set()
    
    # Assign types while ensuring only the first occurrence is assigned correctly
    channel_types = []
    channel_names_updated = []
    
    for ch in channel_names:
        if "EEG" in ch and "EEG" not in seen_categories:
            channel_types.append("eeg")
            channel_names_updated.append("EEG")
            seen_categories.add("EEG")
        elif "Resp" in ch and "Resp" not in seen_categories:
            channel_types.append("resp")
            channel_names_updated.append("Respiratory")
            seen_categories.add("Resp")
        elif "ECG" in ch and "ECG" not in seen_categories:
            channel_types.append("ecg")
            channel_names_updated.append("ECG")
            seen_categories.add("ECG")
        else:
            channel_types.append("misc")
            channel_names_updated.append(ch)
    # Create MNE Info object
    info = mne.create_info(
        ch_names=channel_names_updated,
        sfreq=sampling_rate,
        ch_types=channel_types
    )
    # Convert signals to MNE RawArray
    raw = mne.io.RawArray(signals.T, info)
    
    # Add annotations to the raw object
    raw.set_annotations(annotation_mne)
    misc_channels = [ch for ch, ch_type in zip(channel_names_updated, channel_types) if "misc" in ch_type]
    raw.drop_channels(misc_channels)
    event_labels = {
    1: "Hypopnea",
    2: "Obstructive Apnea",
    3: "Central Apnea",
    4: "No Apnea Event"
    }
    new_id = {}
    for desc in raw.annotations.description:
        if "H" in desc or "HA" in desc:  # Hypopnea
            new_id[desc] = 1
        elif " A" in desc or "X" in desc:  # Obstructive Apnea
            new_id[desc] = 2
        elif "CA" in desc or "CAA" in desc:  # Central Apnea
            new_id[desc] = 3
        else:  # Other
            new_id[desc] = 4
    
    # Generate events with the new mapping
    events, event_id = mne.events_from_annotations(raw, event_id=new_id)
    
    new_annotations = mne.Annotations(
        onset=events[:, 0] / raw.info["sfreq"],  # Convert sample index to seconds
        duration=[30] * len(events),  # Assume each event lasts 30 seconds
        description=[event_labels[e] for e in events[:, 2]]  # Convert IDs back to labels
    )
    tmax = 30.0 - 1.0 / raw.info["sfreq"]  # tmax in included
    epochs = mne.Epochs(
        raw,
        events=events,
        tmin=0.0,
        tmax=tmax,
        baseline=None,
        picks=["ECG",'Respiratory',"EEG"]
    )
    return epochs

In [15]:
all_epochs=[]
List_of_subjects=open(r"\Users\piotr\Desktop\PSG data\List of subjects.txt","r").read().split("\n")
for i in List_of_subjects:
    record_path = fr"C:\Users\piotr\Desktop\PSG data\MIT Data\{i}"
    all_epochs.append(create_epoch(record_path))

Creating RawArray with float64 data, n_channels=4, n_times=1800000
    Range : 0 ... 1799999 =      0.000 ...  7199.996 secs
Ready.
Used Annotations descriptions: [np.str_('1 LA'), np.str_('2'), np.str_('2 H'), np.str_('2 H H'), np.str_('2 H LA'), np.str_('2 HA'), np.str_('2 L'), np.str_('2 L HA'), np.str_('2 L LA'), np.str_('2 LA'), np.str_('2 LA H'), np.str_('2 LA HA'), np.str_('2 LA L'), np.str_('2 LA LA'), np.str_('3'), np.str_('3 H'), np.str_('3 H LA'), np.str_('3 HA'), np.str_('3 L'), np.str_('3 L HA LA'), np.str_('3 L L'), np.str_('3 L LA'), np.str_('3 LA'), np.str_('3 LA L'), np.str_('3 LA LA'), np.str_('4'), np.str_('4 L'), np.str_('4 L L'), np.str_('4 L L L'), np.str_('4 L LA'), np.str_('4 LA'), np.str_('4 LA HA LA'), np.str_('4 LA L'), np.str_('4 LA LA'), np.str_('MT'), np.str_('R'), np.str_('R H'), np.str_('R HA'), np.str_('W'), np.str_('W HA'), np.str_('W LA')]
Not setting metadata
240 matching events found
No baseline correction applied
0 projection items activated
Creati

In [21]:
epoch=mne.concatenate_epochs(all_epochs)

Using data from preloaded Raw for 240 events and 7500 original time points ...
Using data from preloaded Raw for 360 events and 7500 original time points ...
Using data from preloaded Raw for 360 events and 7500 original time points ...


  epoch=mne.concatenate_epochs(all_epochs)


Using data from preloaded Raw for 270 events and 7500 original time points ...
Using data from preloaded Raw for 720 events and 7500 original time points ...
Using data from preloaded Raw for 720 events and 7500 original time points ...
Using data from preloaded Raw for 714 events and 7500 original time points ...
Using data from preloaded Raw for 694 events and 7500 original time points ...
Using data from preloaded Raw for 640 events and 7500 original time points ...
Using data from preloaded Raw for 698 events and 7500 original time points ...
Using data from preloaded Raw for 780 events and 7500 original time points ...
Using data from preloaded Raw for 760 events and 7500 original time points ...
Using data from preloaded Raw for 760 events and 7500 original time points ...
Using data from preloaded Raw for 458 events and 7500 original time points ...
Using data from preloaded Raw for 710 events and 7500 original time points ...
Using data from preloaded Raw for 720 events and 750

In [29]:
def eeg_power_band(epochs):
    """EEG relative power band feature extraction.

    This function takes an ``mne.Epochs`` object and creates EEG features based
    on relative power in specific frequency bands that are compatible with
    scikit-learn.

    Parameters
    ----------
    epochs : Epochs
        The data.

    Returns
    -------
    X : numpy array of shape [n_samples, 5 * n_channels]
        Transformed data.
    """
    # specific frequency bands
    FREQ_BANDS = {
        "delta": [0.5, 4.5],
        "theta": [4.5, 8.5],
        "alpha": [8.5, 11.5],
        "sigma": [11.5, 15.5],
        "beta": [15.5, 30],
    }

    spectrum = epochs.compute_psd(picks="eeg", fmin=0.5, fmax=30.0)
    psds, freqs = spectrum.get_data(return_freqs=True)
    # Normalize the PSDs
    psds /= np.sum(psds, axis=-1, keepdims=True)

    X = []
    for fmin, fmax in FREQ_BANDS.values():
        psds_band = psds[:, :, (freqs >= fmin) & (freqs < fmax)].mean(axis=-1)
        X.append(psds_band.reshape(len(psds), -1))

    return np.concatenate(X, axis=1)

In [45]:
pipe = make_pipeline(
    RandomForestClassifier(n_estimators=100, random_state=42),
)

# Train
y= epoch.events[:, 2]
cross_val_score(pipe,epoch,y)

KeyboardInterrupt: 

In [39]:
X_data=eeg_power_band(epoch)

    Using multitaper spectrum estimation with 7 DPSS windows


In [53]:
pipe.fit(X_data,y)

In [55]:
np.mean(cross_val_score(pipe,X_data,y))

np.float64(0.6196812643645001)

In [63]:
selected_funcs = {"pow_freq_bands"}
data=epoch.get_data()
y_data = epoch.events[:, 2]
X_data = extract_features(data, epoch.info['sfreq'], selected_funcs)

In [64]:
np.mean(cross_val_score(pipe,X_data,y_data))

np.float64(0.5474051582378906)