In [10]:
import numpy as np
from matplotlib import pyplot as plt
from sklearn import preprocessing
import wfdb
import copy as cp
import scipy.signal as signal
import scipy.stats as stats
from sklearn import preprocessing
from tqdm import tqdm
import os
import pathlib
import re
import pandas as pd
import pickle
import csv
import statistics

In [23]:
import timeit

import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats("retina")

from sklearn.model_selection import cross_validate
from sklearn.model_selection import LeaveOneGroupOut

from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve, auc
from sklearn.inspection import permutation_importance

import catboost as cb
from catboost import CatBoostClassifier

In [12]:
parent_path = str(pathlib.Path(os.path.normpath('C:\\Users\\arisi\\Documents\\GitHub\\2022-svsm-afib-group1\\')))

rlist = []
records = os.path.normpath(parent_path + '/mit-bih-dataframes/subject_list.csv')
with open(records) as rfile:
    recordreader = csv.reader(rfile, delimiter=' ', quotechar='|')
    for row in recordreader:
        rlist.append(row[0])

In [13]:
performance_dict = {
    "Model name": [],
    "Avg Accuracy": [],
    "Std Accuracy": [],
    "Sensitivity": [],
    "Specificity": [],
    "Precision": [],
    "F1 score": [],
    "Run time": [],
    "TPS": []
}

moving_accuracy = {}

In [14]:
def score_reporter(initial_results):
    initial_columns = initial_results.axes[0].tolist()
    results = initial_results.dropna()
    changed_columns = results.axes[0].tolist()
    
    dropped_cols = list(set(initial_columns).difference(changed_columns))
    
    print(dropped_cols)

    acc_scores = []
    for x in range(len(rlist)):
        col_name = 'split'+str(x)+'_test_accuracy'
        if col_name not in dropped_cols:
            acc_scores.append(results[col_name])

    spec_scores = []
    for x in range(len(rlist)):
        col_name = 'split'+str(x)+'_test_specificity'
        if col_name not in dropped_cols:
            spec_scores.append(results[col_name])

    sens_scores = []
    for x in range(len(rlist)):
        col_name = 'split'+str(x)+'_test_sensitivity'
        if col_name not in dropped_cols:
            sens_scores.append(results[col_name])


    prec_scores = []
    for x in range(len(rlist)):
        col_name = 'split'+str(x)+'_test_precision'
        if col_name not in dropped_cols:
            prec_scores.append(results[col_name])

    f1_scores = []
    for x in range(len(rlist)):
        col_name = 'split'+str(x)+'_test_f1_score'
        if col_name not in dropped_cols:
            f1_scores.append(results[col_name])
        
    elapsed_times = []
    for x in range(len(rlist)):
        col_name = 'split'+str(x)+'_test_elapsed'
        if col_name not in dropped_cols:
            elapsed_times.append(results[col_name])

    eps_times = []
    for x in range(len(rlist)):
        col_name = 'split'+str(x)+'_test_eps'
        if col_name not in dropped_cols:
            eps_times.append(results[col_name])

    print('---Run time of each fold: \n {}'.format(elapsed_times))
    print("Avg run time: {}".format(np.mean(elapsed_times)))
    print('---Run time per subset of each fold is: \n {}'.format(eps_times))
    print("Avg run time per subset: {}".format(np.mean(eps_times)))
    print()
    print('Accuracy of each fold: \n {}'.format(acc_scores))
    print("Avg accuracy: {}".format(np.mean(acc_scores)))
    print('Std of accuracy : \n{}'.format(np.std(acc_scores)))
    print()
    print('Specificity of each fold: \n {}'.format(spec_scores))
    print("Avg specificity: {}".format(np.mean(spec_scores)))
    print('Std of specificity: \n{}'.format(np.std(spec_scores)))
    print()
    print('Sensitivity of each fold: \n {}'.format(sens_scores))
    print("Avg sensitivity: {}".format(np.mean(sens_scores)))
    print('Std of sensitivity: \n{}'.format(np.std(sens_scores)))
    print()
    print('Precision of each fold: \n {}'.format(prec_scores))
    print("Avg precision: {}".format(np.mean(prec_scores)))
    print('Std of precision : \n{}'.format(np.std(prec_scores)))
    print()
    print('F1-scores of each fold: \n {}'.format(f1_scores))
    print("Avg F1-scores: {}".format(np.mean(f1_scores)))
    print('Std of F1-scores : \n{}'.format(np.std(f1_scores)))

