In [42]:
import pickle as pkl
import numpy as np
import wfdb
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from collections import defaultdict

In [43]:
def dynamic_replace(df):
    curr_aux = df['aux'].loc[0]
    for idx,x in enumerate(df['aux']):
        if x != '':
            curr_aux = x
        df.loc[idx, 'aux'] = curr_aux
    return df

In [44]:
def create_index_df(desired_segment_len=3600, basic_arr_path="data/mit-bih-arrhythmia-database-1.0.0"):
    desired_segment_len = 3600
    basic_arr_path = "data/mit-bih-arrhythmia-database-1.0.0"
    arr_db = wfdb.get_record_list('mitdb')
    num_samples_in_record = 30 * 60 * 360
    
    # for selection and sampling 
    segment_dict_ann = {}
    record_count = 0
    
    
    for _, record_id in enumerate(arr_db):
        record_path = os.path.join(basic_arr_path, str(record_id))

        ann = wfdb.rdann(record_path, 'atr', sampto=num_samples_in_record, return_label_elements=['description',
                                                                                                  'symbol', 'label_store'])
        df = pd.DataFrame({'description': ann.description, 'sample': ann.sample, 'symbol': ann.symbol,
                           'label_store': ann.label_store, 'aux': ann.aux_note})
        df = dynamic_replace(df)
        counter = 0
        reset_flag = True
        allowed_labels = ['Normal beat']
        allowed_symbols = ['N']

        normal_counter = 0
        for i in range(1, df.shape[0] - 1):
            curr_label, curr_sample, curr_symbol = df.loc[i, ['description', 'sample', 'symbol']]
            if curr_label == 'Normal beat':
                normal_counter += 1
            if reset_flag:
                start_sample = curr_sample
                ann_num_start = i
                normal_counter = 0
                allowed_labels = ['Normal beat']
                allowed_symbols = ['N']
            next_label, next_sample, next_symbol = df.loc[i + 1, ['description', 'sample', 'symbol']]
            if curr_label == next_label or next_label in allowed_labels or len(allowed_labels) < 2:

                if next_label not in allowed_labels:
                    allowed_labels.append(next_label)
                    allowed_symbols.append(next_symbol)
                ann_num_end = i+1
                counter += next_sample - curr_sample
                reset_flag = False
                if counter > desired_segment_len:
                    counter = 0
                    reset_flag = True
                    signal = wfdb.rdsamp(record_path, sampfrom=start_sample, sampto=start_sample + 3600)[0][:, 0]
                    normal_ratio = normal_counter / (ann_num_end - ann_num_start)
                    
                    if df.loc[ann_num_start:ann_num_end]['aux'].unique().shape[0] == 1:
                        aux_seg = df.loc[ann_num_start]['aux']
                    else:
                        aux_seg = 'invalid'
                    
                    segment_dict_ann[record_count] = [record_id,  allowed_labels[-1], signal, normal_ratio,
                                                       allowed_symbols[-1], aux_seg]
                    record_count = record_count + 1
            else:
                counter = 0
                normal_counter = 0
                reset_flag = True
                allowed_labels = ['Normal beat']
                allowed_symbols = ['N']


    return segment_dict_ann

In [4]:
d_dict = create_index_df()
# basic_arr_path = "data/mit-bih-arrhythmia-database-1.0.0"
# record_id = 102
# record_path = os.path.join(basic_arr_path, str(record_id))
# num_samples_in_record = 30 * 60 * 360
# ann = wfdb.rdann(record_path, 'atr', sampto=num_samples_in_record, return_label_elements=['description'])
# df = pd.DataFrame({'description': ann.description, 'sample': ann.sample,  'symbol': ann.symbol,
#                            'label_store': ann.label_store, 'aux': ann.aux_note})
# df['aux'].value_counts()
# df = dynamic_replace(df)
# df

In [5]:
seg_df = pd.DataFrame.from_dict(d_dict, orient='index')
seg_df.rename(columns={0: 'record_id', 1:'label', 2:'signal', 3:'normal_ratio',4:'symbol', 5: 'aux'}, inplace= True)


