In [1]:
import mne
import numpy as np
from matplotlib import pyplot as plt
import math
from scipy import stats, signal
import itertools
import pandas as pd
from datetime import datetime
import os
from sklearn.metrics import confusion_matrix, roc_auc_score, average_precision_score, roc_curve, precision_recall_curve


In [2]:
%run ./detector.ipynb

396
RAH1
LAH1
RAH1-RAH2
LAH1-LAH2
398
RAH1
LAH1
RAH1-RAH2
LAH1-LAH2
402
RAH1
LAH1
RAH1-RAH2
LAH1-LAH2
405
LAH1
LAH1-LAH2
406
RAH1
LAH1
RAH1-RAH2
LAH1-LAH2
415
RAH1
LAH1
RAH1-RAH2
LAH1-LAH2
416
RAH1
LAH1
RAH1-RAH2
LAH1-LAH2


In [88]:
def results_by_thresh_type(raw, tags, detections, sr):
    # make it 1Hz
    y_actual = np.zeros(int(raw.n_times / sr), dtype=int)
    y_pred = np.zeros(int(raw.n_times / sr), dtype=int)
    amp = np.zeros(int(raw.n_times / sr), dtype=int)
    grad = np.zeros(int(raw.n_times / sr), dtype=int)
    env = np.zeros(int(raw.n_times / sr), dtype=int)
    # floor to 1 sec resolution
    for x in tags:
        y_actual[math.floor(x)] = 1
        # y_actual[int(x * 4) - 1: int(x * 4) + 2] = 1
    for i, x in detections[detections['max_index'] / sr < raw.n_times/ sr].iterrows():
        y_pred[math.floor(x['max_index'] / sr)] = 1
        if 'amp' in x['threshold_type']:  
            amp[math.floor(x['max_index'] / sr)] = 1
        if 'grad' in x['threshold_type']:  
            grad[math.floor(x['max_index'] / sr)] = 1
        if 'env' in x['threshold_type']:
            env[math.floor(x['max_index'] / sr)] = 1
        
    return pd.DataFrame({'y_actual': y_actual, 'y_pred': y_pred, 'amp': amp, 'grad': grad, 'env': env})

# true negatives, false negatives, true positives, false positives
def get_cm(raw, tags, detections, sr):
    # make it 1Hz
    y_actual = np.zeros(int(raw.n_times / sr), dtype=int)
    y_pred = np.zeros(int(raw.n_times / sr), dtype=int)
    # floor to 1 sec resolution
    for x in tags:
        y_actual[math.floor(x)] = 1
        # y_actual[int(x * 4) - 1: int(x * 4) + 2] = 1
    for i, x in detections[detections['max_index'] / sr < raw.n_times/ sr].iterrows():
        y_pred[math.floor(x['max_index'] / sr)] = 1
        
    cm = confusion_matrix(y_actual, y_pred)
    return cm, y_actual, y_pred

def get_metrics(cm):
    if len(cm) == 0:
        cm = np.zeros((2,2), dtype=int)
    if np.squeeze(cm).ndim < 2:
        new_cm = np.zeros((2,2), dtype=int)
        new_cm[1, 1] = int(cm[0][0])
        cm = new_cm
    numerator = cm[0, 0] + cm[1, 1]
    denominator = cm[0, 0] + cm[1, 1] + cm[0, 1] + cm[1, 0]
    accuracy = numerator / denominator
    precision = cm[1, 1] / (cm[1, 1] + cm[0, 1])
    recall = cm[1, 1] / (cm[1, 1] + cm[1, 0])
    numerator = precision * recall
    denominator = (0.25 * precision) + recall
    f_score = 1.25 * numerator / denominator
    return {'accuracy': "%0.2f" % accuracy, 'precision': "%0.2f" % precision, 'recall': "%0.2f" % recall, 'f_score': "%0.2f" % f_score}

In [65]:
subjects = ['396', '398', '402', '406', '415', '416']
# subjects = ['396', '398', '402', '406']

