In [4]:
import sys
sys.path.append('..')

import numpy as np
import torch
from tqdm.notebook import tqdm
import pandas as pd
import torch.nn.functional as F
from lib.CustomDataset import TimeSeriesHDF5Dataset
from torch.utils.data import DataLoader
from lib.VAE import VAE	
from lib.Utilities import *
import torch.optim as optim
import time
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix
import yaml
import os

from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier

from lib.FE_ExtractFeatures import ExtractFeatures

In [5]:
segment_length_sec = 30
sampling_rate = config['sampling_rate']
overlap = 0.95
directory_path = config['hdf5_file_dir']
mode = ['ABP','ART']


hdf5_files = ['4_Patient_2022-02-05_08:59.h5']

In [10]:
features_all = []

for filename in tqdm(hdf5_files):
	log_info(f"Processing {filename}")
	datafile = os.path.join(directory_path, filename)
	
	# Load the dataset
	for m in mode:
		dataset  = TimeSeriesHDF5Dataset(datafile, m, segment_len=segment_length_sec, overlap=overlap, phase="train", smoothen=False) 

		if len(dataset)==0:
			print("No data to process, continuing...")
			continue

		dataloader = DataLoader(dataset, batch_size=32, num_workers=4, shuffle=False, pin_memory=True)

		artifact_count, non_artifact_count= 0,0
		

		total_count =0
		for start_i, data, lbl, ts in tqdm(dataloader):
			filter = filter_abp_batch_scae(data)
			
			start_i = start_i[filter]
			data = data[filter]
			lbl = lbl[filter]
			ts = ts[filter]

			if len(start_i)>0:
				for b_n in range(len(start_i)):
					start_idx = start_i[b_n]
					label = lbl[b_n]
					timestamp  = ts[b_n]
					signal_data = data[b_n]

					if label==1:
						artifact_count+=1
						input_data = signal_data.unsqueeze(dim=0).numpy()
						features = ExtractFeatures(input_data).get_features().squeeze()
						
						per_segment_features = [datafile, m, label.item()] + features.tolist()
						features_all.append(per_segment_features)

			

  0%|          | 0/1 [00:00<?, ?it/s]

14:53:08 :	  Processing 4_Patient_2022-02-05_08:59.h5 

14:53:08 :	  No Waveforms/ABP_na in the hdf5 file: <HDF5 file "4_Patient_2022-02-05_08:59.h5" (mode r)>. 

No data to process, continuing...




In [18]:
train_files = ['4_Patient_2022-02-05_08:59.h5','4_Patient_2022-02-05_08:59.h5']
test_file = ['4_Patient_2022-02-05_08:59.h5']


In [19]:
import pandas as pd
features_csv_file = '/home/ms5267@drexel.edu/moberg-precicecap/ArtifactDetectionEval/data/FE_features_train.csv'
mode = ['ABP','ART']

df = pd.read_csv(features_csv_file, header=None)

train_features = df[df[0].isin(train_files)][df[1].isin(mode)]

test_features = df[df[0].isin(test_file)][df[1].isin(mode)]

# Training data
train_labels = train_features.iloc[:, 2].to_numpy()
X_train = train_features.iloc[:, 3:].to_numpy()

# Test data
test_labels = test_features.iloc[:, 2].to_numpy()
X_test = test_features.iloc[:, 3:].to_numpy()


In [20]:

def train_and_eval_SVM(X_train, y_train, X_test, y_test):
    log_info("Training with SVM")
    svm_classifier = SVC(kernel='rbf')
    svm_classifier.fit(X_train, y_train)

    log_info("Evaluating the SVM classifier")
    y_pred_train = svm_classifier.predict(X_train)
    log_info(f"Train Accuracy: {accuracy_score(y_train, y_pred_train)}\n{classification_report(y_train, y_pred_train)}\n{confusion_matrix(y_train, y_pred_train)}")
    
    y_pred_test = svm_classifier.predict(X_test)
    log_info(f"Test Accuracy: {accuracy_score(y_test, y_pred_test)}\n{classification_report(y_test, y_pred_test)}\n{confusion_matrix(y_test, y_pred_test)}")

    log_info(f"Saving the trained SVM model")
    save_model(svm_classifier, 'models/svm_classifier_afib.pkl')


def train_and_eval_KNN(X_train, y_train, X_test, y_test, n_neighbors=5):
    log_info("Training with KNN")
    knn_classifier = KNeighborsClassifier(n_neighbors=n_neighbors)
    knn_classifier.fit(X_train, y_train)

    log_info("Evaluating the KNN classifier")
    y_pred_train = knn_classifier.predict(X_train)
    log_info(f"Train Accuracy: {accuracy_score(y_train, y_pred_train)}")
    log_info(f"{classification_report(y_train, y_pred_train)}")
    log_info(f"{confusion_matrix(y_train, y_pred_train)}")

    y_pred_test = knn_classifier.predict(X_test)
    log_info(f"Test Accuracy: {accuracy_score(y_test, y_pred_test)}")
    log_info(f"{classification_report(y_test, y_pred_test)}")
    log_info(f"{confusion_matrix(y_test, y_pred_test)}")

    log_info("Saving the trained KNN model")
    save_model(knn_classifier, 'models/knn_classifier_afib.pkl')


def train_and_eval_DT(X_train, y_train, X_test, y_test, max_depth=None, criterion='gini'):
    log_info("Training with Decision Tree")
    dt_classifier = DecisionTreeClassifier(max_depth=max_depth, criterion=criterion)
    dt_classifier.fit(X_train, y_train)

    log_info("Evaluating the Decision Tree classifier")
    y_pred_train = dt_classifier.predict(X_train)
    log_info(f"Train Accuracy: {accuracy_score(y_train, y_pred_train)}")
    log_info(f"{classification_report(y_train, y_pred_train)}")
    log_info(f"{confusion_matrix(y_train, y_pred_train)}")

    y_pred_test = dt_classifier.predict(X_test)
    log_info(f"Test Accuracy: {accuracy_score(y_test, y_pred_test)}")
    log_info(f"{classification_report(y_test, y_pred_test)}")
    log_info(f"{confusion_matrix(y_test, y_pred_test)}")

    log_info("Saving the trained Decision Tree model")
    save_model(dt_classifier, 'models/dt_classifier_afib.pkl')

array([[72.31704712, 65.875     , 20.60549927, ...,  0.296     ,
         1.176     , 46.22127151],
       [72.0506134 , 65.1875    , 20.69827652, ...,  0.296     ,
         1.176     , 43.9915657 ],
       [72.15585327, 65.8125    , 20.51646042, ...,  0.296     ,
         1.176     , 43.42429733],
       ...,
       [37.56868362, 36.5625    ,  4.36484241, ...,  0.296     ,
         0.776     , 14.41522408],
       [62.59261703, 58.        , 15.92419529, ...,  0.304     ,
         0.984     , 32.14121246],
       [56.2902832 , 52.1875    , 13.82932854, ...,  0.296     ,
         1.152     , 44.97563934]])