In [None]:
import numpy as np
import tensorflow as tf
from keras.preprocessing import image
import fnmatch
import matplotlib.pyplot as plt
import mne, os

from cnn_prepare_dataset import file_keeper, crop_set_annotations


In [None]:
def create_hypnograms(events, times_events, sfreq, img_folder_path, cnn, annotation_stage_id, training_set):

    # Coordinates for true hypnogram and predicted respectively
    x = []
    y = []

    x_pred = []
    y_pred = []

    percent = 0
    correct_predict = 0
    true_list = 6*[0] # List for true numbers of the each stage, based on the annotation_stage_id
    correct_predict_list = 6*[0] # List of the correct predictions for the each stage, based on the annotation_stage_id

    stages_names_list = list(annotation_stage_id.keys())
    stages_id_list = list(annotation_stage_id.values())

    for iteration in range(1, len(times_events)):

        start = int(times_events[iteration - 1][0] / sfreq)
        duration = int(times_events[iteration][0] / sfreq) - start
        stage_id = times_events[iteration - 1][2]

        current_stage_name = stages_names_list[stages_id_list.index(stage_id)]

        x.append(start)
        y.append(current_stage_name)
        x.append(start + duration)
        y.append(current_stage_name)

        image_name = "*" + str(start) + ".png"  # image name of the signal for specific PSG folder in the dataset_predict
        possible_image = fnmatch.filter(os.listdir(img_folder_path), image_name)[0]

        if not possible_image:
            print(f"Possible image file {image_name} hasn't been found.")
            continue

        true_list[stage_id - 1] += 1  # counting stage elements

        # print(f"Image that is currently being processed: {possible_image}, start time: {start}")

        path_possible_img = img_folder_path + "/" + possible_image

        expan_dim_image = image.load_img(path_possible_img, color_mode='grayscale', target_size=(800, 800))
        expan_dim_image = image.img_to_array(expan_dim_image)
        expan_dim_image = np.expand_dims(expan_dim_image, axis=0)
        predict_res = cnn.predict(expan_dim_image)
        
        # print(f"Predicted raw result: {predict_res}")
        # print(f"Class indexes: {training_set.class_indices}")

        index_predict = list(predict_res[0])
        # print(index_predict)

        
        predicted_value = np.argmax(index_predict)   # index

        prediction_main_list = list(training_set.class_indices.keys())
        prediction_id_main_list = list(training_set.class_indices.values())
        
        predicted_stage_name = prediction_main_list[predicted_value] # index of the current sleep stage in the prediction list
        id_stage_predict = stages_names_list.index(predicted_stage_name) + 1
        
        # print(f"Predicted value - {id_stage_predict}/{stage_id}")

        if id_stage_predict == stage_id:
            correct_predict_list[stage_id-1] += 1

        x_pred.append(start)
        y_pred.append(predicted_stage_name)
        x_pred.append(start + duration)
        y_pred.append(predicted_stage_name)

    print("======== Coordinates for true and predicted hypnogram have been received. ========")



    
    # Hypnogram plotting

    stages_names_for_plot = {
        "Sleep stage 4": 0,
        "Sleep stage 3": 1,
        "Sleep stage 2": 2,
        "Sleep stage 1": 3,
        "Sleep stage R": 4,        
        "Sleep stage W": 5
    }

    positions_of_labels = [0, 1, 2, 3, 4, 5]
    y_true_data_plot = [stages_names_for_plot[elem] for elem in y]
    y_pred_data_plot = [stages_names_for_plot[elem] for elem in y_pred]

    plt.rcParams['font.family'] = 'Times New Roman'
    plt.rcParams['font.size'] = 15

    fig = plt.figure()
    fig.set_figwidth(18)
    fig.set_figheight(10)

    plt.subplot(2, 1, 1)
    plt.plot(x, y_true_data_plot, '-b', linewidth='1', label="True hypnogram")
    plt.yticks(positions_of_labels, stages_names_for_plot)
    plt.xticks([])
    plt.xlabel('Time')
    plt.ylabel('Sleep stage')
    plt.title(f"Hypnograms for {img_folder_path[-7:]}")
    plt.legend()


    plt.subplot(2, 1, 2)
    plt.plot(x_pred, y_pred_data_plot, '-g', linewidth='1', label="Predicted hypnogram")
    plt.yticks(positions_of_labels, stages_names_for_plot)
    plt.xticks([])
    plt.xlabel('Time')
    plt.ylabel('Sleep stage')
    plt.legend()

    hypnogram_folder = img_folder_path + "/" + "hypnogram_info" + img_folder_path[-6:]
    if not os.path.isdir(hypnogram_folder):
        os.mkdir(hypnogram_folder)

    path_save_img = hypnogram_folder + f"/hypnogram_{img_folder_path[-6:]}"
    plt.savefig(path_save_img, bbox_inches='tight', pad_inches=0)

    plt.close(fig)

    # Creating confusion map

    
    for elem in y_true_data_plot:
        elem += 1
        
    for elem in y_pred_data_plot:
        elem += 1
    
    path_save_img = hypnogram_folder + f"/confusion_map_{img_folder_path[-6:]}"
    map = create_confusion_matrix(y_true_data_plot, y_pred_data_plot, path_save_img)


    # Percentage of predictions
    try:
        percent = round(sum(correct_predict_list)/sum(true_list) * 100, 2)
    except ZeroDivisionError:
        percent = 0

    info = f"All prediction results for {img_folder_path[-6:]}: {sum(correct_predict_list)}/{sum(true_list)} \
                    ({percent}%)\n"

    path_for_txt = hypnogram_folder + f"/predict_info_{img_folder_path[-6:]}.txt"
    with open(path_for_txt, "w") as file:
        file.write(info)

    print(info)

    for iteration in range(len(correct_predict_list)):
        try:
            percent = round(correct_predict_list[iteration]/true_list[iteration] * 100, 2)
        except ZeroDivisionError:
            if correct_predict_list[iteration] == true_list[iteration]:
                percent = 100
            else:
                percent = 0
        info = f"Prediction results for {stages_names_list[iteration]}: {correct_predict_list[iteration]}/{true_list[iteration]} \
                    ({percent}%)\n"
        print(info)

        with open(path_for_txt, "a") as file:
            file.write(info)
    
    print(f"======== Hypnogram {img_folder_path[-7:]} has been created. ========")