for subj in subjects:
    print(subj)
    results_dict = {'file': [], 'amp_precision': [], 'amp_recall': [], 'amp_f_score': [], 'amp_cm': [],
               'grad_precision': [], 'grad_recall': [], 'grad_f_score': [], 'grad_cm': [],
               'env_precision': [], 'env_recall': [], 'env_f_score': [], 'env_cm': []}
    # get tags and split to right and left
    raw = mne.io.read_raw_edf(f'C:\\Lilach\\{subj}_for_tag_filtered.edf')   
    tags_df = pd.DataFrame(raw.annotations)
    last_spike_tag = tags_df.loc[tags_df['description'] == 'END']
    i = last_spike_tag.index[0]
    tags_df = tags_df[:i]
    right_tags = tags_df[tags_df['description'].str.contains('Rt')]
    left_tags = tags_df[tags_df['description'].str.contains('Lt')]
    
    # run over detection files and calc metrics
    detections_path = f'C:\\analysis\\{subj}\\thresh_tuning'
    for file in os.listdir(detections_path):
        pred_df = pd.read_csv(detections_path + '\\' + file)
        if 'RA' in file:
            tags = right_tags
        else:
            tags = left_tags
        
        if not tags.empty:
            # crop raw according to the last spike tag (+ buffer)
            curr_raw = raw.copy().crop(tmin=0, tmax=tags_df['onset'].iloc[-1] + 3)

            data_df = results_by_thresh_type(curr_raw, tags['onset'], pred_df, int(raw.info['sfreq']))
            cm_amp = confusion_matrix(data_df['y_actual'], data_df['amp'])
            cm_grad = confusion_matrix(data_df['y_actual'], data_df['grad'])
            cm_env = confusion_matrix(data_df['y_actual'], data_df['env'])

            amp_metrics = get_metrics(cm_amp)
            grad_metrics = get_metrics(cm_grad)
            env_metrics = get_metrics(cm_env)

            results_dict['file'].append(file)
            results_dict['amp_precision'].append(amp_metrics['precision'])
            results_dict['grad_precision'].append(grad_metrics['precision'])
            results_dict['env_precision'].append(env_metrics['precision'])
            results_dict['amp_recall'].append(amp_metrics['recall'])
            results_dict['grad_recall'].append(grad_metrics['recall'])
            results_dict['env_recall'].append(env_metrics['recall'])
            results_dict['amp_f_score'].append(amp_metrics['f_score'])
            results_dict['grad_f_score'].append(grad_metrics['f_score'])
            results_dict['env_f_score'].append(env_metrics['f_score'])
            results_dict['amp_cm'].append(cm_amp)
            results_dict['grad_cm'].append(cm_grad)
            results_dict['env_cm'].append(cm_env)
        
    thresh_metrics_df = pd.DataFrame(results_dict)
    thresh_metrics_df['thresh'] = thresh_metrics_df.apply(lambda row: int(row.file.split('_')[-1].replace('t', '').replace('.csv', '')), axis=1)
    thresh_cz_ref = thresh_metrics_df[~thresh_metrics_df['file'].str.contains('AH2')]
    thresh_bi_ref = thresh_metrics_df[thresh_metrics_df['file'].str.contains('AH2')]
    
    # each df contains all thresholds for 1 channel- right/left, cz/bi 
    thresh_cz_right = thresh_cz_ref[thresh_cz_ref['file'].str.contains('RA')]
    thresh_cz_left = thresh_cz_ref[~thresh_cz_ref['file'].str.contains('RA')]
    thresh_bi_right = thresh_bi_ref[thresh_bi_ref['file'].str.contains('RA')]
    thresh_bi_left = thresh_bi_ref[~thresh_bi_ref['file'].str.contains('RA')]
    channels_dict = {'channel': ['RAH1', 'LAH1', 'RAH1-RAH2', 'LAH1-LAH2'], 'amp': [], 'grad': [], 'env': []}

    for curr_df in [thresh_cz_right, thresh_cz_left, thresh_bi_right, thresh_bi_left]:
        if curr_df.empty:
            channels_dict['amp'].append(0)
            channels_dict['grad'].append(0)
            channels_dict['env'].append(0)
        else:
            channels_dict['amp'].append(5 if curr_df['amp_f_score'].isnull().all() else curr_df.loc[curr_df['amp_f_score'].idxmax()]['thresh'])
            channels_dict['grad'].append(5 if curr_df['grad_f_score'].isnull().all() else curr_df.loc[curr_df['grad_f_score'].idxmax()]['thresh'])
            channels_dict['env'].append(5 if curr_df['env_f_score'].isnull().all() else curr_df.loc[curr_df['env_f_score'].idxmax()]['thresh'])

#     edf = f'C:\\UCLA\\{subj}_cz+bi_full_filtered.edf'
    edf = f'C:\\Lilach\\{subj}_for_tag_filtered.edf'
    # run full night detection with the best thresholds
    for i, row in pd.DataFrame(channels_dict).iterrows():
        channel, amp, grad, env = row['channel'], row['amp'], row['grad'], row['env']
