In [None]:
#! pip install mne==1.4.2

In [2]:
import tensorflow
print("Num GPUs Available: ", len(tensorflow.config.list_physical_devices('GPU')))

2024-12-17 10:40:47.006462: I tensorflow/core/platform/cpu_feature_guard.cc:194] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE3 SSE4.1 SSE4.2 AVX
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


Num GPUs Available:  4


In [3]:
import pandas as pd
import mne
import numpy as np
import os
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
#from openpyxl import load_workbook
import pickle
import random
from collections import defaultdict
from mne.decoding import Scaler
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.callbacks import ModelCheckpoint
import sys
sys.path.append('./arl-eegmodels')
from EEGModels import EEGNet
from sklearn.metrics import roc_auc_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score

In [4]:
pd.set_option('display.max_columns', 100)  # or 1000
pd.set_option('display.max_rows', 100)  # or 1000
pd.set_option('display.max_colwidth', None)  # or 199

In [5]:
file=open('/data/sleep_germandata_reactivation-1/eeg_sub_128_fastspin/Sub_list.pkl','rb')
sub_list=pickle.load(file)

## Parameters for model

In [6]:
# parameters for EEGNet
nb_classes=2
dropoutRate=0.17 # change from 0.25 to 0.17 (cross-subject)
kernLength=100
F1=16
D=2
F2=F1*D
norm_rate=0.25
dropoutType = 'Dropout'
batch_size =16
epoch_no=200
random.seed(42) # set random seed

## Cross-subject classification function