In [21]:
def cv_scorer(clf, X, y):
    global moving_acc

    start_time = timeit.default_timer()
    y_pred = clf.predict(X)
    elapsed = timeit.default_timer() - start_time

    total_seen = 0
    total_correct = 0
    subject_acc = []
    for idx, pred in enumerate(y_pred):
        total_seen+=1
        if pred==y.iloc[idx]:
            total_correct+=1
        subject_acc.append(total_correct/total_seen)
    moving_acc.append(subject_acc)

    fold_size = len(X)

    cm = confusion_matrix(y, y_pred)

    sensitivity = cm[0][0]/(cm[0][0]+cm[0][1])
    specificity = cm[1][1]/(cm[1][0]+cm[1][1])
    precision = (cm[0][0])/(cm[0][0]+cm[1][0])
    f1_score = (2*precision*sensitivity)/(precision+sensitivity)

    return {'accuracy': accuracy_score(y, y_pred), 
            'sensitivity': sensitivity, 'specificity': specificity,
            'precision': precision, 'f1_score': f1_score,
            #'auc_score': roc_auc_score(y, clf.predict_proba(X)[:, 1]),
            'elapsed': elapsed, 'eps': elapsed/fold_size}

In [22]:
def pi_scorer(clf, X, y):
    global moving_acc

    y_pred = clf.predict(X)

    cm = confusion_matrix(y, y_pred)

    sensitivity = cm[0][0]/(cm[0][0]+cm[0][1])
    specificity = cm[1][1]/(cm[1][0]+cm[1][1])
    precision = (cm[0][0])/(cm[0][0]+cm[1][0])
    f1_score = (2*precision*sensitivity)/(precision+sensitivity)

    return {'accuracy': accuracy_score(y, y_pred), 
            'sensitivity': sensitivity, 'specificity': specificity,
            'precision': precision, 'f1_score': f1_score}
            #'auc_score': roc_auc_score(y, clf.predict_proba(X)[:, 1])

In [17]:
feature_dfs = {}
for record in tqdm(rlist):
    feature_dfs[record] = pd.read_parquet(os.path.normpath(parent_path + '/mit-bih-time-features/'+record+'.parquet'))

combined_features = pd.concat([feature_dfs[key][1:] for key in feature_dfs])

100%|██████████| 23/23 [00:00<00:00, 41.67it/s]


In [18]:
X = combined_features[['StoS', 'StoR', 'StoL', 'RtoS', 'RtoR', 'RtoL', 'LtoS', 'LtoR', 'LtoL', 'rmssd', 'iqr', 'mad', 'cov']]
y = combined_features['mappedLabel'].map({"Non-Afib": 0, "Afib": 1})
groups = combined_features['subjectID'].astype('int64')

logo = LeaveOneGroupOut()
splits = list(logo.split(X, y, groups=groups))

In [24]:
# CatBoost
moving_accuracy['catboost'] = []

model = CatBoostClassifier(learning_rate=0.1, loss_function='Logloss', verbose=None, max_depth=8, iterations=400)

moving_acc = []
scores = cross_validate(model, X, y, scoring=cv_scorer, cv=splits, return_estimator=True)
moving_accuracy['catboost'] = moving_acc

importances = []

for estimator in scores['estimator']:
    importances.append(permutation_importance(estimator, X, y, n_repeats=15, random_state=1))


'''
print(randomsearch.best_estimator_.feature_importances_)
Importance = pd.DataFrame({'Importance':(randomsearch.best_estimator_.feature_importances_*100)[0:12]}, 
                      index = (X.columns)[0:12])
Importance.sort_values(by = 'Importance', 
                   axis = 0, 
                   ascending = True).plot(kind = 'barh', 
                                          color = 'r')
plt.xlabel('Variable Importance')
plt.gca().legend_ = None

results = pd.DataFrame(randomsearch.cv_results_)
results.to_csv(os.path.normpath(parent_path + '/models/subsets/results/catboost_results.csv'))
print(results)

score_reporter(results.iloc[randomsearch.best_index_])

print(randomsearch.best_params_)
'''

print(scores)

