In [None]:
import numpy as np

from keras import callbacks

from sklearn.model_selection import StratifiedGroupKFold

from codes.set_gpu import set_gpu
from codes.load_seg_data import load_seg_data
from codes.cnn import cnn
from codes.print_history import print_history
from codes.print_metric import print_metric
from codes.print_predict import print_predict

set_gpu()

x_ADHD, y_ADHD, sbj_ADHD, seg_ADHD = load_seg_data("/home/unixuser/cdproject/public_data/segments_PSD/ADHD", 1)
x_HC, y_HC, sbj_HC, seg_HC = load_seg_data("/home/unixuser/cdproject/public_data/segments_PSD/HC", 0)

#x = psd 값을 나타내는 (batch size, frequecy, channel, depth) 4d ndarray(ADHD + HC)
#y = x의 label ndarray(ADHD + HC)
#s = x의 subject ndarray(ADHD + HC)
x = np.transpose(np.expand_dims(np.concatenate((x_ADHD, x_HC), axis=0), axis=-1), (0, 2, 1, 3))
y = np.concatenate((y_ADHD, y_HC), axis=0)
sbj = np.concatenate((sbj_ADHD, sbj_HC), axis=0)
seg = np.concatenate((seg_ADHD, seg_HC), axis=0)

sgkf = StratifiedGroupKFold(n_splits=10)
for fold_num, (train_index, test_index) in enumerate(sgkf.split(x, y, sbj), start=1):
    x_train = x[train_index]
    y_train = y[train_index]

    x_test = x[test_index]
    y_test = y[test_index]
    sbj_test = sbj[test_index]
    seg_test = seg[test_index]

    model = cnn(x[0].shape)

    print(f"--------------------fold {fold_num}--------------------")

    reduce_lr = callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=0.00001, verbose=1)
    early_stopping = callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
    checkpoint = callbacks.ModelCheckpoint('best_model.keras', monitor='val_loss', save_best_only=True, mode='min')

    history = model.fit(x_train, y_train, epochs=50, batch_size=10, verbose=0, validation_data=(x_test, y_test), callbacks=[reduce_lr, early_stopping, checkpoint])
    print_history(history)

    metric = model.evaluate(x_test, y_test, verbose=0)
    print(metric)

    y_predicted = model.predict(x_test)
    print_predict(y_predicted, y_test, sbj_test, seg_test, fold_num)