## 睡眠段階の定義

In [1]:
# ラベル名をIDに置き換える
# Sleep stage 3とSleep stage 4を同じIDとして、AASMによる分類に変更する
RANDK_LABEL2ID = {
    'Movement time': -1,
    'Sleep stage ?': -1,
    'Sleep stage W': 0,
    'Sleep stage 1': 1,
    'Sleep stage 2': 2,
    'Sleep stage 3': 3,
    'Sleep stage 4': 3,
    'Sleep stage R': 4
}

# AASMによる分類
LABEL2ID = {
    'Movement time': -1,
    'Sleep stage ?': -1,
    'Sleep stage W': 0,
    'Sleep stage 1': 1,
    'Sleep stage 2': 2,
    'Sleep stage 3/4': 3,
    'Sleep stage R': 4
}
ID2LABEL = {v:k for k, v in LABEL2ID.items()}

## ライブラリのインストールとデータ設定

In [2]:
import os
os.chdir('/sleep-stage-detection/')

In [14]:
import datetime
from pathlib import Path

import warnings
from tqdm.auto import tqdm
from typing import Dict, List

import mne
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import lightgbm as lgb

from src.config import *

warnings.filterwarnings('ignore')

## データの読み込み

In [4]:
train_record_df = pd.read_csv(PATH_TRAIN_DATA)
test_record_df = pd.read_csv(PATH_TEST_DATA)
sample_submission_df = pd.read_csv(PATH_SAMPLE_SUBMIT, parse_dates=[1])

In [5]:
INPUT_DIR = Path("input/")
EDF_DIR = INPUT_DIR / "edf_data"

In [6]:
# パスを指定
train_record_df["hypnogram"] = train_record_df["hypnogram"].map(lambda x: str(EDF_DIR/x))
train_record_df["psg"] = train_record_df["psg"].map(lambda x: str(EDF_DIR/x))
test_record_df["psg"] = test_record_df["psg"].map(lambda x: str(EDF_DIR/x))

In [7]:
# 確認用
display(train_record_df.head())
display(test_record_df.head())

Unnamed: 0,id,subject_id,night,age,sex,lights_off,psg,hypnogram
0,3c1c5cf,07c46da,1,90,male,23:28:00,input/edf_data/3c1c5cf-PSG.edf,input/edf_data/3c1c5cf-Hypnogram.edf
1,8fbd71b,07c46da,2,90,male,01:29:00,input/edf_data/8fbd71b-PSG.edf,input/edf_data/8fbd71b-Hypnogram.edf
2,9d5e9ec,21969ff,1,51,female,23:10:00,input/edf_data/9d5e9ec-PSG.edf,input/edf_data/9d5e9ec-Hypnogram.edf
3,e0df8c0,21969ff,2,51,female,23:15:00,input/edf_data/e0df8c0-PSG.edf,input/edf_data/e0df8c0-Hypnogram.edf
4,3e404fc,22b58e8,1,51,female,22:38:00,input/edf_data/3e404fc-PSG.edf,input/edf_data/3e404fc-Hypnogram.edf


Unnamed: 0,id,subject_id,night,age,sex,lights_off,psg
0,53c1555,17ca2cd,1,91,female,00:15:00,input/edf_data/53c1555-PSG.edf
1,29ef1d5,17ca2cd,2,91,female,23:39:00,input/edf_data/29ef1d5-PSG.edf
2,c90b6e7,2c77159,1,56,female,23:55:00,input/edf_data/c90b6e7-PSG.edf
3,a61e635,2c77159,2,56,female,00:13:00,input/edf_data/a61e635-PSG.edf
4,2cb6860,40dc0bc,1,52,male,23:03:00,input/edf_data/2cb6860-PSG.edf


## 関数定義

In [12]:
def epoch_to_df(epoch:mne.epochs.Epochs) -> pd.DataFrame:
    truncate_start_point = epoch.info["temp"]["truncate_start_point"]
    
    df = epoch.to_data_frame(verbose=False)
    
    new_meas_date = epoch.info["meas_date"].replace(tzinfo=None) + datetime.timedelta(seconds=truncate_start_point)
    
    df["meas_time"] = pd.date_range(start=new_meas_date, periods=len(df), freq=pd.Timedelta(1 / 100, unit="s"))
    
    return df