#         if amp != 0:
#             detect_subj(edf, channel, amp, grad, env, 
#                    f'C:\\analysis\\{subj}\\best_thresh_detect\\{subj}_{channel}_t{amp}_{grad}_{env}_for_tag2.csv')
        print(channel, amp, grad, env)

396




RAH1 0 0 0
LAH1 5 7 9
RAH1-RAH2 0 0 0
LAH1-LAH2 5 7 10
398




RAH1 4 10 7
LAH1 4 8 9
RAH1-RAH2 4 8 8
LAH1-LAH2 4 8 8
402




RAH1 5 6 9
LAH1 5 6 10
RAH1-RAH2 4 7 9
LAH1-LAH2 4 10 8
406




RAH1 4 5 8
LAH1 5 4 7
RAH1-RAH2 5 4 10
LAH1-LAH2 5 4 7
415




RAH1 5 4 5
LAH1 4 4 5
RAH1-RAH2 4 4 6
LAH1-LAH2 4 4 4
416




RAH1 4 4 5
LAH1 5 4 5
RAH1-RAH2 4 4 7
LAH1-LAH2 4 5 6




In [68]:
pd.DataFrame(thresh_cz_ref)

Unnamed: 0,file,amp_precision,amp_recall,amp_f_score,amp_cm,grad_precision,grad_recall,grad_f_score,grad_cm,env_precision,env_recall,env_f_score,env_cm,thresh
7,416_LAH1_t10.csv,,0.0,,"[[178, 0], [25, 0]]",,0.0,,"[[178, 0], [25, 0]]",,0.0,,"[[178, 0], [25, 0]]",10
8,416_LAH1_t4.csv,0.0,0.0,,"[[176, 2], [25, 0]]",0.428571,0.12,0.283019,"[[174, 4], [22, 3]]",0.433333,0.52,0.448276,"[[161, 17], [12, 13]]",4
9,416_LAH1_t5.csv,,0.0,,"[[178, 0], [25, 0]]",1.0,0.04,0.172414,"[[178, 0], [24, 1]]",0.538462,0.28,0.454545,"[[172, 6], [18, 7]]",5
10,416_LAH1_t6.csv,,0.0,,"[[178, 0], [25, 0]]",,0.0,,"[[178, 0], [25, 0]]",0.428571,0.12,0.283019,"[[174, 4], [22, 3]]",6
11,416_LAH1_t7.csv,,0.0,,"[[178, 0], [25, 0]]",,0.0,,"[[178, 0], [25, 0]]",0.25,0.04,0.121951,"[[175, 3], [24, 1]]",7
12,416_LAH1_t8.csv,,0.0,,"[[178, 0], [25, 0]]",,0.0,,"[[178, 0], [25, 0]]",0.0,0.0,,"[[177, 1], [25, 0]]",8
13,416_LAH1_t9.csv,,0.0,,"[[178, 0], [25, 0]]",,0.0,,"[[178, 0], [25, 0]]",,0.0,,"[[178, 0], [25, 0]]",9
21,416_RAH1_t10.csv,,0.0,,"[[169, 0], [34, 0]]",,0.0,,"[[169, 0], [34, 0]]",,0.0,,"[[169, 0], [34, 0]]",10
22,416_RAH1_t4.csv,0.857143,0.176471,0.483871,"[[168, 1], [28, 6]]",0.714286,0.294118,0.555556,"[[165, 4], [24, 10]]",0.488372,0.617647,0.509709,"[[147, 22], [13, 21]]",4
23,416_RAH1_t5.csv,1.0,0.029412,0.131579,"[[169, 0], [33, 1]]",0.777778,0.205882,0.5,"[[167, 2], [27, 7]]",0.733333,0.323529,0.585106,"[[165, 4], [23, 11]]",5


In [69]:
pd.DataFrame(channels_dict)

Unnamed: 0,channel,amp,grad,env
0,RAH1,4,4,5
1,LAH1,5,4,5
2,RAH1-RAH2,4,4,7
3,LAH1-LAH2,4,5,6


In [24]:
thresh_cz_left

