In [1]:
import mne
import numpy as np
from matplotlib import pyplot as plt
import math
from scipy import stats, signal
import itertools
import pandas as pd
from datetime import datetime
import os
from sklearn.metrics import confusion_matrix, roc_auc_score, average_precision_score, roc_curve, precision_recall_curve


In [2]:
%run ./detector.ipynb

In [3]:
def results_by_thresh_type(raw, tags, detections, sr):
    # make it 1Hz
    y_actual = np.zeros(int(raw.n_times / sr), dtype=int)
    y_pred = np.zeros(int(raw.n_times / sr), dtype=int)
    amp = np.zeros(int(raw.n_times / sr), dtype=int)
    grad = np.zeros(int(raw.n_times / sr), dtype=int)
    env = np.zeros(int(raw.n_times / sr), dtype=int)
    # floor to 1 sec resolution
    for x in tags:
        y_actual[math.floor(x)] = 1
        # y_actual[int(x * 4) - 1: int(x * 4) + 2] = 1
    for i, x in detections[detections['max_index'] / sr < raw.n_times/ sr].iterrows():
        y_pred[math.floor(x['max_index'] / sr)] = 1
        if 'amp' in x['threshold_type']:  
            amp[math.floor(x['max_index'] / sr)] = 1
        if 'grad' in x['threshold_type']:  
            grad[math.floor(x['max_index'] / sr)] = 1
        if 'env' in x['threshold_type']:
            env[math.floor(x['max_index'] / sr)] = 1
        
    return pd.DataFrame({'y_actual': y_actual, 'y_pred': y_pred, 'amp': amp, 'grad': grad, 'env': env})

def get_cm(raw, tags, detections, sr):
    # make it 1Hz
    y_actual = np.zeros(int(raw.n_times / sr), dtype=int)
    y_pred = np.zeros(int(raw.n_times / sr), dtype=int)
    # floor to 1 sec resolution
    for x in tags:
        y_actual[math.floor(x)] = 1
        # y_actual[int(x * 4) - 1: int(x * 4) + 2] = 1
    for i, x in detections[detections['max_index'] / sr < raw.n_times/ sr].iterrows():
        y_pred[math.floor(x['max_index'] / sr)] = 1
        
    cm = confusion_matrix(y_actual, y_pred)
    return cm, y_actual, y_pred

def get_metrics(cm):
    if len(cm) == 0:
        cm = np.zeros((2,2), dtype=int)
    if np.squeeze(cm).ndim < 2:
        new_cm = np.zeros((2,2), dtype=int)
        new_cm[1, 1] = int(cm[0][0])
        cm = new_cm
    numerator = cm[0, 0] + cm[1, 1]
    denominator = cm[0, 0] + cm[1, 1] + cm[0, 1] + cm[1, 0]
    accuracy = numerator / denominator
    precision = cm[1, 1] / (cm[1, 1] + cm[0, 1])
    recall = cm[1, 1] / (cm[1, 1] + cm[1, 0])
    numerator = precision * recall
    denominator = (0.25 * precision) + recall
    f_score = 1.25 * numerator / denominator
    return {'accuracy': accuracy, 'precision': precision, 'recall': recall, 'f_score': f_score}

In [51]:
# subjects = ['396', '398', '402', '405', '406', '415', '416']
subjects = ['406', '415', '416']

