In [16]:
import mne
import numpy as np
from matplotlib import pyplot as plt
import math
from scipy import stats, signal
import itertools
import pandas as pd
from datetime import datetime
import os
from sklearn.metrics import confusion_matrix, roc_auc_score, average_precision_score, roc_curve, precision_recall_curve


In [17]:
def results_by_thresh_type(raw, tags, detections, sr):
    # make it 1Hz
    y_actual = np.zeros(int(raw.n_times / sr), dtype=int)
    y_pred = np.zeros(int(raw.n_times / sr), dtype=int)
    amp = np.zeros(int(raw.n_times / sr), dtype=int)
    grad = np.zeros(int(raw.n_times / sr), dtype=int)
    env = np.zeros(int(raw.n_times / sr), dtype=int)
    # floor to 1 sec resolution
    for x in tags:
        y_actual[math.floor(x)] = 1
        # y_actual[int(x * 4) - 1: int(x * 4) + 2] = 1
    for i, x in detections[detections['max_index'] / sr < raw.n_times/ sr].iterrows():
        y_pred[math.floor(x['max_index'] / sr)] = 1
        if 'amp' in x['threshold_type']:  
            amp[math.floor(x['max_index'] / sr)] = 1
        if 'grad' in x['threshold_type']:  
            grad[math.floor(x['max_index'] / sr)] = 1
        if 'env' in x['threshold_type']:
            env[math.floor(x['max_index'] / sr)] = 1
        
    return pd.DataFrame({'y_actual': y_actual, 'y_pred': y_pred, 'amp': amp, 'grad': grad, 'env': env})

def get_cm(raw, tags, detections, sr):
    # make it 1Hz
    y_actual = np.zeros(int(raw.n_times / sr), dtype=int)
    y_pred = np.zeros(int(raw.n_times / sr), dtype=int)
    # floor to 1 sec resolution
    for x in tags:
        y_actual[math.floor(x)] = 1
        # y_actual[int(x * 4) - 1: int(x * 4) + 2] = 1
    for i, x in detections[detections['max_index'] / sr < raw.n_times/ sr].iterrows():
        y_pred[math.floor(x['max_index'] / sr)] = 1
        
    cm = confusion_matrix(y_actual, y_pred)
    return cm, y_actual, y_pred

def get_metrics(cm):
    if len(cm) == 0:
        cm = np.zeros((2,2), dtype=int)
    if np.squeeze(cm).ndim < 2:
        new_cm = np.zeros((2,2), dtype=int)
        new_cm[1, 1] = int(cm[0][0])
        cm = new_cm
    numerator = cm[0, 0] + cm[1, 1]
    denominator = cm[0, 0] + cm[1, 1] + cm[0, 1] + cm[1, 0]
    accuracy = numerator / denominator
    precision = cm[1, 1] / (cm[1, 1] + cm[0, 1])
    recall = cm[1, 1] / (cm[1, 1] + cm[1, 0])
    numerator = precision * recall
    denominator = (0.25 * precision) + recall
    f_score = 1.25 * numerator / denominator
    return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f_score': f_score}

In [39]:
subjects = ['396', '398', '402', '405', '406', '415', '416']

results_dict = {'file': [], 'amp_precision': [], 'amp_recall': [], 'amp_f_score': [], 'amp_cm': [],
               'grad_precision': [], 'grad_recall': [], 'grad_f_score': [], 'grad_cm': [],
               'env_precision': [], 'env_recall': [], 'env_f_score': [], 'env_cm': []}

# for subj in subjects:
for subj in [402]:
    # get tags and split to right and left
    raw = mne.io.read_raw_edf(f'C:\\Lilach\\{subj}_for_tag.edf')   
    tags_df = pd.DataFrame(raw.annotations)
    i = tags_df.loc[tags_df['description'] == 'END'].index[0]
    tags_df = tags_df[:i]
    right_tags = tags_df[tags_df['description'].str.contains('Rt')]
    left_tags = tags_df[tags_df['description'].str.contains('Lt')]
    # run over detection files
    detections_path = f'C:\\analysis\\{subj}\\thresh_tuning'
    for file in os.listdir(detections_path):
        pred_df = pd.read_csv(detections_path + '\\' + file)
        if 'RA' in file:
            tags = right_tags
        else:
            tags = left_tags
        # crop objects according to the last spike tag (+ buffer)
        last_spike_tag = tags.tail(1).iloc[0]['onset'] + 3
        curr_raw = raw.copy().crop(tmin=0, tmax=last_spike_tag)
        
        data_df = results_by_thresh_type(curr_raw, tags['onset'], pred_df, int(raw.info['sfreq']))
        cm_amp = confusion_matrix(data_df['y_actual'], data_df['amp'])
        cm_grad = confusion_matrix(data_df['y_actual'], data_df['grad'])
        cm_env = confusion_matrix(data_df['y_actual'], data_df['env'])
        
        amp_metrics = get_metrics(cm_amp)
        grad_metrics = get_metrics(cm_grad)
        env_metrics = get_metrics(cm_env)
        
        results_dict['file'].append(file)
        results_dict['amp_precision'].append(amp_metrics['precision'])
        results_dict['grad_precision'].append(grad_metrics['precision'])
        results_dict['env_precision'].append(env_metrics['precision'])
        results_dict['amp_recall'].append(amp_metrics['recall'])
        results_dict['grad_recall'].append(grad_metrics['recall'])
        results_dict['env_recall'].append(env_metrics['recall'])
        results_dict['amp_f_score'].append(amp_metrics['f_score'])
        results_dict['grad_f_score'].append(grad_metrics['f_score'])
        results_dict['env_f_score'].append(env_metrics['f_score'])
        results_dict['amp_cm'].append(cm_amp)
        results_dict['grad_cm'].append(cm_grad)
        results_dict['env_cm'].append(cm_env)
        
        