In [7]:
def cross_sub_losocv(model_name,subjs,data_path,checkpoint_path,region,sfreq,ts_start,ts_dur,night,
                   shuf,repeats,flag,results_path):
    """
    Perform Leave-One-Subject-Out cross-validation.

    Parameters:
    - model_name: model name (e.g., 'eegnet')
    - subjs is a list of subject_ids
    - data_path
    - checkpoint_path
    - region
    - sfreq=sampling frequency (e.g., 128 Hz)
    - ts_start = start of timestamp
    - ts_dur= total duration of timestamp
    - night='D' or 'M'
    - shuf=1 (with shuffle) or 0 (no suffle)
    - flag=1 (save) or 0 (not save)
    - results_path

    Returns:
    a result dictionary with:
    - model name
    - auc list
    - accuracy list
    - f1 list
    - recall list
    - precision list
    
    """
    
    # channels for region
    if region=='central':
        channels=[6,7,13,20,29,30,31,36,37,42,53,54,55,79,80,86,87,93,104,105,
                  106,111,112,118]
    elif region=='front': #frontal
        channels= [2,3,4,5,9,10,11,12,15,16,18,19,22,23,24,26,27,123,124]
    elif region=='post': #posterior
        channels= [59,60,61,62,65,66,67,70,71,72,75,76,77,78,83,84,85,90,91]
    elif region=='all':
        channels=np.arange(1,128)
        
    channels = [x - 1 for x in channels]
    
    # sample no. calculation for timestamp ##########################################################
    ts_end=ts_start+ts_dur # timing in second
    start=ts_start*sfreq
    end=start+(ts_dur*sfreq+1)
    
    # D or M night 
    if night=='D':
        X_night='XD_'
        y_night='y_encodedD_'
    elif night=='M':
        X_night='XM_'
        y_night='y_encodedM_'

    numsub_val= 4   # randomly select n subject for valid set
    
    # Loop through each subject in the subject list ######################################
    results=defaultdict(dict)
    for s in subjs:
        print(' ')
        print(f"Testing Subject {s}:")
        auc_list=[]
        accuracy_list = []
        f1_list=[]
        recall_list=[]
        precision_list=[]
        
        for r in range(repeats):
            #now = datetime.now()
            # Convert to string
            #date_time_str = now.strftime("%Y-%m-%d %H:%M:%S")
            # test subject############################################
            X_test=np.load(data_path+X_night+s+'.npy')
            X_test=X_test[:,channels,start:end]
            y_encoded_test=np.load(data_path+y_night+s+'.npy')
            print(f"X test: {X_test.shape}")
            print(f"y test: {y_encoded_test.shape}")

            subjs1=set(subjs)-set([s])

            #val subjects##############################################
            val_subj = random.sample(subjs1, numsub_val)
            X_val = []
            y_encoded_val = []
            for s_val in val_subj:
                X=np.load(data_path+X_night+s_val+'.npy')
                X=X[:,channels,start:end]
                y_encoded=np.load(data_path+y_night+s_val+'.npy')
                    
                X_val.extend(X)
                y_encoded_val.extend(y_encoded)
            
            X_val=np.array(X_val)
            y_encoded_val=np.array(y_encoded_val)

            print(f"X val: {X_val.shape}")
            print(f"y val: {y_encoded_val.shape}")

            #train subjects###################################################
            train_subj=set(subjs)-set(val_subj)-set([s])
            #train_subj=set(subjs)-set([s])
            X_train = []
            y_encoded_train = []
            for s_train in train_subj:
                X=np.load(data_path+X_night+s_train+'.npy')
                X=X[:,channels,start:end]
                y_encoded=np.load(data_path+y_night+s_train+'.npy')
                if shuf==1:
                    np.random.shuffle(y_encoded)  # shuffle
                
                X_train.extend(X)
                y_encoded_train.extend(y_encoded)
                
            X_train=np.array(X_train)
            y_encoded_train=np.array(y_encoded_train)

            print(f"X train: {X_train.shape}")
            print(f"y train: {y_encoded_train.shape}")

    ##### Deep learning ############################################################################################

            kernels, chans, samples = 1, X_train.shape[1], X_train.shape[2]

            # normalize inputs
            normalize_obj = Scaler(scalings='median').fit(epochs_data=X_train)
            X_train = normalize_obj.transform(X_train)
            X_val = normalize_obj.transform(X_val)
            X_test = normalize_obj.transform(X_test)

            X_train=X_train.reshape(X_train.shape[0], chans, samples, kernels)
            X_val=X_val.reshape(X_val.shape[0], chans, samples, kernels)
            X_test=X_test.reshape(X_test.shape[0], chans, samples, kernels)

            # callbacks
            early_stopping = EarlyStopping(monitor='val_accuracy', mode='max',patience=50,
                                           restore_best_weights=True)
            checkpoint_fullpath=checkpoint_path+'sub'+s+'_'+str(r)+'_'+region+'_'+str(ts_start)+'-'+str(ts_end)+'_'+night+'.h5'
            # first stage training
            checkpointer1 = ModelCheckpoint(filepath=checkpoint_fullpath, verbose=1,save_weights_only=True,
                                           monitor='val_accuracy',mode='max',save_best_only=True)
            
            # second stage training
            checkpointer2 = ModelCheckpoint(filepath=checkpoint_fullpath, verbose=1,save_weights_only=True,
                                           monitor='accuracy',mode='max',save_best_only=False)
            
            #log_path='./logs_crosssub_'+model_name+'/'+model_name+str(r)+'_'+s+'_'+night+date_time_str+'.csv'
            #csv_logger = CSVLogger(log_path,s,kernLength,F1,region,data_path)

            # Model fitting ###############################################################################
            class_weights = {0:1, 1:1}
            if model_name=='eegnet':

                model=EEGNet(nb_classes=nb_classes,Chans=chans,Samples=samples,dropoutRate=dropoutRate,
                            kernLength=kernLength,F1=F1,D=D,F2=F2,norm_rate=norm_rate,dropoutType=dropoutType)

            model.compile(loss = 'categorical_crossentropy', optimizer = 'adam',metrics=['accuracy'])
            
            # first stage training
            model.fit(X_train,y_encoded_train,batch_size = batch_size, epochs=epoch_no, 
                      validation_data=(X_val, y_encoded_val),
                      callbacks=[checkpointer1,early_stopping], class_weight = class_weights)
            model.load_weights(checkpoint_fullpath) # load optimal weights
            
            # second stage straining
            model.fit(X_val,y_encoded_val,batch_size = batch_size, epochs=50, 
                      callbacks=[checkpointer2], class_weight = class_weights)
            model.load_weights(checkpoint_fullpath) # load optimal weights
            
            y_predict=model.predict(X_test)
            y_class=np.argmax(y_predict, axis=1)
            roc_auc=roc_auc_score(y_encoded_test[:,1], y_predict[:,1])
            #accuracy=accuracy_score(y_encoded_test[:,1],y_class)
            accuracy= np.mean(y_class == y_encoded_test.argmax(axis=-1))
            f1=f1_score(y_encoded_test[:,1],y_class)
            recall=recall_score(y_encoded_test[:,1],y_class)
            precision=precision_score(y_encoded_test[:,1],y_class)

            # print results
            print(y_encoded_test[:,1])
            print(y_class)
            print(f"acc= {accuracy}, auc= {roc_auc}")

            auc_list.append(roc_auc)
            accuracy_list.append(accuracy)
            f1_list.append(f1)
            recall_list.append(recall)
            precision_list.append(precision)
            #val_list.append(s_val)
        
        #print(accuracy_list)
        results[s]['model']=model_name
        if night=='D':
            #results[s]['D_val']=val_subj
            results[s]['D_auc']=auc_list
            results[s]['D_accuracy']=accuracy_list
            results[s]['D_f1']=f1_list
            results[s]['D_recall']=recall_list
            results[s]['D_precision']=precision_list
        elif night=='M':
            #results[s]['M_val']=val_subj
            results[s]['M_auc']=auc_list
            results[s]['M_accuracy']=accuracy_list
            results[s]['M_f1']=f1_list
            results[s]['M_recall']=recall_list
            results[s]['M_precision']=precision_list

        # save results #####################     
        if flag==1 and shuf==0:
            f = open(results_path+model_name+'_'+night+'_'+region+'_'+str(ts_start)+'-'+str(ts_end)
                     +'s_crosssub_2stager5.pkl', 'wb')
            pickle.dump(results,f)
            f.close()
        elif flag==1 and shuf==1:
            f = open(results_path+model_name+'_'+night+'_'+region+'_'+str(ts_start)+'-'+str(ts_end)
                     +'s_crosssub_2stager5_shuf.pkl', 'wb')
            pickle.dump(results,f)
            f.close()
        
    print('Done!')
    print(' ')

## Input for function

In [8]:
model_name='eegnet'
subjs=sub_list
data_path = '/data/sleep_germandata_reactivation-1/eeg_sub_128_raw/'
checkpoint_path='./checkpoint_raw_crosssub/'
region=['all','front','central','post']
sfreq=128
ts_start=0
ts_dur=[2,4,7]
night=['D','M']
shuf=[0,1]
repeats=5
flag=1
results_path='./Results_raw(0.5-20)_crosssub/'

In [9]:
for reg in region:
    print(f'\nregion: {reg}')
    for ts in ts_dur:
        print(f'\nduration: {ts}')
        for n in night:
            print(f'night: {n}')
            for s in shuf:
                print(f'shuf: {s}')
                cross_sub_losocv(model_name,subjs,data_path,checkpoint_path,reg,sfreq,ts_start,ts,n,s,repeats,flag,results_path)


region: all

duration: 2
night: D
shuf: 0
 
Testing Subject 25:


FileNotFoundError: [Errno 2] No such file or directory: '/data/sleep_germandata_reactivation-1/eeg_sub_128_raw/XD_25.npy'