Unnamed: 0,file,amp_precision,amp_recall,amp_f_score,amp_cm,grad_precision,grad_recall,grad_f_score,grad_cm,env_precision,env_recall,env_f_score,env_cm,thresh
7,396_LAH1_t10.csv,,0.0,,"[[176, 0], [17, 0]]",1.0,0.117647,0.4,"[[176, 0], [15, 2]]",0.666667,0.470588,0.615385,"[[172, 4], [9, 8]]",10
8,396_LAH1_t4.csv,0.75,0.176471,0.454545,"[[175, 1], [14, 3]]",0.608696,0.823529,0.642202,"[[167, 9], [3, 14]]",0.378378,0.823529,0.424242,"[[153, 23], [3, 14]]",4
9,396_LAH1_t5.csv,1.0,0.176471,0.517241,"[[176, 0], [14, 3]]",0.823529,0.823529,0.823529,"[[173, 3], [3, 14]]",0.48,0.705882,0.512821,"[[163, 13], [5, 12]]",5
10,396_LAH1_t6.csv,1.0,0.058824,0.238095,"[[176, 0], [16, 1]]",0.846154,0.647059,0.797101,"[[174, 2], [6, 11]]",0.571429,0.705882,0.594059,"[[167, 9], [5, 12]]",6
11,396_LAH1_t7.csv,,0.0,,"[[176, 0], [17, 0]]",1.0,0.529412,0.849057,"[[176, 0], [8, 9]]",0.526316,0.588235,0.537634,"[[167, 9], [7, 10]]",7
12,396_LAH1_t8.csv,,0.0,,"[[176, 0], [17, 0]]",1.0,0.294118,0.675676,"[[176, 0], [12, 5]]",0.588235,0.588235,0.588235,"[[169, 7], [7, 10]]",8
13,396_LAH1_t9.csv,,0.0,,"[[176, 0], [17, 0]]",1.0,0.176471,0.517241,"[[176, 0], [14, 3]]",0.692308,0.529412,0.652174,"[[172, 4], [8, 9]]",9


In [26]:
pd.DataFrame(channels_dict)

Unnamed: 0,channel,amp,grad,env
0,RAH1,0,0,0
1,LAH1,5,7,9
2,RAH1-RAH2,0,0,0
3,LAH1-LAH2,5,7,9


In [107]:
subjects = ['396', '398', '402', '406', '415', '416']

results_dict = {'subject': [], 'channel': [], 'amp': [], 'grad': [], 'env': [], 
                'tags_count':[], 'precision': [], 'recall': [], 'f_score': [], 'cm': [], 
                'amp_precision': [], 'amp_recall': [], 'amp_f_score': [], 'amp_cm': [],
               'grad_precision': [], 'grad_recall': [], 'grad_f_score': [], 'grad_cm': [],
               'env_precision': [], 'env_recall': [], 'env_f_score': [], 'env_cm': []}

# channels_dict = {'subject':[], 'channel': [], 'amp': [], 'grad': [], 'env': [], 
#                      'amp_precision': [], 'amp_recall': [], 'amp_f_score': [], 'grad_precision': [], 'grad_recall': [], 
#                      'grad_f_score': [], 'env_precision': [], 'env_recall': [], 'env_f_score': []}
for subj in subjects:
    print(subj)
    # get tags and split to right and left
    raw = mne.io.read_raw_edf(f'C:\\Lilach\\{subj}_for_tag_filtered.edf')   
    tags_df = pd.DataFrame(raw.annotations)
    last_spike_tag = tags_df.loc[tags_df['description'] == 'END']
    i = last_spike_tag.index[0]
    tags_df = tags_df[:i]
    right_tags = tags_df[tags_df['description'].str.contains('Rt')]
    left_tags = tags_df[tags_df['description'].str.contains('Lt')]
    
    # run over detection files and calc metrics
    detections_path = f'C:\\analysis\\{subj}\\best_thresh_detect'
    for file in [x for x in os.listdir(detections_path) if 'tag' in x]:
        pred_df = pd.read_csv(detections_path + '\\' + file)
        # 2 or more thresholds
        pred_df = pred_df[pred_df['threshold_type'].str.contains(',')]
        if 'RA' in file:
            tags = right_tags
        else:
            tags = left_tags
        
        if not tags.empty:
            # crop raw according to the last spike tag (+ buffer)
            curr_raw = raw.copy().crop(tmin=0, tmax=tags_df['onset'].iloc[-1] + 3)

            # general metrics
            cm, y_actual, y_pred = get_cm(curr_raw, tags['onset'], pred_df, int(raw.info['sfreq']))
            metrics = get_metrics(cm)
            thresh_from_file = file.split('t')[1].split('_')
            results_dict['subject'].append(subj)
            results_dict['channel'].append(file.split('_')[1])
            results_dict['amp'].append(int(thresh_from_file[0]))
            results_dict['grad'].append(int(thresh_from_file[1]))
            results_dict['env'].append(int(thresh_from_file[2]))
            results_dict['tags_count'].append(len(tags))
            results_dict['precision'].append(metrics['precision'])
            results_dict['recall'].append(metrics['recall'])
            results_dict['f_score'].append(metrics['f_score'])
            results_dict['cm'].append(cm)
            
            # metrics per thresh
            data_df = results_by_thresh_type(curr_raw, tags['onset'], pred_df, int(raw.info['sfreq']))
            cm_amp = confusion_matrix(data_df['y_actual'], data_df['amp'])
            cm_grad = confusion_matrix(data_df['y_actual'], data_df['grad'])
            cm_env = confusion_matrix(data_df['y_actual'], data_df['env'])
            amp_metrics = get_metrics(cm_amp)
            grad_metrics = get_metrics(cm_grad)
            env_metrics = get_metrics(cm_env)
            results_dict['amp_precision'].append(amp_metrics['precision'])
            results_dict['grad_precision'].append(grad_metrics['precision'])
            results_dict['env_precision'].append(env_metrics['precision'])
            results_dict['amp_recall'].append(amp_metrics['recall'])
            results_dict['grad_recall'].append(grad_metrics['recall'])
            results_dict['env_recall'].append(env_metrics['recall'])
            results_dict['amp_f_score'].append(amp_metrics['f_score'])
            results_dict['grad_f_score'].append(grad_metrics['f_score'])
            results_dict['env_f_score'].append(env_metrics['f_score'])
            results_dict['amp_cm'].append(cm_amp)
            results_dict['grad_cm'].append(cm_grad)
            results_dict['env_cm'].append(cm_env)
        
