In [None]:
pip install wfdb

In [None]:
import wfdb
import csv
from google.colab import drive
import pandas as pd
from sklearn.decomposition import PCA
from sklearn.svm import SVC
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from wfdb import processing
import numpy as np
import statistics
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import DataLoader, TensorDataset
import random
from sklearn.model_selection import train_test_split, learning_curve
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from sklearn.model_selection import LearningCurveDisplay
import seaborn as sns
random.seed(4)

drive.mount('/content/drive')

In [None]:
# Function that returns path to the relevant patient files
# Input: patient number (type=int)
# Output: path (type=string)
def get_path(patient_number):
  return '/content/drive/My Drive/ECG Project Files/mit-bih-arrhythmia-database-1.0.0/'+str(patient_number)

In [None]:
signals, fields = wfdb.rdsamp(get_path(100), sampfrom=3000, channels=[0,1])
print("Signal:", signals)
print("Fields:", fields)

In [None]:
# Data visualization
temp_record = wfdb.rdrecord(get_path(205), sampto=3000)
temp_annotation = wfdb.rdann(get_path(205), extension='atr', sampto=3000)
wfdb.plot_wfdb(record=temp_record, annotation=temp_annotation, plot_sym=True, time_units='seconds', title='MIT-BIH Record 100', figsize=(10,4), ecg_grids='all')

In [None]:
#All labels present in the dataset
wfdb.show_ann_labels()

In [None]:
# Function that returns the label for an interval of data samples
# Input: start index (type=int), end index (type=int)
# Output: label (type=str)
def get_label_index(start_index, end_index, label_indices):
  for i in range(len(label_indices)):
    if (start_index<=label_indices[i])and(end_index>=label_indices[i]):
      return i
  return -1

# Function that returns the label class
# Input: label (type=str)
# Output: label_class (type=str)
def adjust_label(label):
  B = "LRje"
  S = "Aasj"
  V = "!VE[]"
  F = "F"
  Q = "f/Q"
  if(label=="N"):
    return "N"
  elif(B.find(label)!=-1):
    return "B"
  elif(S.find(label)!=-1):
    return "S"
  elif(V.find(label)!=-1):
    return "V"
  elif(F.find(label)!=-1):
    return "F"
  elif(Q.find(label)!=-1):
    return "Q"
  else:
    return "*"

In [None]:
# Function that returns sampling frequency for a patient file
# Input: patient number (type=int)
# Output: sampling frequency (type=int)
def get_sampling_frequency(patient_number):
  temp_sig, temp_fields = wfdb.rdsamp(get_path(patient_number), channels=[0])
  return temp_fields['fs']


# Function that returns peak indices for a patient file
# Input: patient number (type=int)
# Output: peak indices (type=list)
def get_peak_indices(patient_number):
  temp_sig, temp_fields = wfdb.rdsamp(get_path(patient_number), channels=[0])
  peak_inds = processing.xqrs_detect(sig=temp_sig[:,0], fs=temp_fields['fs'])
  return peak_inds

# Function that returns start indices for beats for the whole patient file
# Input: qrs indices (type=list)
# Output: start indices (type=list)
def get_start_index(qrs_indices):
  temp_list = [0]
  for i in range(len(qrs_indices)-1):
    temp_avg = int((qrs_indices[i+1]+qrs_indices[i])/2)
    temp_list.append(temp_avg)
  return temp_list

In [None]:
temp_qrs_index = get_peak_indices(105) #patient specific list
temp_start_index = get_start_index(temp_qrs_index)
temp_signal, temp_fields = wfdb.rdsamp(get_path(105),channels=[0])

# Sample visualization
for i in range(2):
  temp_start = temp_start_index[i]
  temp_end = temp_start_index[i+1]
  wfdb.plot.plot_items(signal=temp_signal[temp_start:temp_end])

In [None]:
# Create CSV
filename = "processed_dataset(1_signal).csv"
outfile = open(filename, "w")
out_csv = csv.writer(outfile)

# Write CSV header with lead names
out_csv.writerow(["Label", "Signal"])
count = 0
for patient_number in range(100,235):
  try:
    annotation = wfdb.rdann(get_path(patient_number), extension='atr')
    label_index_list = annotation.sample
    label_symbol_list = annotation.symbol
    sampling_frequency = get_sampling_frequency(patient_number)
    qrs_index = get_peak_indices(patient_number)
    start_ind = get_start_index(qrs_index)
    signal, fields = wfdb.rdsamp(get_path(patient_number),channels=[0])
    print("Obtaining data from patient #", patient_number)

    for i in range(len(qrs_index)-1):
        csv_row1 = []
        cropped_signal = signal[start_ind[i]:start_ind[i+1]]
        cropped_signal = np.array(cropped_signal).flatten()
        cropped_signal = cropped_signal.tolist()
        label_index = get_label_index(start_index=start_ind[i], end_index=start_ind[i+1], label_indices=label_index_list)
        if (label_index==-1):
          continue
        label = label_symbol_list[label_index]
        label = adjust_label(label)
        if (label=="*"):
          continue
        csv_row1.append(label)
        csv_row1.append(cropped_signal)
        out_csv.writerow(csv_row1)
  except FileNotFoundError:
    continue