In [58]:
import wfdb
import os
import pandas as pd
import wfdb.processing as wp
import numpy as np
import pickle
from biosppy.signals import ecg, tools

import torch
import torch.nn as nn
import torch.utils.data as data
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch import nn, optim

import pytorch_model_summary

from sklearn.preprocessing import MinMaxScaler as mms

import matplotlib.pyplot as plt
import matplotlib
import matplotlib.patches as patches
from matplotlib.patheffects import withStroke
from matplotlib.gridspec import GridSpec

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["CUDA_VISIbLE_DEVICES"] = "0"

In [3]:
# Data Extract From Physionet *.dat

input_path = "../physionet/mit-bih_arr/1.0.0/"
records = open(input_path+"RECORDS","r")
records_list = []
for l in records:
    l = l.rstrip()
    records_list.append(l)
records.close()
print(records_list)

NORMAL_ANN = ['N', 'L', 'R']
SUPRA_ANN = ['e',"j","A","a","J","S"]
VENTRI_ANN = ['V', 'E']
FUSION_ANN = ['F']
UNCLASS_ANN = ['/', 'f', 'Q']
ANN_DICT = {0:"N", 1:"S", 2:"V", 3:"F", 4:"Q", 5:"-"}

['100', '101', '102', '103', '104', '105', '106', '107', '108', '109', '111', '112', '113', '114', '115', '116', '117', '118', '119', '121', '122', '123', '124', '200', '201', '202', '203', '205', '207', '208', '209', '210', '212', '213', '214', '215', '217', '219', '220', '221', '222', '223', '228', '230', '231', '232', '233', '234']


In [4]:
def extract_from_sbj(idx, sampfrom, sampto):
    record_sig = wfdb.rdsamp(input_path+records_list[idx], channels=[0], sampfrom=sampfrom, sampto = sampto)
    record_ann = list(wfdb.rdann(input_path+records_list[idx],"atr",sampfrom=sampfrom, sampto = sampto).sample)[0:] # R Peak x_position
    record_ann_sym = list(wfdb.rdann(input_path+records_list[idx], "atr",sampfrom=sampfrom, sampto = sampto).symbol)[0:] # R Peak Symbol
    interval = wp.ann2rr(input_path+records_list[idx], 'atr', as_array=True)
    
    """
    0 : N | 1 : S | 2 : V | 3 : F | 4 : Q
    """
    for i,sym in enumerate(record_ann_sym):
        if sym in NORMAL_ANN:
            record_ann_sym[i] = 0
        elif sym in SUPRA_ANN:
            record_ann_sym[i] = 1
        elif sym in VENTRI_ANN:
            record_ann_sym[i] = 2
        elif sym in FUSION_ANN:
            record_ann_sym[i] = 3
        elif sym in UNCLASS_ANN:
            record_ann_sym[i] = 4
        else:
            record_ann_sym[i] = 5

    return record_sig, record_ann, record_ann_sym

In [None]:
seg_points = []

"""
첫 rpeak는 next_Rpeak랑만 계산
"""
l_ann = len(record_ann)
last_points = len(record_sig[0])-1

N_Sig = []
S_Sig = []
V_Sig = []
F_Sig = []
Q_Sig = []

for i in range(l_ann):
    
    now_r = record_ann[i]
    now_sym = (ANN_DICT[record_ann_sym[i]],record_ann_sym[i])
    if now_sym[0] == "-":
        continue
    if i == 0: # 첫번째 예외처리
        next_r = record_ann[i+1]
        next_points = (next_r + now_r) // 2
        prev_points = 0
    elif i == l_ann-1: #마지막 예외처리
        prev_r = record_ann[i-1]
        prev_points = (prev_r + now_r) // 2
        next_points = last_points
    else: 
        prev_r = record_ann[i-1]
        next_r = record_ann[i+1]

        prev_points = (prev_r+now_r)//2
        next_points = (next_r+now_r)//2
    
    
#     print(prev_points, next_points, "Size :", (next_points-prev_points), now_sym)
    # Beat Segmentation
    if now_sym[0] == "N":
        N_Sig.append([prev_points,next_points])
    elif now_sym[0] == "S":
        S_Sig.append([prev_points,next_points])  
    elif now_sym[0] == "V":
        V_Sig.append([prev_points,next_points])
    elif now_sym[0] == "F":
        F_Sig.append([prev_points,next_points])
    elif now_sym[0] == "Q":
        Q_Sig.append([prev_points,next_points])
    

In [22]:
idx = 1
sampfrom = 0
sampto = 650000
record_sig, record_ann, record_ann_sym = extract_from_sbj(idx,sampfrom,sampto)

%matplotlib notebook
fig = plt.figure(figsize=(9,7))
ax1 = fig.add_subplot(1,1,1)
ax1.set_title("Subject {}".format(records_list[idx]))
ax1.plot(record_sig[0][0:500])
plt.show()

<IPython.core.display.Javascript object>

### Scipy Filtering

In [14]:
import scipy.io.wavfile
import scipy.signal

#### MIT-BIH Arrhythmia Database SampleRate : 360Hz

In [25]:
sampleRate = 360
data = np.squeeze(np.array(record_sig[0]))
times = np.arange(len(data))/sampleRate

b, a = scipy.signal.butter(3, 0.1) # low pass filter 
filtered = scipy.signal.filtfilt(b, a, data)

In [65]:
%matplotlib notebook
fig = plt.figure(figsize=(9.7,5))
gs = GridSpec(nrows=3, ncols=2)
ax1 = fig.add_subplot(gs[0,0])
ax1.set_title("low pass Filter Before".format(records_list[idx]))
ax1.plot(record_sig[0][0:500])
ax1.margins(0, .05)

ax2 = fig.add_subplot(gs[0,1])
ax2.set_title("low pass Filter After".format(records_list[idx]))
ax2.plot(filtered[0:500])
ax2.margins(0, .05)

ax3 = fig.add_subplot(gs[1,:])
ax3.set_title("Lowpass Filter whole data Before")
ax3.plot(record_sig[0])
ax3.margins(0, .05)

ax4 = fig.add_subplot(gs[2,:])
ax4.set_title("Lowpass Filter whole data After")
ax4.plot(filtered)
ax4.margins(0, .05)
plt.tight_layout()
plt.show()

<IPython.core.display.Javascript object>