In [1]:
import pandas as pd
import re
import os
import sys
from collections import deque
import torch
from tqdm import tqdm
sys.path.append('../../')
from Model import CNN as cn
from Model import CommonBlock as cb
import numpy as np
project_path = os.path.abspath(os.path.relpath('../../../../', os.getcwd()))
import Utils.Preprocess as ut
data_dir= os.path.join(project_path,'BilinearNetwork\Data\PreprocessedData\CHB-MIT\detection')
constrain_path=os.path.join(project_path,'BilinearNetwork/Data/Constraint/Detection')
seizure_table_original=pd.read_csv(os.path.join(constrain_path,'seizure_ictal_summary.csv'))
import lightning as L
from torch.utils.data import random_split, DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset,TensorDataset
chb_root_path=os.path.join(project_path,'Dataset/CHB-MIT')

## Detection Latency

In [2]:

def DetectionLatency(model,file_path,seizure_start,seizure_stop,threshold=0.5,queue_size=10,threshold_count=7):
    raw_T=ut.PreprocessTool(file_path).do_preprocess()
    def get_latency(_data_stft):
        data_ds=TensorDataset(torch.Tensor(_data_stft))
        del _data_stft
        data_loader=DataLoader(data_ds,batch_size=1)
        one_sample_delay=1/raw_T.raw.info['sfreq']
        model.eval()
        y_pred_queue = deque(maxlen=queue_size)
        for i,data in enumerate(tqdm(data_loader)):
            data=data[0]
            y_pred=model(data).view(-1)
            y_pred_queue.append(y_pred.item())
            if sum(1 for y in y_pred_queue if y > threshold) >= threshold_count:
                current_delay = i * one_sample_delay
                model.train()
                return current_delay
        model.train()
        return -1
    whole_start,whole_stop=seizure_start-5,seizure_stop+5
    partial_start,partial_stop=whole_start,(int(seizure_stop+5)/4)
    
    data_stft=raw_T.overlap_events_slice_all(start=partial_start,stop=partial_stop).cut_epochs().get_epochs_stft()
    latency=get_latency(data_stft)
    if(latency==-1):
        del data_stft
        data_stft=raw_T.overlap_events_slice_all(start=whole_start,stop=whole_stop).cut_epochs().get_epochs_stft()
        latency=get_latency(data_stft)
    return latency
    
    
    
def DetectionLatency_overall(chb_root_path,Net,seizure_table,patient_id,leave_out_id):
    patient_name="chb"+str(patient_id).zfill(2)
    seizure_table_df = seizure_table[seizure_table['File Name'].str.startswith(patient_name)]
    row=seizure_table_df.iloc[leave_out_id,:]
    file_name=row['File Name']
    file_path=os.path.join(chb_root_path,"{}/{}".format(patient_name,file_name))
    start_time=row['Start Time']
    end_time=row['End Time']
    latency=DetectionLatency(model=Net,file_path=file_path,seizure_start=start_time,seizure_stop=end_time)
    print("Latency:{}".format(latency))
    return latency
def analyze_delay(arr):
    valid_numbers = [num for num in arr if num != -1]
    
    mean_value = np.mean(valid_numbers) if valid_numbers else float('nan')
    std_value = np.std(valid_numbers) if valid_numbers else float('nan')
    
    count_negative_ones = arr.count(-1)
    count_valid_numbers = len(valid_numbers)
    
    return {
        "mean": mean_value,
        "std": std_value,
        "un_detect": count_negative_ones,
        "detect": count_valid_numbers
    }
def Analyze_write_delay_to_csv(file_path, patient_id, delays):
    """
    Write the metric to the file. if the file does not exist, create the file.
    """
    result=analyze_delay(delays)
    delay_mean,delay_std,un_detect_num,detect_num=result['mean'],result['std'],result['un_detect'],result['detect']
    delay_str="{}({})".format(delay_mean,delay_std)
    detect_num="{}/{}".format(detect_num,detect_num+un_detect_num)
    #     check if the file and record exist
    if os.path.exists(file_path):
        df = pd.read_excel(file_path)
        if df[df['patient_id'] == patient_id].shape[0] > 0:
            df.loc[df['patient_id'] == patient_id, 'patient_id'] = patient_id
            df.loc[df['patient_id'] == patient_id, 'delay'] = delay_str
            df.loc[df['patient_id'] == patient_id, 'detected'] = detect_num
        else:
            df = pd.concat(
                [df, pd.DataFrame({'patient_id': [patient_id], 'delay': [delay_str], 'detected': [detect_num]})],
                ignore_index=True)
    else:
        df = pd.DataFrame({'patient_id': [patient_id], 'delay': [delay_str], 'detected': [detect_num]})
    df.to_excel(file_path, index=False)
    