0:	learn: 0.5124539	total: 35.5ms	remaining: 14.2s
1:	learn: 0.4045208	total: 67.5ms	remaining: 13.4s
2:	learn: 0.3381354	total: 97.7ms	remaining: 12.9s
3:	learn: 0.2892314	total: 129ms	remaining: 12.8s
4:	learn: 0.2541348	total: 162ms	remaining: 12.8s
5:	learn: 0.2286572	total: 190ms	remaining: 12.5s
6:	learn: 0.2096618	total: 218ms	remaining: 12.3s
7:	learn: 0.1949566	total: 245ms	remaining: 12s
8:	learn: 0.1838187	total: 274ms	remaining: 11.9s
9:	learn: 0.1751992	total: 302ms	remaining: 11.8s
10:	learn: 0.1683873	total: 332ms	remaining: 11.7s
11:	learn: 0.1623843	total: 360ms	remaining: 11.6s
12:	learn: 0.1578152	total: 386ms	remaining: 11.5s
13:	learn: 0.1537176	total: 412ms	remaining: 11.4s
14:	learn: 0.1497166	total: 441ms	remaining: 11.3s
15:	learn: 0.1465542	total: 471ms	remaining: 11.3s
16:	learn: 0.1438065	total: 499ms	remaining: 11.2s
17:	learn: 0.1411713	total: 530ms	remaining: 11.2s
18:	learn: 0.1390661	total: 558ms	remaining: 11.2s
19:	learn: 0.1366980	total: 588ms	remain

  sensitivity = cm[0][0]/(cm[0][0]+cm[0][1])


0:	learn: 0.5046221	total: 30.3ms	remaining: 12.1s
1:	learn: 0.3897354	total: 57.3ms	remaining: 11.4s
2:	learn: 0.3192592	total: 86.2ms	remaining: 11.4s
3:	learn: 0.2702712	total: 114ms	remaining: 11.3s
4:	learn: 0.2361437	total: 144ms	remaining: 11.4s
5:	learn: 0.2111783	total: 197ms	remaining: 12.9s
6:	learn: 0.1938464	total: 235ms	remaining: 13.2s
7:	learn: 0.1810717	total: 268ms	remaining: 13.1s
8:	learn: 0.1709763	total: 299ms	remaining: 13s
9:	learn: 0.1624025	total: 330ms	remaining: 12.9s
10:	learn: 0.1555736	total: 359ms	remaining: 12.7s
11:	learn: 0.1499610	total: 387ms	remaining: 12.5s
12:	learn: 0.1452280	total: 417ms	remaining: 12.4s
13:	learn: 0.1411097	total: 445ms	remaining: 12.3s
14:	learn: 0.1372559	total: 473ms	remaining: 12.1s
15:	learn: 0.1348393	total: 500ms	remaining: 12s
16:	learn: 0.1328320	total: 528ms	remaining: 11.9s
17:	learn: 0.1309127	total: 558ms	remaining: 11.8s
18:	learn: 0.1288527	total: 586ms	remaining: 11.8s
19:	learn: 0.1273676	total: 615ms	remainin

  sensitivity = cm[0][0]/(cm[0][0]+cm[0][1])


0:	learn: 0.5213071	total: 36ms	remaining: 14.4s
1:	learn: 0.4058022	total: 69.1ms	remaining: 13.7s
2:	learn: 0.3399938	total: 101ms	remaining: 13.4s
3:	learn: 0.2923238	total: 133ms	remaining: 13.2s
4:	learn: 0.2587184	total: 167ms	remaining: 13.2s
5:	learn: 0.2348483	total: 197ms	remaining: 12.9s
6:	learn: 0.2162904	total: 227ms	remaining: 12.7s
7:	learn: 0.2019905	total: 255ms	remaining: 12.5s
8:	learn: 0.1899446	total: 286ms	remaining: 12.4s
9:	learn: 0.1816086	total: 315ms	remaining: 12.3s
10:	learn: 0.1745957	total: 343ms	remaining: 12.1s
11:	learn: 0.1686759	total: 373ms	remaining: 12.1s
12:	learn: 0.1642143	total: 400ms	remaining: 11.9s
13:	learn: 0.1597541	total: 428ms	remaining: 11.8s
14:	learn: 0.1567039	total: 454ms	remaining: 11.7s
15:	learn: 0.1534496	total: 483ms	remaining: 11.6s
16:	learn: 0.1512293	total: 513ms	remaining: 11.6s
17:	learn: 0.1486900	total: 543ms	remaining: 11.5s
18:	learn: 0.1467564	total: 570ms	remaining: 11.4s
19:	learn: 0.1446515	total: 599ms	remaini