for subj in subjects:
    print(subj)
    results_dict = {'file': [], 'amp_precision': [], 'amp_recall': [], 'amp_f_score': [], 'amp_cm': [],
               'grad_precision': [], 'grad_recall': [], 'grad_f_score': [], 'grad_cm': [],
               'env_precision': [], 'env_recall': [], 'env_f_score': [], 'env_cm': []}
    # get tags and split to right and left
    raw = mne.io.read_raw_edf(f'C:\\Lilach\\{subj}_for_tag.edf')   
    tags_df = pd.DataFrame(raw.annotations)
    last_spike_tag = tags_df.loc[tags_df['description'] == 'END']
    i = last_spike_tag.index[0]
    tags_df = tags_df[:i]
    right_tags = tags_df[tags_df['description'].str.contains('Rt')]
    left_tags = tags_df[tags_df['description'].str.contains('Lt')]
    
    # run over detection files and calc metrics
    detections_path = f'C:\\analysis\\{subj}\\thresh_tuning'
    for file in os.listdir(detections_path):
        pred_df = pd.read_csv(detections_path + '\\' + file)
        if 'RA' in file:
            tags = right_tags
        else:
            tags = left_tags
        
        if not tags.empty:
            # crop raw according to the last spike tag (+ buffer)
            curr_raw = raw.copy().crop(tmin=0, tmax=tags_df['onset'].iloc[-1] + 3)

            data_df = results_by_thresh_type(curr_raw, tags['onset'], pred_df, int(raw.info['sfreq']))
            cm_amp = confusion_matrix(data_df['y_actual'], data_df['amp'])
            cm_grad = confusion_matrix(data_df['y_actual'], data_df['grad'])
            cm_env = confusion_matrix(data_df['y_actual'], data_df['env'])

            amp_metrics = get_metrics(cm_amp)
            grad_metrics = get_metrics(cm_grad)
            env_metrics = get_metrics(cm_env)

            results_dict['file'].append(file)
            results_dict['amp_precision'].append(amp_metrics['precision'])
            results_dict['grad_precision'].append(grad_metrics['precision'])
            results_dict['env_precision'].append(env_metrics['precision'])
            results_dict['amp_recall'].append(amp_metrics['recall'])
            results_dict['grad_recall'].append(grad_metrics['recall'])
            results_dict['env_recall'].append(env_metrics['recall'])
            results_dict['amp_f_score'].append(amp_metrics['f_score'])
            results_dict['grad_f_score'].append(grad_metrics['f_score'])
            results_dict['env_f_score'].append(env_metrics['f_score'])
            results_dict['amp_cm'].append(cm_amp)
            results_dict['grad_cm'].append(cm_grad)
            results_dict['env_cm'].append(cm_env)
        
    thresh_metrics_df = pd.DataFrame(results_dict)
    thresh_metrics_df['thresh'] = thresh_metrics_df.apply(lambda row: int(row.file.split('_')[-1].replace('t', '').replace('.csv', '')), axis=1)
    thresh_cz_ref = thresh_metrics_df[~thresh_metrics_df['file'].str.contains('AH2')]
    thresh_bi_ref = thresh_metrics_df[thresh_metrics_df['file'].str.contains('AH2')]
    
    # each df contains all thresholds for 1 channel- right/left, cz/bi 
    thresh_cz_right = thresh_cz_ref[thresh_cz_ref['file'].str.contains('RA')]
    thresh_cz_left = thresh_cz_ref[~thresh_cz_ref['file'].str.contains('RA')]
    thresh_bi_right = thresh_bi_ref[thresh_bi_ref['file'].str.contains('RA')]
    thresh_bi_left = thresh_bi_ref[~thresh_bi_ref['file'].str.contains('RA')]
    channels_dict = {'channel': ['RAH1', 'LAH1', 'RAH1-RAH2', 'LAH1-LAH2'], 'amp': [], 'grad': [], 'env': []}

    # get the best thresh for each prop
    for curr_df in [thresh_cz_right, thresh_cz_left, thresh_bi_right, thresh_bi_left]:
        if curr_df.empty:
            channels_dict['amp'].append(0)
            channels_dict['grad'].append(0)
            channels_dict['env'].append(0)
        else:
            channels_dict['amp'].append(5 if curr_df['amp_f_score'].isnull().all() else curr_df.loc[curr_df['amp_f_score'].idxmax()]['thresh'])
            channels_dict['grad'].append(5 if curr_df['grad_f_score'].isnull().all() else curr_df.loc[curr_df['grad_f_score'].idxmax()]['thresh'])
            channels_dict['env'].append(5 if curr_df['env_f_score'].isnull().all() else curr_df.loc[curr_df['env_f_score'].idxmax()]['thresh'])

    edf = f'C:\\UCLA\\{subj}_cz+bi_full.edf'
    
    # run full night detection with the best thresholds
    for i, row in pd.DataFrame(channels_dict).iterrows():
        channel, amp, grad, env = row['channel'], row['amp'], row['grad'], row['env']
        if amp != 0:
            detect_subj(edf, channel, amp, grad, env, 
                   f'C:\\analysis\\{subj}\\{subj}_{channel}_t{amp}_{grad}_{env}_no_filter2.csv')
        print(channel, amp, grad, env)

