In [1]:
import wfdb
import os
import pandas as pd
import wfdb.processing as wp
import numpy as np
import pickle
from biosppy.signals import ecg, tools


import torch
import torch.nn as nn
import torch.utils.data as data
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch import nn, optim

import pytorch_model_summary

from sklearn.preprocessing import MinMaxScaler as mms

import matplotlib.pyplot as plt
import matplotlib

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["CUDA_VISIbLE_DEVICES"] = "0"

In [2]:
#data extract from Physionet

input_path = "../physionet/mit-bih_arr/1.0.0/x_mitdb/"
records = open(input_path+"RECORDS","r")
records_list = []
for l in records:
    l = l.rstrip()
    records_list.append(l)
records.close()
print(records_list)

['x_108', 'x_109', 'x_111', 'x_112', 'x_113', 'x_114', 'x_115', 'x_116', 'x_117', 'x_121', 'x_122', 'x_123', 'x_124', 'x_220', 'x_221', 'x_222', 'x_223', 'x_228', 'x_230', 'x_231', 'x_232', 'x_233', 'x_234']


In [3]:
NORMAL_ANN = ['N', 'L', 'R', 'e', 'j']
SUPRA_ANN = ['A', 'a', 'J', 'S']
VENTRI_ANN = ['V', 'E']
FUSION_ANN = ['F']
UNCLASS_ANN = ['/', 'f', 'Q']

In [4]:
window_size = 320
def extract_from_sbj(idx):
    record_sig = wfdb.rdsamp(input_path+records_list[idx], channels=[0], sampfrom=0)
    record_ann = list(wfdb.rdann(input_path+records_list[idx],"atr", sampfrom=0).sample)[1:]
    record_ann_sym = list(wfdb.rdann(input_path+records_list[idx], "atr", sampfrom=0).symbol)[1:]
    interval = wp.ann2rr(input_path+records_list[idx], "atr", as_array=True)
    
    rpeaks_ch = ecg.christov_segmenter(signal=record_sig[0][:,0], sampling_rate=360)
    rpeaks_ha = ecg.hamilton_segmenter(signal=record_sig[0][:,0], sampling_rate=360)
    
    #h_beat = ecg.extract_heartbeats(signal=record_sig[0][:,0], rpeaks=rpeaks_ha, sampling_rate=360)
    NORMAL_ANN = ['N', 'L', 'R', 'e', 'j']
    SUPRA_ANN = ['A', 'a', 'J', 'S']
    VENTRI_ANN = ['V', 'E']
    FUSION_ANN = ['F']
    UNCLASS_ANN = ['/', 'f', 'Q']
    
    
    """
    0 : N, 1 : S, 2 : V, 3 : F, 4 : Q, 5: "/"
    """
    for i,sym in enumerate(record_ann_sym):
        if sym in NORMAL_ANN:
            record_ann_sym[i] = 0
        elif sym in SUPRA_ANN:
            record_ann_sym[i] = 1
        elif sym in VENTRI_ANN:
            record_ann_sym[i] = 2
        elif sym in FUSION_ANN:
            record_ann_sym[i] = 3
        elif sym in UNCLASS_ANN:
            record_ann_sym[i] = 4
        else:
            record_ann_sym[i] = 5
    
    def beat_Padding(beat,w_size):
        beat_len = len(beat)
        if beat_len <= w_size:
            pd_size = (w_size - beat_len)//2
        pad_beat = np.pad(beat, (pd_size, pd_size), "constant", constant_values=0)
        
        if len(pad_beat) < w_size:
            pad_beat = np.pad(pad_beat, (0, w_size - len(pad_beat)), "constant", constant_values=0)
        elif len(pad_beat) > w_size:
            pad_beat = pad_beat[0:-(len(pad_beat)-w_size)]
        
        
        return pad_beat
    
    def beat_extraction(signal,rpeaks): # Peak를 기점으로 앞뒤로 0.3초 정도의 window를 가지는 Beat들
        beat = []
        for i,rp in enumerate(rpeaks):
            if i == 0: # 0.3초 정도의 window
                start = int(rp-(0.3*360))
                if start < 0:
                    start = 0
                end = int(rp+(0.3*360))
                sig_beat = beat_Padding(signal[start:end], window_size)
                beat.append(sig_beat)
                continue
            
            # 양쪽으로 30% 정도 남겨서 짜르기
            
            start = int(rp-(0.3 * 360))
            end = int(rp+(0.3 * 360))
            
            if end > len(signal):
                end = len(signal-1)
                
            # Beat Padding
            sig_beat = beat_Padding(signal[start:end], window_size)
            beat.append(sig_beat)

        return {"Beats":beat,"rpeaks":rpeaks, "symbol":record_ann}
    
    beat = beat_extraction(record_sig[0][:,0], record_ann)
    
    print("Beats, Rpeaks Len, ",len(beat["Beats"]),len(beat["rpeaks"]), len(record_ann))
    
    beat["symbol"]=record_ann_sym[:]
    max_len = -1
    min_len = 9999
    for i,datas in enumerate(zip(beat["rpeaks"],beat["symbol"])):
        max_len = max(max_len, len(beat["Beats"][i]))
        min_len = min(min_len, len(beat["Beats"][i]))
    return max_len, min_len, beat

datas = {}

for i,rec in enumerate(records_list):
    print("Record",rec)
    maxs, mins, beats = extract_from_sbj(i)
    datas[rec]=beats
    print(maxs,mins)

Record x_108
Beats, Rpeaks Len,  581 581 581
320 320
Record x_109
Beats, Rpeaks Len,  857 857 857
320 320
Record x_111
Beats, Rpeaks Len,  702 702 702
320 320
Record x_112
Beats, Rpeaks Len,  854 854 854
320 320
Record x_113
Beats, Rpeaks Len,  581 581 581
320 320
Record x_114
Beats, Rpeaks Len,  559 559 559
320 320
Record x_115
Beats, Rpeaks Len,  635 635 635
320 320
Record x_116
Beats, Rpeaks Len,  797 797 797
320 320
Record x_117
Beats, Rpeaks Len,  505 505 505
320 320
Record x_121
Beats, Rpeaks Len,  611 611 611
320 320
Record x_122
Beats, Rpeaks Len,  839 839 839
320 320
Record x_123
Beats, Rpeaks Len,  506 506 506
320 320
Record x_124
Beats, Rpeaks Len,  530 530 530
320 320
Record x_220
Beats, Rpeaks Len,  711 711 711
320 320
Record x_221
Beats, Rpeaks Len,  837 837 837
320 320
Record x_222
Beats, Rpeaks Len,  750 750 750
320 320
Record x_223
Beats, Rpeaks Len,  846 846 846
320 320
Record x_228
Beats, Rpeaks Len,  740 740 740
320 320
Record x_230
Beats, Rpeaks Len,  820 820 820
3

In [5]:
# Save pickle File
pickle_path = "./pickle/test"

import pickle

for rec in records_list:
    pickle_data = pickle_path + "/" + rec + ".pkl"
    with open(pickle_data,"wb") as f:
        pickle.dump(datas[rec], f, pickle.HIGHEST_PROTOCOL)

### Trained Model load and Evaluate