In [16]:
def epoch_to_sub_df(epoch_df:pd.DataFrame, id, is_train:bool) -> pd.DataFrame:
    cols = ["id", "meas_time"]
    if is_train:
        # 訓練セットの場合はラベルを追加
        cols.append("condition")
    
    label_df = epoch_df.loc[epoch_df.groupby("epoch")["time"].idxmin()].reset_index(drop=True)
    label_df["id"] = id
    
    return label_df[cols]

In [8]:
def read_and_set_annoation(record_df:pd.DataFrame, include=None, is_test=False) -> List[mne.epochs.Epochs]:
    whole_epoch_data = []

    for row_id, row in tqdm(record_df.iterrows(), total=len(record_df)):        
        # PSGファイルとHypnogram(アノテーションファイルを読み込む)
        psg_edf = mne.io.read_raw_edf(row["psg"], include=include, verbose=False)
        
        if not is_test:
            # 訓練データの場合
            annot = mne.read_annotations(row["hypnogram"])

            # 切り捨て
            truncate_start_point = 3600 * 5
            truncate_end_point = (len(psg_edf)/100) - (3600 *5)
            annot.crop(truncate_start_point, truncate_end_point, verbose=False)

            # アノテーションデータの切り捨て
            psg_edf.set_annotations(annot, emit_warning=False)
            events, _ = mne.events_from_annotations(psg_edf, event_id=RANDK_LABEL2ID, chunk_duration=30., verbose=False)
            
            event_id = LABEL2ID
        else:
            # テストデータの場合
            start_psg_date = psg_edf.info["meas_date"]
            start_psg_date = start_psg_date.replace(tzinfo=None)

            test_start_time = sample_submission_df[sample_submission_df["id"]==row["id"]]["meas_time"].min()
            test_end_time = sample_submission_df[sample_submission_df["id"]==row["id"]]["meas_time"].max()
            
            truncate_start_point = int((test_start_time - start_psg_date).total_seconds())
            truncate_end_point = int((test_end_time- start_psg_date).total_seconds())+30
            
            event_range = list(range(truncate_start_point, truncate_end_point, 30))
            events = np.zeros((len(event_range), 3), dtype=int)
            events[:, 0] = event_range
            events = events * 100
            
            event_id = {'Sleep stage W': 0}
            
    
        # 30秒毎に1epochとする
        tmax = 30. - 1. / psg_edf.info['sfreq']
        epoch = mne.Epochs(raw=psg_edf, events=events, event_id=event_id, tmin=0, tmax=tmax, baseline=None, verbose=False, on_missing='ignore')
        
        # 途中でデータが落ちてないかチェック
        assert len(epoch.events) * 30 == truncate_end_point - truncate_start_point
        
        # メタデータを追加
        epoch.info["temp"] = {
            "id":row["id"],
            "subject_id":row["subject_id"],
            "night":row["night"],
            "age":row["age"],
            "sex":row["sex"],
            "truncate_start_point":truncate_start_point
        }

        whole_epoch_data.append(epoch)

    return whole_epoch_data 

In [9]:
# 処理を簡単にするためにEEG Fpz-Czのみ読み込みます
# またtrainをすべて処理するには少し時間がかかるため、ここでは半分ほどの50を利用することにします
train_record_subset_df = train_record_df.sample(n=50).reset_index(drop=True)

train_subset_epoch = read_and_set_annoation(train_record_subset_df, include=["EEG Fpz-Cz"], is_test=False)
test_whole_epoch = read_and_set_annoation(test_record_df, include=["EEG Fpz-Cz"], is_test=True)

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/45 [00:00<?, ?it/s]

## 特徴量エンジニアリング