406




RAH1 4 5 8
LAH1 5 4 7
RAH1-RAH2 5 4 10
LAH1-LAH2 5 4 7
415




RAH1 5 4 5
LAH1 4 4 5
RAH1-RAH2 4 4 6
LAH1-LAH2 4 4 4
416




RAH1 4 4 5
LAH1 5 4 5
RAH1-RAH2 4 4 7
LAH1-LAH2 4 5 6


In [47]:
channels_dict

{'channel': ['RAH1', 'LAH1', 'RAH1-RAH2', 'LAH1-LAH2'],
 'amp': [4, 0, 4, 4],
 'grad': [4, 4, 4, 5],
 'env': [5, 5, 7, 6]}

415




KeyError: nan

In [49]:
curr_df

Unnamed: 0,file,amp_precision,amp_recall,amp_f_score,amp_cm,grad_precision,grad_recall,grad_f_score,grad_cm,env_precision,env_recall,env_f_score,env_cm,thresh
21,415_RAH1_t10.csv,,0.0,,"[[180, 0], [22, 0]]",,0.0,,"[[180, 0], [22, 0]]",,0.0,,"[[180, 0], [22, 0]]",10
22,415_RAH1_t4.csv,0.0,0.0,,"[[177, 3], [22, 0]]",0.666667,0.272727,0.517241,"[[177, 3], [16, 6]]",0.363636,0.363636,0.363636,"[[166, 14], [14, 8]]",4
23,415_RAH1_t5.csv,,0.0,,"[[180, 0], [22, 0]]",0.5,0.045455,0.166667,"[[179, 1], [21, 1]]",0.833333,0.227273,0.543478,"[[179, 1], [17, 5]]",5
24,415_RAH1_t6.csv,,0.0,,"[[180, 0], [22, 0]]",0.5,0.045455,0.166667,"[[179, 1], [21, 1]]",0.75,0.136364,0.394737,"[[179, 1], [19, 3]]",6
25,415_RAH1_t7.csv,,0.0,,"[[180, 0], [22, 0]]",0.0,0.0,,"[[179, 1], [22, 0]]",0.0,0.0,,"[[179, 1], [22, 0]]",7
26,415_RAH1_t8.csv,,0.0,,"[[180, 0], [22, 0]]",,0.0,,"[[180, 0], [22, 0]]",0.0,0.0,,"[[179, 1], [22, 0]]",8
27,415_RAH1_t9.csv,,0.0,,"[[180, 0], [22, 0]]",,0.0,,"[[180, 0], [22, 0]]",,0.0,,"[[180, 0], [22, 0]]",9


In [33]:
grad_metrics

{'accuracy': 0.8489583333333334,
 'precision': 1.0,
 'recall': 0.12121212121212122,
 'f_score': 0.40816326530612246}

In [26]:
pd.DataFrame(channels_dict)

Unnamed: 0,channel,amp,grad,env
0,RAH1,0,0,0
1,LAH1,5,7,9
2,RAH1-RAH2,0,0,0
3,LAH1-LAH2,5,7,9
