In [None]:
pip install wfdb

WFDB Documentation: https://wfdb.readthedocs.io/en/stable/io.html

WFDB ECG Library Documentation:https://wfdb.readthedocs.io/en/latest/processing.html

Scikit-learn Documentation: https://scikit-learn.org/stable/

In [None]:
import wfdb
import csv
from google.colab import drive
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from wfdb import processing
import numpy as np
import statistics
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader, TensorDataset
import random
from sklearn.model_selection import train_test_split, learning_curve
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.model_selection import LearningCurveDisplay
import seaborn as sns

random.seed(4)

drive.mount('/content/drive')

In [None]:
# Function that returns path to the relevant patient files
# Input: patient number (type=int)
# Output: path (type=string)
def get_path(path,patient_number):
  if path == 0:
    if patient_number < 10:
      return '/content/drive/My Drive/Test/p00/p0000'+str(patient_number)+'/p0000'+str(patient_number)+'_s00'
    else:
      return '/content/drive/My Drive/Test/p00/p000'+str(patient_number)+'/p000'+str(patient_number)+'_s00'

  if path == 1:
    if patient_number < 10:
      return '/content/drive/My Drive/Test/p01/p0100'+str(patient_number)+'/p0100'+str(patient_number)+'_s00'
    else:
      return '/content/drive/My Drive/Test/p01/p010'+str(patient_number)+'/p010'+str(patient_number)+'_s00'

  if path == 2:
    if patient_number < 10:
      return '/content/drive/My Drive/Test/p02/p0200'+str(patient_number)+'/p0200'+str(patient_number)+'_s00'
    else:
      return '/content/drive/My Drive/Test/p02/p020'+str(patient_number)+'/p020'+str(patient_number)+'_s00'

  if path == 3:
    if patient_number < 10:
      return '/content/drive/My Drive/Test/p03/p0300'+str(patient_number)+'/p0300'+str(patient_number)+'_s00'
    else:
      return '/content/drive/My Drive/Test/p03/p030'+str(patient_number)+'/p030'+str(patient_number)+'_s00'

  if path == 4:
    if patient_number < 10:
      return '/content/drive/My Drive/Test/p04/p0400'+str(patient_number)+'/p0400'+str(patient_number)+'_s00'
    else:
      return '/content/drive/My Drive/Test/p04/p040'+str(patient_number)+'/p040'+str(patient_number)+'_s00'

  if path == 5:
    if patient_number < 10:
      return '/content/drive/My Drive/Test/p05/p0500'+str(patient_number)+'/p0500'+str(patient_number)+'_s00'
    else:
      return '/content/drive/My Drive/Test/p05/p050'+str(patient_number)+'/p050'+str(patient_number)+'_s00'

  if path == 6:
    if patient_number < 10:
      return '/content/drive/My Drive/Test/p06/p0600'+str(patient_number)+'/p0600'+str(patient_number)+'_s00'
    else:
      return '/content/drive/My Drive/Test/p06/p060'+str(patient_number)+'/p060'+str(patient_number)+'_s00'

  if path == 7:
    if patient_number < 10:
      return '/content/drive/My Drive/Test/p07/p0700'+str(patient_number)+'/p0700'+str(patient_number)+'_s00'
    else:
      return '/content/drive/My Drive/Test/p07/p070'+str(patient_number)+'/p070'+str(patient_number)+'_s00'

  if path == 8:
    if patient_number < 10:
      return '/content/drive/My Drive/Test/p08/p0800'+str(patient_number)+'/p0800'+str(patient_number)+'_s00'
    else:
      return '/content/drive/My Drive/Test/p08/p080'+str(patient_number)+'/p080'+str(patient_number)+'_s00'

  if path == 9:
    if patient_number < 10:
      return '/content/drive/My Drive/Test/p09/p0900'+str(patient_number)+'/p0900'+str(patient_number)+'_s00'
    else:
      return '/content/drive/My Drive/Test/p09/p090'+str(patient_number)+'/p090'+str(patient_number)+'_s00'

  if path == 10:
    if patient_number < 10:
      return '/content/drive/My Drive/Test/p10/p1000'+str(patient_number)+'/p1000'+str(patient_number)+'_s00'
    else:
      return '/content/drive/My Drive/Test/p10/p100'+str(patient_number)+'/p100'+str(patient_number)+'_s00'



In [None]:
# Printing signal and fields data from the dataset
signals, fields = wfdb.rdsamp(get_path(9,2) sampfrom=3000, channels=[0])
print("Signal:", signals)
print("Fields:", fields)

