In [1]:
import time
from collections import Counter
import wfdb
import numpy as np
from scipy.signal import resample
import pywt
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from collections import Counter

In [2]:
def WTfilt_1d(sig):
    """
    对信号进行小波变换滤波
    :param sig: 输入信号，1-d array
    :return: 小波滤波后的信号，1-d array
    """
    coeffs = pywt.wavedec(sig, 'db6', level=9)
    coeffs[-1] = np.zeros_like(coeffs[-1])
    coeffs[-2] = np.zeros_like(coeffs[-2])
    coeffs[0] = np.zeros_like(coeffs[0])
    sig_filt = pywt.waverec(coeffs, 'db6')
    return sig_filt

# def Z_ScoreNormalization(data):
#     mean = np.mean(data, axis=0)
#     std_dev = np.std(data, axis=0)
#     normalized_data = (data - mean) / std_dev
#
#     return normalized_data

def min_max_normalization(signal, new_min=0, new_max=1):
    signal = np.array(signal)
    min_val = np.min(signal)
    max_val = np.max(signal)

    if max_val == min_val:
        return np.zeros_like(signal)  # 避免除零错误，所有值相同则返回全零

    norm_signal = (signal - min_val) / (max_val - min_val)  # 归一化到 [0,1]
    norm_signal = norm_signal * (new_max - new_min) + new_min  # 缩放到 [new_min, new_max]

    return norm_signal

In [3]:
# 检查文件夹是否存在
data_folder = 'mit-bih-arrhythmia-database-1.0.0'
records_file = data_folder + '/RECORDS'


with open(records_file, 'r') as file:
    dat_files = file.read().splitlines()  # 读取所有行并去除换行符
    print(dat_files)

['100', '101', '102', '103', '104', '105', '106', '107', '108', '109', '111', '112', '113', '114', '115', '116', '117', '118', '119', '121', '122', '123', '124', '200', '201', '202', '203', '205', '207', '208', '209', '210', '212', '213', '214', '215', '217', '219', '220', '221', '222', '223', '228', '230', '231', '232', '233', '234']


In [4]:

def data_seg(signal, r_peaks, ann_label, sam_rate):

    AAMI_MIT  = ['L', 'N', 'R', 'V']

    data_seg = []
    label_seg = []

    signal = signal.flatten()
    signal = WTfilt_1d(signal)


    for i in range(len(r_peaks)):
        start = r_peaks[i] - (sam_rate // 2)
        end = r_peaks[i] + (sam_rate // 2)
        if end > len(signal) or start < 0:
            continue

        signal_seg = signal[start:end]
        signal_seg = min_max_normalization(signal=signal_seg)

        if ann_label[i] in AAMI_MIT:
            # 如果类别数量等于1，打印此类别
            label_seg += ann_label[i]
            data_seg.append(signal_seg)


    return data_seg, label_seg

In [5]:
data_ls = []
label_ls = []
for i in range(len(dat_files)):
    ecg_hea = wfdb.rdheader(data_folder + '/' + dat_files[i], rd_segments=False)
    fs = ecg_hea.fs  # 采样率
    ecg_channel = np.where(np.array(ecg_hea.sig_name) == 'MLII')[0]


    if ecg_channel.size > 0:
        record = wfdb.rdrecord(data_folder + '/' + dat_files[i], sampfrom=0, physical=True, channels=ecg_channel.tolist())
        ann = wfdb.rdann(data_folder + '/' + dat_files[i], 'atr')

        signal = record.p_signal
        r_peaks = ann.sample
        labels = ann.symbol

        segments_data, segments_label = data_seg(signal, r_peaks, labels, fs)

        data_ls = data_ls + segments_data
        label_ls = label_ls + segments_label


In [6]:
data_ls = np.array(data_ls)
label_ls = np.array(label_ls)
print(data_ls.shape, label_ls.shape)

(97198, 360) (97198,)


In [7]:
from collections import Counter

count_dict = Counter(label_ls)

print(count_dict)

Counter({'N': 74749, 'L': 8071, 'R': 7255, 'V': 7123})


In [8]:
np.save('mit_adb_data.npy', data_ls)
np.save('mit_adb_label.npy', label_ls)

In [9]:
from datetime import datetime

# 获取当前时间
current_time = datetime.now()

# 打印当前时间
print("当前时间：", current_time.strftime("%Y-%m-%d %H:%M:%S"))


当前时间： 2025-03-10 12:44:35


In [None]:
adb_data = np.load('mit_adb_data.npy')
adb_label = np.load('mit_adb_label.npy')

counter_adb = Counter(adb_label)
print('adb count: ', counter_adb)

label_map = {'L': 0, 'N': 1, 'R': 2, 'V': 3}

adb_label = np.vectorize(label_map.get)(adb_label)
print('adb count: ', Counter(adb_label))

In [None]:
data = []
labels = []
for i in range(4):
    indices = np.where(adb_label == i)[0]
    sel_indices = np.random.choice(indices, size=7000, replace=False)
    class_data = adb_data[sel_indices]
    class_label = adb_label[sel_indices]
    data.append(class_data)
    labels.extend([class_label])

data = np.array(np.concatenate(data, axis=0))
labels = np.concatenate(labels)
print(data.shape, labels.shape)

In [None]:
# 打乱数据和标签，保持对应关系
x_shuffled, y_shuffled = shuffle(data, labels, random_state=2025)
print(x_shuffled.shape)  #
print(y_shuffled.shape)  #

x_train, x_test, y_train, y_test = train_test_split(x_shuffled, y_shuffled, test_size=0.1, random_state=2025)
print(Counter(y_train))
print(Counter(y_test))

In [None]:
np.save('adb_train_data.npy', x_train)
np.save('adb_train_label.npy', y_train)
np.save('adb_test_data.npy', x_test)
np.save('adb_test_label.npy', y_test)

In [None]:
from datetime import datetime

# 获取当前时间
current_time = datetime.now()

# 打印当前时间
print("当前时间：", current_time.strftime("%Y-%m-%d %H:%M:%S"))