In [None]:
def make_predict_create_hypns(cnn, dir_edf_files_predict, training_set):

    dir_predict_dataset = "dataset_predict"

    if not os.path.isdir(dir_predict_dataset):
        print("Prediction dataset hasn't been found. Impossible to make a prediction.")
        return

    if not cnn:
        cnn = tf.keras.models.load_model('models/exit_model.keras')
        
    print(f"Summary for cnn model '{model_path}':\n")
    cnn.summary()

    for folder in os.listdir(dir_predict_dataset):

        curr_path = dir_predict_dataset + "/" + folder

        psg_file = folder + "*"  # name of the true psg signal files (psg and hyp) in the dir_edf_files_predict
        possible_psg_files = fnmatch.filter(os.listdir(dir_edf_files_predict), psg_file)

        psg_file_path = ""
        hyp_file_path = ""
        for file in possible_psg_files:
            if file.split('-')[1][0] == "H":
                hyp_file_path = dir_edf_files_predict + "/" + file
            else:
                psg_file_path = dir_edf_files_predict + "/" + file

        data, annotations = file_keeper(psg_file_path, hyp_file_path)
        sfreq = data.info.get('sfreq')
        events_from_the_file, event_id_info, annotations_stage_id = crop_set_annotations(data, annotations)

        tmax = 30.0 - 1.0 / data.info["sfreq"]
        epochs_from_the_file = mne.Epochs(
            raw=data,
            events=events_from_the_file,
            event_id=event_id_info,
            tmin=0.0,
            tmax=tmax,
            baseline=None,
            verbose=False
        )
        events = epochs_from_the_file.get_data(picks=[0])
        times_events = epochs_from_the_file.events

        print(f"Current path: {curr_path}")
        create_hypnograms(events, times_events, sfreq, curr_path, cnn, annotations_stage_id, training_set)

    print("All hypnograms have been created.")

In [2]:
from sklearn.metrics import confusion_matrix
import seaborn as sn
import pandas as pd

def create_confusion_matrix(true_array, prediction_array, path_save_img=''):
    labels_stages = [1, 2, 3, 4, 5, 6]
    result = confusion_matrix(true_array, prediction_array, labels=labels_stages)
    print(f"\nConfusion matrix:\n{result}\n")

    table = pd.DataFrame(result, range(len(labels_stages)), range(len(labels_stages)))
    fig = plt.figure()
    map = sn.heatmap(table, annot=True, fmt='g')

    if path_save_img:
        plt.savefig(path_save_img)
    else:
        plt.show()
    plt.close(fig)

# EXIT CODE

## PREDICTING

In [None]:
dir_edf_files_predict = "edf_files/x_test_y_test"
model_path = 'models/exit_model.keras'
make_predict_create_hypns(_, dir_edf_files_predict, training_set)