In [10]:
def eeg_power_band(epochs):
    """EEG relative power band feature extraction.

    This function takes an ``mne.Epochs`` object and creates EEG features based
    on relative power in specific frequency bands that are compatible with
    scikit-learn.

    Parameters
    ----------
    epochs : Epochs
        The data.

    Returns
    -------
    X : numpy array of shape [n_samples, 5]
        Transformed data. 
    """
    # specific frequency bands
    FREQ_BANDS = {"delta": [0.5, 4.5],
                  "theta": [4.5, 8.5],
                  "alpha": [8.5, 11.5],
                  "sigma": [11.5, 15.5],
                  "beta": [15.5, 30]}

    spectrum = epochs.compute_psd(picks='eeg', fmin=0.5, fmax=30. ,verbose=False)
    psds, freqs = spectrum.get_data(return_freqs=True)
    # Normalize the PSDs
    psds /= np.sum(psds, axis=-1, keepdims=True)

    X = []
    for fmin, fmax in FREQ_BANDS.values():
        psds_band = psds[:, :, (freqs >= fmin) & (freqs < fmax)].mean(axis=-1)
        X.append(psds_band.reshape(len(psds), -1))

    return np.concatenate(X, axis=1)

In [17]:
train_df = []
for epoch in tqdm(train_subset_epoch):
    # 波形をdataframe化
    epoch_df = epoch_to_df(epoch)
    # submit形式のデータフレーム生成
    sub_df = epoch_to_sub_df(epoch_df, epoch.info["temp"]["id"], is_train=True)
    
    # パワースペクトル密度計算
    feature_df = pd.DataFrame(eeg_power_band(epoch))
    
    _df = pd.concat([sub_df, feature_df], axis=1)
    # 必要ないラベルがある場合は除外する
    _df = _df[~_df["condition"].isin(["Sleep stage ?", "Movement time"])]
    
    train_df.append(_df)
    
train_df = pd.concat(train_df).reset_index(drop=True)

  0%|          | 0/50 [00:00<?, ?it/s]

In [18]:
train_df["condition"].value_counts()

Sleep stage W      35487
Sleep stage 2      21455
Sleep stage R       7288
Sleep stage 1       6970
Sleep stage 3/4     3876
Name: condition, dtype: int64

In [19]:
# ラベルIDに変換
train_df["condition"] = train_df["condition"].map(LABEL2ID)

In [20]:
train_df

Unnamed: 0,id,meas_time,condition,0,1,2,3,4
0,5464ae3,1989-09-04 21:05:00,0,0.006058,0.000671,0.000291,0.000243,0.000315
1,5464ae3,1989-09-04 21:05:30,0,0.007440,0.000366,0.000178,0.000087,0.000084
2,5464ae3,1989-09-04 21:06:00,0,0.006651,0.000469,0.000289,0.000197,0.000220
3,5464ae3,1989-09-04 21:06:30,0,0.006988,0.000541,0.000366,0.000124,0.000112
4,5464ae3,1989-09-04 21:07:00,0,0.006886,0.000544,0.000378,0.000192,0.000118
...,...,...,...,...,...,...,...,...
75071,a168af0,1990-03-15 09:22:30,0,0.006501,0.000882,0.000211,0.000189,0.000166
75072,a168af0,1990-03-15 09:23:00,0,0.007052,0.000561,0.000181,0.000128,0.000125
75073,a168af0,1990-03-15 09:23:30,0,0.007285,0.000294,0.000157,0.000102,0.000147
75074,a168af0,1990-03-15 09:24:00,0,0.007569,0.000426,0.000121,0.000078,0.000047


## ベースモデル構築

In [21]:
# 20％の被験者を選ぶ
val_size = int(train_record_df["subject_id"].nunique() * 0.20)
train_all_subjects = train_record_df["subject_id"].unique()
np.random.shuffle(train_all_subjects)

val_subjects = train_all_subjects[:val_size]
val_ids = train_record_df[train_record_df["subject_id"].isin(val_subjects)]["id"]

In [22]:
# 検証セットを作成します
trn_df = train_df[~train_df["id"].isin(val_ids)]
val_df = train_df[train_df["id"].isin(val_ids)]

In [23]:
trn_df

