In [None]:
pip install wfdb

In [None]:
import wfdb
import csv
from google.colab import drive
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from wfdb import processing
import numpy as np
import statistics
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader, TensorDataset
import random
from sklearn.model_selection import train_test_split, learning_curve
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.model_selection import LearningCurveDisplay
import seaborn as sns

random.seed(4)

drive.mount('/content/drive')

In [None]:
# Function that returns path to the relevant patient files
# Input: patient number (type=int)
# Output: path (type=string)
def get_path(path,patient_number):
  if path == 0:
    if patient_number < 10:
      return '/content/drive/My Drive/Test/p00/p0000'+str(patient_number)+'/p0000'+str(patient_number)+'_s00'
    else:
      return '/content/drive/My Drive/Test/p00/p000'+str(patient_number)+'/p000'+str(patient_number)+'_s00'

  if path == 1:
    if patient_number < 10:
      return '/content/drive/My Drive/Test/p01/p0100'+str(patient_number)+'/p0100'+str(patient_number)+'_s00'
    else:
      return '/content/drive/My Drive/Test/p01/p010'+str(patient_number)+'/p010'+str(patient_number)+'_s00'

  if path == 2:
    if patient_number < 10:
      return '/content/drive/My Drive/Test/p02/p0200'+str(patient_number)+'/p0200'+str(patient_number)+'_s00'
    else:
      return '/content/drive/My Drive/Test/p02/p020'+str(patient_number)+'/p020'+str(patient_number)+'_s00'

  if path == 3:
    if patient_number < 10:
      return '/content/drive/My Drive/Test/p03/p0300'+str(patient_number)+'/p0300'+str(patient_number)+'_s00'
    else:
      return '/content/drive/My Drive/Test/p03/p030'+str(patient_number)+'/p030'+str(patient_number)+'_s00'

  if path == 4:
    if patient_number < 10:
      return '/content/drive/My Drive/Test/p04/p0400'+str(patient_number)+'/p0400'+str(patient_number)+'_s00'
    else:
      return '/content/drive/My Drive/Test/p04/p040'+str(patient_number)+'/p040'+str(patient_number)+'_s00'

  if path == 5:
    if patient_number < 10:
      return '/content/drive/My Drive/Test/p05/p0500'+str(patient_number)+'/p0500'+str(patient_number)+'_s00'
    else:
      return '/content/drive/My Drive/Test/p05/p050'+str(patient_number)+'/p050'+str(patient_number)+'_s00'

  if path == 6:
    if patient_number < 10:
      return '/content/drive/My Drive/Test/p06/p0600'+str(patient_number)+'/p0600'+str(patient_number)+'_s00'
    else:
      return '/content/drive/My Drive/Test/p06/p060'+str(patient_number)+'/p060'+str(patient_number)+'_s00'

  if path == 7:
    if patient_number < 10:
      return '/content/drive/My Drive/Test/p07/p0700'+str(patient_number)+'/p0700'+str(patient_number)+'_s00'
    else:
      return '/content/drive/My Drive/Test/p07/p070'+str(patient_number)+'/p070'+str(patient_number)+'_s00'

  if path == 8:
    if patient_number < 10:
      return '/content/drive/My Drive/Test/p08/p0800'+str(patient_number)+'/p0800'+str(patient_number)+'_s00'
    else:
      return '/content/drive/My Drive/Test/p08/p080'+str(patient_number)+'/p080'+str(patient_number)+'_s00'

  if path == 9:
    if patient_number < 10:
      return '/content/drive/My Drive/Test/p09/p0900'+str(patient_number)+'/p0900'+str(patient_number)+'_s00'
    else:
      return '/content/drive/My Drive/Test/p09/p090'+str(patient_number)+'/p090'+str(patient_number)+'_s00'

  if path == 10:
    if patient_number < 10:
      return '/content/drive/My Drive/Test/p10/p1000'+str(patient_number)+'/p1000'+str(patient_number)+'_s00'
    else:
      return '/content/drive/My Drive/Test/p10/p100'+str(patient_number)+'/p100'+str(patient_number)+'_s00'

In [None]:
# Define helper functions for label categorization

# Function that returns the label for an interval of data samples
# Input: start index (type=int), end index (type=int)
# Output: label (type=str)
def get_label_index(start_index, end_index, label_indices):
  for i in range(len(label_indices)):
    if (start_index<=label_indices[i])and(end_index>=label_indices[i]):
      return i
  return -1

# Function that returns the label class
# Input: label (type=str)
# Output: label_class (type=str)
def adjust_label(label):
  B = "LRje"
  S = "Aasj"
  V = "!VE[]"
  F = "F"
  Q = "f/Q"
  if(label=="N"):
    return "N"
  elif(B.find(label)!=-1):
    return "B"
  elif(S.find(label)!=-1):
    return "S"
  elif(V.find(label)!=-1):
    return "V"
  elif(F.find(label)!=-1):
    return "F"
  elif(Q.find(label)!=-1):
    return "Q"
  else:
    return "*"