In [None]:
# Visualizing data
temp_record = wfdb.rdrecord(get_path(1,2), sampto=3000)
temp_annotation = wfdb.rdann(get_path(1,2), extension='atr', sampto=3000)
wfdb.plot_wfdb(record=temp_record, annotation=temp_annotation, plot_sym=True,
                   time_units='seconds', title='Icentia11k Single Lead Continuous Raw Electrocardiogram Dataset',
                   figsize=(10,4), ecg_grids='all')

In [None]:
# Figure out all labels present in the dataset
wfdb.show_ann_labels()

In [None]:
# Define helper functions for label categorization

# Function that returns the label for an interval of data samples
# Input: start index (type=int), end index (type=int)
# Output: label (type=str)
def get_label_index(start_index, end_index, label_indices):
  for i in range(len(label_indices)):
    if (start_index<=label_indices[i])and(end_index>=label_indices[i]):
      return i
  return -1

# Function that returns the label class
# Input: label (type=str)
# Output: label_class (type=str)
def adjust_label(label):
  B = "LRje"
  S = "Aasj"
  V = "!VE[]"
  F = "F"
  Q = "f/Q"
  if(label=="N"):
    return "N"
  elif(B.find(label)!=-1):
    return "B"
  elif(S.find(label)!=-1):
    return "S"
  elif(V.find(label)!=-1):
    return "V"
  elif(F.find(label)!=-1):
    return "F"
  elif(Q.find(label)!=-1):
    return "Q"
  else:
    return "*"

In [None]:
# Function that returns sampling frequency for a patient file
# Input: patient number (type=int)
# Output: sampling frequency (type=int)
def get_sampling_frequency(path,patient_number):
  temp_sig, temp_fields = wfdb.rdsamp(get_path(path,patient_number), channels=[0])
  return temp_fields['fs']


# Function that returns peak indices for a patient file
# Input: patient number (type=int)
# Output: peak indices (type=list)
def get_peak_indices(path,patient_number):
  temp_sig, temp_fields = wfdb.rdsamp(get_path(path,patient_number), channels=[0])
  peak_inds = processing.xqrs_detect(sig=temp_sig[:,0], fs=temp_fields['fs'])
  return peak_inds

# Function that returns start indices for beats for the whole patient file
# Input: qrs indices (type=list)
# Output: start indices (type=list)
def get_start_index(qrs_indices):
  temp_list = [0]
  for i in range(len(qrs_indices)-1):
    temp_avg = int((qrs_indices[i+1]+qrs_indices[i])/2)
    temp_list.append(temp_avg)
  return temp_list

In [None]:
temp_qrs_index = get_peak_indices(0,0)
temp_start_index = get_start_index(temp_qrs_index)
temp_signal, temp_fields = wfdb.rdsamp(get_path(0,0),channels=[0])

# For visual demonstration purposes
for i in range(21):
  temp_start = temp_start_index[i]
  temp_end = temp_start_index[i+1]
  wfdb.plot.plot_items(signal=temp_signal[temp_start:temp_end])

In [None]:
# .csv generation for the testing data
filename = "processed_dataset(1_signal).csv"
outfile = open(filename, "w")
out_csv = csv.writer(outfile)

# Write CSV header with lead names
out_csv.writerow(["Label", "Signal"])
count = 0

for path in range(11):
  print("Obtaining data from path #", path)
  for patient_number in range(3):
    try:
      annotation = wfdb.rdann(get_path(path,patient_number), extension='atr')
      label_index_list = annotation.sample
      label_symbol_list = annotation.symbol
      sampling_frequency = get_sampling_frequency(path,patient_number)
      qrs_index = get_peak_indices(path,patient_number)
      start_ind = get_start_index(qrs_index)
      signal, fields = wfdb.rdsamp(get_path(path,patient_number),channels=[0])

      print("Obtaining data from patient #", patient_number)

      for i in range(len(qrs_index)-1):
          csv_row1 = []
          cropped_signal = signal[start_ind[i]:start_ind[i+1]]

          cropped_signal = np.array(cropped_signal)

          # data normalization step
          normalized_signal = np.interp(cropped_signal, (cropped_signal.min(), cropped_signal.max()), (0, 1))
          cropped_signal = normalized_signal.flatten().tolist()

          label_index = get_label_index(start_index=start_ind[i], end_index=start_ind[i+1], label_indices=label_index_list)
          if (label_index==-1):
            continue
          label = label_symbol_list[label_index]
          label = adjust_label(label)
          if (label=="*"):
            continue
          csv_row1.append(label)
          csv_row1.append(cropped_signal)
          out_csv.writerow(csv_row1)
    except FileNotFoundError:
      continue