seg_df['aux'].value_counts()

(N         5020
(AFIB       640
(P          370
(N          165
invalid     126
(B          114
MISSB        78
(T           66
TS           41
(AFIB        25
(PREX        24
(AFL         21
(AFL         19
(IVR         11
(SVTA        11
(VFL         11
(VT           7
(SBR          7
(AB           3
(NOD          1
Name: aux, dtype: int64

In [10]:
seg_df['class'] = 'invalid'
seg_df['aux'] = seg_df['aux'].str.rstrip('\x00')

normal_aux_idx = (seg_df.aux == '(N') & (seg_df.symbol.apply(lambda x: True if x in num_labels_we_have.keys() else False))
seg_df.loc[normal_aux_idx, 'class'] = seg_df.loc[normal_aux_idx, 'symbol']

normal_label_idx = (seg_df.symbol == 'N') & (seg_df.aux.apply(lambda x: True if x in num_labels_we_have.keys() else False))
seg_df.loc[normal_label_idx, 'class'] = seg_df.loc[normal_label_idx, 'aux']

seg_df.loc[seg_df['aux'] == '(B', 'class'] = '(B'
seg_df.loc[seg_df['aux'] == '(T', 'class'] = '(T'
seg_df.loc[seg_df['aux'] == '(IVR', 'class'] = '(IVR'
seg_df.loc[seg_df['aux'] == '(P', 'class'] = '(P'
seg_df.loc[seg_df['aux'] == '(VFL', 'class'] = '(VFL'
# seg_df['class'] = seg_df['class'].str.lstrip('(')
seg_df['class'].value_counts()


N          3093
invalid    1612
L           493
R           454
(P          370
(AFIB       284
A           207
(B          114
(T           66
(PREX        23
(AFL         22
(VFL         11
(IVR         11
Name: class, dtype: int64

In [11]:
seg_df.head()

Unnamed: 0,record_id,label,signal,normal_ratio,symbol,aux,class
0,100,Atrial premature contraction,"[0.84, 0.765, 0.52, 0.17, -0.165, -0.365, -0.4...",0.846154,A,(N,A
1,100,Normal beat,"[0.885, 0.935, 0.835, 0.525, 0.12, -0.23, -0.4...",0.923077,N,(N,N
2,100,Normal beat,"[0.88, 0.76, 0.46, 0.07, -0.24, -0.425, -0.485...",0.923077,N,(N,N
3,100,Normal beat,"[0.885, 0.88, 0.68, 0.325, -0.08, -0.355, -0.5...",0.923077,N,(N,N
4,100,Normal beat,"[0.85, 0.815, 0.6, 0.28, -0.1, -0.37, -0.495, ...",0.923077,N,(N,N


In [7]:
num_labels_dict = {
'Normal beat': 283, # N
'Left bundle branch block beat': 103, #L
'Atrial premature beat': 66, # A
'Atrial flutter': 20,  # (AFL (aux)
'Atrial fibrillation': 135, # (AFIB (aux)
'Pre-excitation (WPW)': 21, # (PREX (aux)
'Premature ventricular contraction': 133, #V
'Ventricular bigeminy': 55, #  (B (aux)
'Ventricular trigeminy': 13,  # (T (aux)
'Ventricular tachycardia': 10,  # (VT (aux)
'Idioventricular rhythm': 10,  # (IVR (aux)
'Ventricular flutter': 10,   # (VFL (aux)
'Fusion of ventricular and normal beat': 11, #F
'Second-degree heart block': 10,
'Pacemaker rhythm': 45, # / 
'Supraventricular tachyarrhythmia': 13,  # (SVTA (aux)
'Right bundle branch block beat': 62, # R
                 }

num_labels_we_have = {
'N': 283, #'Normal beat'
'L': 103, #'Left bundle branch block beat'
'A': 66, #'Atrial premature beat':
# 'V': 133, #'Premature ventricular contraction'
# '!': 10,   #'Ventricular flutter'
# 'F': 11, # 'Fusion of ventricular and normal beat'
'R': 62, # 'Right bundle branch block beat'
'(AFL'  : 20,
'(AFIB' : 135,
'(PREX' : 21,
'(B'    : 55,
'(T'    : 13,
# '(VT'   : 10,
'(IVR'  : 10,
'(VFL'  : 10,
'(P': 45
# '(SVTA' : 13
            }

In [48]:
import random
random.seed(42)


# x = seg_df.loc[seg_df[seg_df.label == 'Normal beat'].index]

def sample_per_sym(seg_df, cls, num):
    # filter
    class_df = seg_df.loc[seg_df[seg_df['class'] == cls].index]
    # sample
    print(f' len df: {len(class_df)}')
    sample_df = class_df.sample(n=num)
    
    # add the index at the begining of the signal
    signals = np.stack(sample_df.signal.values)
    return idx, sample_df, signals


sample_df = []
signals = []
idx_arr = []
for cls, num in num_labels_we_have.items():
    print(f'class:{cls}, num: {num}')
    idx, sample_df_class, signals_class = sample_per_sym(seg_df, cls, num)
    sample_df.append(sample_df_class)
    signals.append(signals_class)
    idx_arr = np.append(idx_arr,idx)
sample_df = pd.concat(sample_df)



class:N, num: 283
 len df: 3093
class:L, num: 103
 len df: 493
class:A, num: 66
 len df: 207
class:R, num: 62
 len df: 454
class:(AFL, num: 20
 len df: 22
class:(AFIB, num: 135
 len df: 284
class:(PREX, num: 21
 len df: 23
class:(B, num: 55
 len df: 114
class:(T, num: 13
 len df: 66
class:(IVR, num: 10
 len df: 11
class:(VFL, num: 10
 len df: 11
class:(P, num: 45
 len df: 370


In [50]:
sample_df.reset_index()

Unnamed: 0,index,record_id,label,signal,normal_ratio,symbol,aux,class
0,1823,113,Normal beat,"[1.945, 1.92, 1.6, 0.96, 0.32, -0.32, -0.845, ...",0.900000,N,(N,N
1,3715,201,Normal beat,"[0.87, 0.93, 0.915, 0.805, 0.61, 0.38, 0.13, -...",0.888889,N,(N,N
2,2611,117,Normal beat,"[-0.65, -0.895, -1.16, -1.38, -1.5, -1.51, -1....",0.888889,N,(N,N
3,42,100,Normal beat,"[0.905, 0.98, 0.975, 0.795, 0.44, -0.015, -0.3...",0.933333,N,(N,N
4,54,100,Normal beat,"[0.855, 0.88, 0.74, 0.43, 0.01, -0.33, -0.52, ...",0.923077,N,(N,N
5,1794,113,Normal beat,"[1.82, 1.78, 1.54, 1.03, 0.39, -0.25, -0.74, -...",0.900000,N,(N,N
6,3350,123,Normal beat,"[0.995, 1.07, 0.985, 0.725, 0.315, -0.225, -0....",0.875000,N,(N,N
7,180,101,Normal beat,"[0.975, 0.715, 0.405, 0.145, -0.01, -0.02, -0....",0.916667,N,(N,N
8,325,101,Normal beat,"[1.14, 0.94, 0.64, 0.375, 0.19, 0.08, 0.0, -0....",0.900000,N,(N,N
9,3172,122,Normal beat,"[0.835, 0.845, 0.825, 0.8, 0.75, 0.64, 0.47, 0...",0.928571,N,(N,N


In [51]:
# [print(i.shape) for i in signals]

signals_arr = np.stack(sample_df.signal.values, axis=0)

# print(idx_arr[4])

# sample_df.head(5)
np.save('signals_arr.npy', signals_arr)

ew_num_arr = np.load('signals_arr.npy')
# np.savetxt('signals_arr.txt', signals_arr, fmt='%d')
# b = np.loadtxt('signals_arr.txt', dtype=float)
sample_df.to_pickle("./sample_df_bar.pkl")