#     thresh_metrics_df = pd.DataFrame(results_dict)
#     thresh_metrics_df['thresh'] = thresh_metrics_df.apply(lambda row: int(row.file.split('_')[-1].replace('t', '').replace('.csv', '')), axis=1)
#     thresh_cz_ref = thresh_metrics_df[~thresh_metrics_df['file'].str.contains('AH2')]
#     thresh_bi_ref = thresh_metrics_df[thresh_metrics_df['file'].str.contains('AH2')]
    
#     # each df contains all thresholds for 1 channel- right/left, cz/bi 
#     thresh_cz_right = thresh_cz_ref[thresh_cz_ref['file'].str.contains('RA')]
#     thresh_cz_left = thresh_cz_ref[~thresh_cz_ref['file'].str.contains('RA')]
#     thresh_bi_right = thresh_bi_ref[thresh_bi_ref['file'].str.contains('RA')]
#     thresh_bi_left = thresh_bi_ref[~thresh_bi_ref['file'].str.contains('RA')]

#     # init df with the best thresh and their metrics
#     for curr_df, curr_channel in zip([thresh_cz_right, thresh_cz_left, thresh_bi_right, thresh_bi_left], ['RAH1', 'LAH1', 'RAH1-RAH2', 'LAH1-LAH2']):
#         if not curr_df.empty:
#             channels_dict['subject'].append(subj)
#             channels_dict['channel'].append(curr_channel)
#             amp_best = None if curr_df['amp_f_score'].isnull().all() else curr_df.loc[curr_df['amp_f_score'].idxmax()]
#             channels_dict['amp'].append(None if amp_best is None else amp_best['thresh'])
#             channels_dict['amp_precision'].append(None if amp_best is None else amp_best['amp_precision'])
#             channels_dict['amp_recall'].append(None if amp_best is None else amp_best['amp_recall'])
#             channels_dict['amp_f_score'].append(None if amp_best is None else amp_best['amp_f_score'])
            
#             grad_best = None if curr_df['grad_f_score'].isnull().all() else curr_df.loc[curr_df['grad_f_score'].idxmax()]
#             channels_dict['grad'].append(None if grad_best is None else grad_best['thresh'])
#             channels_dict['grad_precision'].append(None if grad_best is None else grad_best['grad_precision'])
#             channels_dict['grad_recall'].append(None if grad_best is None else grad_best['grad_recall'])
#             channels_dict['grad_f_score'].append(None if grad_best is None else grad_best['grad_f_score'])
            
#             env_best = None if curr_df['env_f_score'].isnull().all() else curr_df.loc[curr_df['env_f_score'].idxmax()]
#             channels_dict['env'].append(None if env_best is None else env_best['thresh'])
#             channels_dict['env_precision'].append(None if env_best is None else env_best['env_precision'])
#             channels_dict['env_recall'].append(None if env_best is None else env_best['env_recall'])
#             channels_dict['env_f_score'].append(None if env_best is None else env_best['env_f_score'])
            

396
398




402
406
415




416




In [108]:
final = pd.DataFrame(results_dict)
final.to_csv('combined.csv')

In [105]:
combined_df = pred_df[pred_df['threshold_type'].str.contains(',')]

In [104]:
len(pred_df)

101

In [106]:
len(combined_df)

24