Unnamed: 0,id,meas_time,condition,0,1,2,3,4
0,5464ae3,1989-09-04 21:05:00,0,0.006058,0.000671,0.000291,0.000243,0.000315
1,5464ae3,1989-09-04 21:05:30,0,0.007440,0.000366,0.000178,0.000087,0.000084
2,5464ae3,1989-09-04 21:06:00,0,0.006651,0.000469,0.000289,0.000197,0.000220
3,5464ae3,1989-09-04 21:06:30,0,0.006988,0.000541,0.000366,0.000124,0.000112
4,5464ae3,1989-09-04 21:07:00,0,0.006886,0.000544,0.000378,0.000192,0.000118
...,...,...,...,...,...,...,...,...
75071,a168af0,1990-03-15 09:22:30,0,0.006501,0.000882,0.000211,0.000189,0.000166
75072,a168af0,1990-03-15 09:23:00,0,0.007052,0.000561,0.000181,0.000128,0.000125
75073,a168af0,1990-03-15 09:23:30,0,0.007285,0.000294,0.000157,0.000102,0.000147
75074,a168af0,1990-03-15 09:24:00,0,0.007569,0.000426,0.000121,0.000078,0.000047


In [24]:
val_df

Unnamed: 0,id,meas_time,condition,0,1,2,3,4
1530,e0df8c0,1989-05-09 21:09:00,0,0.007059,0.000795,0.000144,0.000093,0.000077
1531,e0df8c0,1989-05-09 21:09:30,0,0.007500,0.000417,0.000144,0.000085,0.000061
1532,e0df8c0,1989-05-09 21:10:00,0,0.007296,0.000729,0.000140,0.000061,0.000039
1533,e0df8c0,1989-05-09 21:10:30,0,0.006992,0.000747,0.000152,0.000100,0.000105
1534,e0df8c0,1989-05-09 21:11:00,0,0.006589,0.000695,0.000269,0.000184,0.000182
...,...,...,...,...,...,...,...,...
57572,89ac2db,1989-04-27 10:03:30,0,0.007434,0.000654,0.000133,0.000066,0.000022
57573,89ac2db,1989-04-27 10:04:00,0,0.006235,0.001190,0.000413,0.000183,0.000115
57574,89ac2db,1989-04-27 10:04:30,0,0.006965,0.000982,0.000230,0.000095,0.000033
57575,89ac2db,1989-04-27 10:05:00,0,0.007294,0.000746,0.000151,0.000057,0.000034


### LightGBM実装

In [34]:
train_x = trn_df.iloc[:, 3:]
train_y = trn_df[COL_CONDITION]

val_x = val_df.iloc[:, 3:]
val_y = val_df[COL_CONDITION]

trains = lgb.Dataset(train_x, train_y)
valids = lgb.Dataset(val_x, val_y)

params = {
    "objective": "regression",
    "metrics": "mae"
}

In [35]:
model = lgb.train(params, trains, valid_sets=valids, num_boost_round=1000, early_stopping_rounds=100)

You can set `force_row_wise=true` to remove the overhead.
And if memory is not enough, you can set `force_col_wise=true`.
[LightGBM] [Info] Total Bins 1275
[LightGBM] [Info] Number of data points in the train set: 60001, number of used features: 5
[LightGBM] [Info] Start training from score 1.200697
[1]	valid_0's l1: 1.14064
Training until validation scores don't improve for 100 rounds
[2]	valid_0's l1: 1.09709
[3]	valid_0's l1: 1.05829
[4]	valid_0's l1: 1.02361
[5]	valid_0's l1: 0.992405
[6]	valid_0's l1: 0.961935
[7]	valid_0's l1: 0.936277
[8]	valid_0's l1: 0.912159
[9]	valid_0's l1: 0.892507
[10]	valid_0's l1: 0.874827
[11]	valid_0's l1: 0.858249
[12]	valid_0's l1: 0.842314
[13]	valid_0's l1: 0.82892
[14]	valid_0's l1: 0.817685
[15]	valid_0's l1: 0.805929
[16]	valid_0's l1: 0.795053
[17]	valid_0's l1: 0.786612
[18]	valid_0's l1: 0.778814
[19]	valid_0's l1: 0.77217
[20]	valid_0's l1: 0.766242
[21]	valid_0's l1: 0.760166
[22]	valid_0's l1: 0.754844
[23]	valid_0's l1: 0.750274
[24]	val