In [3]:
# Net = cb.CommonNet(encoder=cn.ConvNetBlock_small(), lr=0.0003)
# DetectionLatency_overall(chb_root_path=chb_root_path,seizure_table=seizure_table_original,Net=Net,patient_id=1,leave_out_id=0)
# --
# Analyze_write_delay_to_csv(r'E:\Research\BilinearNetwork\Data\Result\detection\detect_seizure\1.xlsx',1,[1,1,-1,-1,-1,6,1.3,5])

## FDR

In [4]:
def get_interictal_files_name(chb_root_path,constrain_path,patient_id):
    def find_all_files(chb_root_path):
        all_files=[]
        for root,dirs,files in os.walk(chb_root_path):
            for file in files:
                if file.endswith('.edf'):
                    all_files.append(file)
        return all_files
    exclude_file=pd.read_csv(os.path.join(constrain_path,'exclude_File.csv'))
    exclude_patient=pd.read_csv(os.path.join(constrain_path,'exclude_Patient.csv'))
    small_constrant = list(exclude_file['File Name'])
    exclude_patient = list(exclude_patient['0'])
    large_table = find_all_files(chb_root_path)
    select_files = large_table.copy()
    for file_name in large_table:
        prelix = file_name.split('_')[0]
        if prelix in exclude_patient:
            select_files.remove(file_name)
            continue
        if file_name in small_constrant:
            select_files.remove(file_name)
            continue
    select_files=pd.Series(select_files)
    select_file_final = []
    for i in range(len(select_files) - 1):
        current_num = int(select_files[i].split('_')[1].split('.')[0])
        prelix = select_files[i].split('_')[0]
        next_num = int(select_files[i + 1].split('_')[1].split('.')[0])
        if abs(next_num - current_num) > 1:
            select_file_final.append(select_files[i])
            previous_file_name = prelix + '_' + str(current_num - 1).zfill(2) + '.edf'
            previous_previous_file_name = prelix + '_' + str(current_num - 2).zfill(2) + '.edf'
            if previous_file_name in select_files:
                select_file_final.append(previous_file_name)
            if previous_previous_file_name in select_files:
                select_file_final.append(previous_previous_file_name)
    select_file_final.extend(
        ['chb19_07.edf', 'chb19_10.edf', 'chb19_15.edf'])
    select_file_final.sort()
    files=pd.Series(select_file_final,name='file')
    patient_name="chb"+str(patient_id).zfill(2)
    files=files[files.str.startswith(patient_name)].drop_duplicates()
    
    s_files=select_files[select_files.str.startswith(patient_name)].drop_duplicates()
    
    return files,s_files[~s_files.isin(files)]

def FDR_for_one_ictal_file(model,file_path,threshold=0.5,threshold_count=7):
    raw_T=ut.PreprocessTool(file_path).do_preprocess(truncate_time=3600)
    duration=raw_T.raw.n_times / raw_T.raw.info['sfreq']
    _data_stft=raw_T.create_group_slicing_event(group_interval=20,num_events_per_group=10,duration=5).cut_epochs().get_epochs_stft()
    data_ds=TensorDataset(torch.Tensor(_data_stft))
    del _data_stft
    data_loader=DataLoader(data_ds,batch_size=10)
    alarm_count=0
    model.eval()
    for i,data in enumerate(data_loader):
        data=data[0]
        y_pred=model(data).view(-1)
        if sum(1 for y in y_pred if y > threshold) >= threshold_count:
            alarm_count+=1
    model.train()
    FDR_per_second=alarm_count/duration
    return FDR_per_second
    


def FDR_for_patient_ictal_file(chb_root_path,constrain_path,model,patient_id,**kwargs):
    print("Start evaluate the FDR for algorithm")
    files_plan=get_interictal_files_name(chb_root_path=chb_root_path,constrain_path=constrain_path,patient_id=patient_id)
    patient_name="chb"+str(patient_id).zfill(2)
    FDR_per_seconds=[]
    for ictal_file_name in files_plan:
        file_path=os.path.join(chb_root_path,"{}/{}".format(patient_name,ictal_file_name))
        one_FDR=FDR_for_one_ictal_file(model,file_path,kwargs)
        FDR_per_seconds.append(one_FDR)
    return np.mean(FDR_per_seconds)*3600
# FDR_for_patient_ictal_file(chb_root_path=chb_root_path,constrain_path=constrain_path,model=cb.CommonNet(cn.ConvNetBlock_small()),patient_id=1)




In [6]:
get_interictal_files_name(chb_root_path,constrain_path,18)