#         cm, y_actual, y_pred = get_cm(raw, tags['onset'], pred_df, int(raw.info['sfreq']))
#         metrics = get_metrics(cm)
#         results_dict['file'].append(file)
#         results_dict['accuracy'].append(metrics['accuracy'])
#         results_dict['precision'].append(metrics['precision'])
#         results_dict['recall'].append(metrics['recall'])
#         results_dict['f_score'].append(metrics['f_score'])
#         results_dict['cm'].append(cm)
#         results_dict['tags_count'].append(len(tags))
#         results_dict['roc_auc'].append(roc_auc_score(y_actual, y_pred))
#         results_dict['avg_precision'].append(average_precision_score(y_actual, y_pred))



In [42]:
thresh_metrics_df = pd.DataFrame(results_dict)
thresh_metrics_df['thresh'] = thresh_metrics_df.apply(lambda row: int(row.file.split('_')[-1].replace('t', '').replace('.csv', '')), axis=1)
thresh_cz_ref = thresh_metrics_df[~thresh_metrics_df['file'].str.contains('AH2')]
thresh_bi_ref = thresh_metrics_df[thresh_metrics_df['file'].str.contains('AH2')]
thresh_cz_ref = thresh_cz_ref.sort_values('grad_f_score')
thresh_bi_ref = thresh_bi_ref.sort_values('grad_f_score')


In [43]:
thresh_cz_ref

Unnamed: 0,file,amp_precision,amp_recall,amp_f_score,amp_cm,grad_precision,grad_recall,grad_f_score,grad_cm,env_precision,env_recall,env_f_score,env_cm,thresh
8,402_LAH1_t4.csv,0.625,0.5,0.595238,"[[209, 3], [5, 5]]",0.5,0.9,0.54878,"[[203, 9], [1, 9]]",0.16129,1.0,0.193798,"[[160, 52], [0, 10]]",4
22,402_RAH1_t4.csv,0.285714,0.142857,0.238095,"[[204, 5], [12, 2]]",0.5,0.928571,0.550847,"[[196, 13], [1, 13]]",0.26,0.928571,0.303738,"[[172, 37], [1, 13]]",4
9,402_LAH1_t5.csv,1.0,0.4,0.769231,"[[212, 0], [6, 4]]",0.666667,0.8,0.689655,"[[208, 4], [2, 8]]",0.236842,0.9,0.277778,"[[183, 29], [1, 9]]",5
23,402_RAH1_t5.csv,1.0,0.142857,0.454545,"[[209, 0], [12, 2]]",0.666667,0.857143,0.697674,"[[203, 6], [2, 12]]",0.285714,0.714286,0.324675,"[[184, 25], [4, 10]]",5
7,402_LAH1_t10.csv,,0.0,,"[[212, 0], [10, 0]]",0.833333,0.5,0.735294,"[[211, 1], [5, 5]]",0.714286,0.5,0.657895,"[[210, 2], [5, 5]]",10
13,402_LAH1_t9.csv,,0.0,,"[[212, 0], [10, 0]]",0.857143,0.6,0.789474,"[[211, 1], [4, 6]]",0.7,0.7,0.7,"[[209, 3], [3, 7]]",9
10,402_LAH1_t6.csv,1.0,0.2,0.555556,"[[212, 0], [8, 2]]",0.875,0.7,0.833333,"[[211, 1], [3, 7]]",0.36,0.9,0.409091,"[[196, 16], [1, 9]]",6
11,402_LAH1_t7.csv,1.0,0.1,0.357143,"[[212, 0], [9, 1]]",0.875,0.7,0.833333,"[[211, 1], [3, 7]]",0.470588,0.8,0.512821,"[[203, 9], [2, 8]]",7
12,402_LAH1_t8.csv,1.0,0.1,0.357143,"[[212, 0], [9, 1]]",0.875,0.7,0.833333,"[[211, 1], [3, 7]]",0.7,0.7,0.7,"[[209, 3], [3, 7]]",8
21,402_RAH1_t10.csv,,0.0,,"[[209, 0], [14, 0]]",1.0,0.5,0.833333,"[[209, 0], [7, 7]]",1.0,0.357143,0.735294,"[[209, 0], [9, 5]]",10


In [22]:
pd.DataFrame(raw.annotations)

Unnamed: 0,onset,duration,description,orig_time
0,41.1777,0.0,Rt spike,2007-09-23 23:30:00+00:00
1,61.0798,0.0,Lt spike,2007-09-23 23:30:00+00:00
2,61.1191,0.0,Lt spike,2007-09-23 23:30:00+00:00
3,61.8968,0.0,Rt spike,2007-09-23 23:30:00+00:00
4,97.1905,0.0,Rt spike,2007-09-23 23:30:00+00:00
5,121.2242,0.0,Lt spike,2007-09-23 23:30:00+00:00
6,121.4411,0.0,Rt spike,2007-09-23 23:30:00+00:00
7,129.6388,0.0,Lt spike,2007-09-23 23:30:00+00:00
8,138.496,0.0,Rt spike,2007-09-23 23:30:00+00:00
9,140.2353,0.0,Rt spike,2007-09-23 23:30:00+00:00