In [None]:
# Function that returns sampling frequency for a patient file
# Input: patient number (type=int)
# Output: sampling frequency (type=int)
def get_sampling_frequency(path,patient_number):
  temp_sig, temp_fields = wfdb.rdsamp(get_path(path,patient_number), channels=[0])
  return temp_fields['fs']


# Function that returns peak indices for a patient file
# Input: patient number (type=int)
# Output: peak indices (type=list)
def get_peak_indices(path,patient_number):
  temp_sig, temp_fields = wfdb.rdsamp(get_path(path,patient_number), channels=[0])
  peak_inds = processing.xqrs_detect(sig=temp_sig[:,0], fs=temp_fields['fs'])
  return peak_inds

# Function that returns start indices for beats for the whole patient file
# Input: qrs indices (type=list)
# Output: start indices (type=list)
def get_start_index(qrs_indices):
  temp_list = [0]
  for i in range(len(qrs_indices)-1):
    temp_avg = int((qrs_indices[i+1]+qrs_indices[i])/2)
    temp_list.append(temp_avg)
  return temp_list

In [None]:
# Function that returns bpm (beats per minute) parameter for a singular beat
# Input: start_index (type=int), end_index(type=int), sampling frequency (type=int)
# Output: bpm (type=int)
def get_bpm(start_index, end_index, frequency):
  temp_duration = end_index-start_index
  return (60*frequency)/(temp_duration)

# Function that returns RS amplitude parameter for a singular beat
# Input: signal (type=list(or 1D np array))
# Output: rs_amplitude (type=int)
def get_RS_amplitude(signal):
  return max(signal)-min(signal)

# Function that returns RS parameter for a singular beat
# Input: signal (type=list(or 1D np array))
# Output: rs_interval (type=int)
def get_RS_interval(signal, frequency):
  temp_duration = signal.index(min(signal))-signal.index(max(signal))
  return temp_duration / frequency

# Function that returns TS amplitude parameter for a singular beat
# Input: signal (type=list(or 1D np array))
# Output: rs_amplitude (type=int)
def get_TS_amplitude(signal):
  signal_min = min(signal)
  signal_min_index = signal.index(signal_min)
  updated_signal = signal[signal_min_index:]
  return max(updated_signal)-signal_min

# Function that returns TS interval parameter for a singular beat
# Input: signal (type=list(or 1D np array))
# Output: ts_interval (type=int)
def get_TS_interval(signal, frequency):
  signal_min = min(signal)
  signal_min_index = signal.index(signal_min)
  updated_signal = signal[signal_min_index:]
  updated_signal_max_index = updated_signal.index(max(updated_signal))
  temp_duration = updated_signal_max_index-signal_min_index
  return temp_duration / frequency

In [None]:
annotation = wfdb.rdann(get_path(0,0), extension='atr')
ann_symbol = annotation.symbol[1:] # In order to remove starting label +
ann_indices =  annotation.sample[1:]
sampling_frequency = get_sampling_frequency(0,0)
qrs_index = get_peak_indices(0,0)
start_ind = get_start_index(qrs_index)
signal, fields = wfdb.rdsamp(get_path(0,0),channels=[0])
cropped_signal = signal[start_ind[0]:start_ind[1]]
cropped_signal = np.array(cropped_signal).flatten()
cropped_signal = cropped_signal.tolist()

print()
print("Extracted Parameters")
print("--------------------")
print("Beat #: 1")
print("Label:", ann_symbol[get_label_index(start_index=start_ind[0], end_index=start_ind[1], label_indices=ann_indices)])
print("RS amplitude:", get_RS_amplitude(cropped_signal))
print("RS interval:", get_RS_interval(signal=cropped_signal, frequency=sampling_frequency))
print("TS amplitude:", get_TS_amplitude(signal=cropped_signal))
print("TS interval:", get_TS_interval(signal=cropped_signal, frequency=sampling_frequency))
print("BPM:", get_bpm(frequency=sampling_frequency, start_index=start_ind[0], end_index=start_ind[1]))
print("Mean:", statistics.mean(cropped_signal))
print("Standard deviation:", statistics.stdev(cropped_signal))

In [None]:
# Create CSV
filename = "baseline_dataset.csv"
outfile = open(filename, "w")
out_csv = csv.writer(outfile)

