In [1]:
import re
import os
import Loader
import numpy as np
from mne import viz
from scipy import io
from matplotlib import colors, pyplot as plt
from sklearn import preprocessing, model_selection, ensemble

In [2]:
CAF_DOSE = 200

PROJECT_PATH = 'E:\\Cafeine_data'

STAGES = ['AWA', 'N1', 'N2', 'N3', 'REM']
BANDS = ['delta', 'theta', 'alpha', 'sigma', 'beta', 'low gamma']

In [3]:
sensor_pos = io.loadmat(os.path.join(PROJECT_PATH, 'Coo_caf'))['Cor'].T
sensor_pos = np.array([sensor_pos[1], sensor_pos[0]]).T

In [4]:
subject_labels = Loader.load_labels(CAF_DOSE)
psd = Loader.load_feature('PSD', CAF_DOSE)

In [5]:
data = {}
labels = {}
groups = {}
group_indices = {}

for stage in STAGES:
    curr_data = []
    labels[stage] = []
    groups[stage] = []
    group_indices[stage] = {}
    
    for subject_id, subject in psd[stage].items():
        if subject.size == 0:
            continue
            
        curr_data.append(subject)
        labels[stage] += [subject_labels[subject_id]] * subject.shape[1]
        
        subject_short = re.match('\S\d+', subject_id)[0]
        if len(group_indices[stage]) == 0:
            group_indices[stage][subject_short] = 0
        elif not subject_short in group_indices[stage]:
            group_indices[stage][subject_short] = np.max(list(group_indices[stage].values())) + 1
        
        groups[stage] += [group_indices[stage][subject_short]] * subject.shape[1]
    
    concatenated = np.concatenate(curr_data, axis=1)

    data[stage] = {}
    for i, band in enumerate(BANDS):
        data[stage][band] = concatenated[:,:,i].T

In [49]:
scores = {}
rfs = {}
for stage in STAGES:
    print(f'Sleep stage {stage}')
    scores[stage] = {}
    rfs[stage] = {}
    for band in BANDS:
        print(f'   Training for frequency band {band}...')
        rf = ensemble.RandomForestClassifier(n_estimators=10)
        rfs[stage][band] = rf
        scores[stage][band] = model_selection.permutation_test_score(rf, data[stage][band], labels[stage], groups[stage], n_permutations=100, cv=10, n_jobs=6)

Sleep stage AWA
   Training for frequency band delta...
   Training for frequency band theta...
   Training for frequency band alpha...
   Training for frequency band sigma...
   Training for frequency band beta...
   Training for frequency band low gamma...
Sleep stage N1
   Training for frequency band delta...
   Training for frequency band theta...
   Training for frequency band alpha...
   Training for frequency band sigma...
   Training for frequency band beta...
   Training for frequency band low gamma...
Sleep stage N2
   Training for frequency band delta...
   Training for frequency band theta...
   Training for frequency band alpha...
   Training for frequency band sigma...
   Training for frequency band beta...
   Training for frequency band low gamma...
Sleep stage N3
   Training for frequency band delta...
   Training for frequency band theta...
   Training for frequency band alpha...
   Training for frequency band sigma...
   Training for frequency band beta...
   Training

In [50]:
for stage in STAGES:
    print(f'Sleep stage {stage}')
    for band in BANDS:
        print(f'    Frequency band {band:10}: score {scores[stage][band][0]:.3f}, p-value {scores[stage][band][2]:.3f}')
    print()

Sleep stage AWA
    Frequency band delta     : score 0.706, p-value 0.970
    Frequency band theta     : score 0.706, p-value 1.000
    Frequency band alpha     : score 0.706, p-value 1.000
    Frequency band sigma     : score 0.706, p-value 1.000
    Frequency band beta      : score 0.706, p-value 1.000
    Frequency band low gamma : score 0.706, p-value 1.000

Sleep stage N1
    Frequency band delta     : score 0.522, p-value 1.000
    Frequency band theta     : score 0.522, p-value 1.000
    Frequency band alpha     : score 0.522, p-value 1.000
    Frequency band sigma     : score 0.522, p-value 1.000
    Frequency band beta      : score 0.522, p-value 1.000
    Frequency band low gamma : score 0.522, p-value 1.000

Sleep stage N2
    Frequency band delta     : score 0.539, p-value 1.000
    Frequency band theta     : score 0.539, p-value 1.000
    Frequency band alpha     : score 0.539, p-value 1.000
    Frequency band sigma     : score 0.539, p-value 1.000
    Frequency band beta 

In [None]:
for stage in STAGES:
    for band in BANDS:
        scaler = preprocessing.StandardScaler()
        scaler.fit(x_train[stage][band])

        x_train[stage][band] = scaler.transform(x_train[stage][band])
        x_test[stage][band] = scaler.transform(x_test[stage][band])

    print(stage)
    print(f'   Train samples: {x_train[stage][BANDS[0]].shape[0]}')
    percentage = x_test[stage][BANDS[0]].shape[0] / x_train[stage][BANDS[0]].shape[0] * 100
    print(f'   Test samples: {x_test[stage][BANDS[0]].shape[0]} ({percentage:.3f}%)')

In [None]:
rfs = {}
for stage in STAGES:
    rfs[stage] = {}
    print(f'Sleep stage {stage}')
    for band in BANDS:
        print(f'   Training random for frequency band {band}...')
        rfs[stage][band] = ensemble.RandomForestClassifier(n_estimators=200)
        rfs[stage][band].fit(x_train[stage][band], y_train[stage])

In [51]:
for stage in STAGES:
    print(f'Sleep stage {stage}')
    for band in BANDS:
        acc_train = rfs[stage][band].score(x_train[stage][band], y_train[stage]) * 100
        acc_test = rfs[stage][band].score(x_test[stage][band], y_test[stage]) * 100
        print(f'   {band} accuracy for training data: {acc_train:.3f}%, testing data: {acc_test:.3f}%')

Sleep stage AWA


NameError: name 'x_train' is not defined

In [52]:
colormap = 'jet'

vmin = np.min([[rf.feature_importances_.min() for rf in rfs[stage].values()] for stage in STAGES])
vmax = np.max([[rf.feature_importances_.max() for rf in rfs[stage].values()] for stage in STAGES])

for stage in STAGES:
    plt.figure(figsize=(18, 6))
    plt.suptitle(f'{stage} sleep stage feature importances', y=0.8, fontsize=15)
    
    axes = []
    for i, band in enumerate(BANDS):
        ax = plt.subplot(1, len(BANDS), i + 1)
        axes.append(ax)
        plt.title(f'PSD {band}')
        importance = rfs[stage][band].feature_importances_
        viz.plot_topomap(importance, sensor_pos, cmap=colormap, vmin=vmin, vmax=vmax, contours=False, show=False)
        
    norm = colors.Normalize(vmin=vmin,vmax=vmax)
    sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)
    sm.set_array([])
    plt.colorbar(sm, ax=axes, shrink=0.3, aspect=15)
    plt.show()

NotFittedError: This RandomForestClassifier instance is not fitted yet. Call 'fit' with appropriate arguments before using this method.