# Write CSV header with lead names
out_csv.writerow(["Label", "RS_amplitude:", "RS_interval", "TS_amplitude", "TS interval", "BPM", "Mean", "STD"])
count = 0
for patient_number in range(0,10):
  try:
    annotation = wfdb.rdann(get_path(patient_number), extension='atr')
    label_index_list = annotation.sample
    label_symbol_list = annotation.symbol
    sampling_frequency = get_sampling_frequency(patient_number)
    qrs_index = get_peak_indices(patient_number)
    start_ind = get_start_index(qrs_index)
    signal, fields = wfdb.rdsamp(get_path(patient_number),channels=[0])
    print("Obtaining data from patient #", patient_number)

    for i in range(len(qrs_index)-1):
        csv_row = []
        cropped_signal = signal[start_ind[i]:start_ind[i+1]]
        cropped_signal = np.array(cropped_signal).flatten()
        cropped_signal = cropped_signal.tolist()
        label_index = get_label_index(start_index=start_ind[i], end_index=start_ind[i+1], label_indices=label_index_list)
        if (label_index==-1):
          print("Oops. No label is present for this ECG signal...")
          continue
        label = label_symbol_list[label_index]
        label = adjust_label(label)
        if (label=="*"):
          print("Oops. This is a useless label...")
          continue
        csv_row.append(label)
        csv_row.append(get_RS_amplitude(cropped_signal))
        csv_row.append(get_RS_interval(signal=cropped_signal, frequency=sampling_frequency))
        csv_row.append(get_TS_amplitude(signal=cropped_signal))
        csv_row.append(get_TS_interval(signal=cropped_signal, frequency=sampling_frequency))
        csv_row.append(get_bpm(frequency=sampling_frequency, start_index=start_ind[i], end_index=start_ind[i+1]))
        csv_row.append(statistics.mean(cropped_signal))
        csv_row.append(statistics.stdev(cropped_signal))
        out_csv.writerow(csv_row)
  except FileNotFoundError:
    continue

In [None]:
df = pd.read_csv('/content/drive/My Drive/ECG Project Files/baseline_dataset.csv')

label_encoder = LabelEncoder()
df.iloc[:, 0] = label_encoder.fit_transform(df.iloc[:, 0])

samples = pd.DataFrame()
for class_label in df.iloc[:, 0].unique():
    classes = df[df.iloc[:, 0] == class_label]
    sample_classes = classes.sample(n = min(len(classes), 3000), random_state=42)
    samples = pd.concat([samples, sample_classes], axis=0)

X = samples.iloc[:, 1:].values
y = samples.iloc[:, 0].values

# Data splitting
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.3, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

# Pipeline
pipeline = Pipeline([('scaler', StandardScaler()),('pca', PCA(n_components=3)),('svm', SVC(kernel='poly', degree=3))])
pipeline.fit(X_train, y_train)

# Learning curve ()
fig, ax = plt.subplots(figsize=(8, 6))
train_sizes, train_scores, test_scores, fit_times, _ = learning_curve(pipeline, X_train, y_train, cv=5, n_jobs=-1, train_sizes=np.linspace(.1, 1.0, 5),return_times=True)
LearningCurveDisplay.from_estimator(estimator=pipeline, X=X_train, y=y_train, cv=5, n_jobs=-1, ax=ax)
ax.set_title('Learning Curve for SVM with PCA')
ax.set_xlabel('Training Examples')
ax.set_ylabel('Accuracy')
plt.show()
print("")

# PCA visualization
X_transformed = pipeline.named_steps['pca'].transform(pipeline.named_steps['scaler'].transform(X))
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection='3d')
scatter = ax.scatter(X_transformed[:, 0], X_transformed[:, 1], X_transformed[:, 2], c=y, cmap='viridis', edgecolor='k')
legend1 = ax.legend(*scatter.legend_elements(), loc="upper right", title="Classes")
ax.add_artist(legend1)
ax.set_xlabel('Principal Component 1')
ax.set_ylabel('Principal Component 2')
ax.set_zlabel('Principal Component 3')
plt.title('3D PCA projection with SVM classification')
plt.show()
print("")

# Model accuracy
y_train_pred = pipeline.predict(X_train)
y_val_pred = pipeline.predict(X_val)
train_accuracy = accuracy_score(y_train, y_train_pred)
val_accuracy = accuracy_score(y_val, y_val_pred)
print(f"Training Accuracy: {train_accuracy:.2f}")
print(f"Validation Accuracy: {val_accuracy:.2f}")
print("")

# Confusion matrix
print("Confusion Matrix Data for the Validation Dataset:")
print(confusion_matrix(y_val, y_val_pred))
print("")
print("Classification Report for the Validation Dataset:")
print(classification_report(y_val, y_val_pred))
print("")

# Confusion matrix actuator
cm = confusion_matrix(y_val, y_val_pred)
plt.figure(figsize=(10, 7))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=le.classes_, yticklabels=le.classes_)
plt.title('Confusion Matrix')
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.show()