In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
import os


#from utils import print_signal_qrs, print_signal, calcul_f1, perf


import tensorflow as tf

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv1D, MaxPooling1D, Flatten, Dense, Dropout
from keras.optimizers import SGD
from keras.callbacks import ReduceLROnPlateau
from scipy.signal import resample
from scipy.interpolate import UnivariateSpline

import pickle
import joblib

# Le preprocessing dans l'article est un peu bizarre alors j'ai juste demandé à gpt de me faire un truc à peu près

In [2]:
def preprocessing(signal, fs):
    clean_baseline = baseline_wander_removal(signal)
    normalization = normalize_signal(clean_baseline)
    return normalization

def baseline_wander_removal(ecg_signal, window_size=4, sampling_rate=360, subsample_rate=200):
    # Convert window size to number of samples
    window_samples = window_size * sampling_rate
    
    # Initialize an empty array to store the corrected signal
    corrected_signal = np.zeros_like(ecg_signal)
    
    # Process the signal in windows
    for start in range(0, len(ecg_signal), window_samples):
        end = min(start + window_samples, len(ecg_signal))
        segment = ecg_signal[start:end]
        
        # Resample segment to reduce computational load
        resampled_segment = resample(segment, subsample_rate)
        
        # Perform LOESS regression
        x = np.linspace(0, len(resampled_segment) - 1, len(resampled_segment))
        spline = UnivariateSpline(x, resampled_segment, s=len(resampled_segment))
        baseline = spline(x)
        
        # Upsample the baseline back to the original sampling rate
        baseline_full = resample(baseline, len(segment))
        
        # Subtract the baseline from the original segment
        corrected_signal[start:end] = segment - baseline_full
    
    return corrected_signal

def normalize_signal(ecg_signal):
    mean_val = np.mean(ecg_signal)
    std_val = np.std(ecg_signal)
    
    # Subtract mean and divide by standard deviation
    normalized_signal = (ecg_signal - mean_val) / std_val
    
    return normalized_signal

In [3]:
def create_windows(ecg_signal, qrs_positions, fs, num_negative_samples=3):
    points_before = int(100 * fs / 1000)
    points_after = int(300 * fs / 1000)
    total_points = points_before + points_after + 1

    data_windows = []
    labels = []
    
    tolerance = int(40 * fs / 1000)
    
    for qrs in qrs_positions:
        start = qrs - points_before
        end = qrs + points_after + 1
        if start >= 0 and end <= len(ecg_signal):
            window = ecg_signal[start:end]
            if len(window) != total_points:
                print(f"Taille incorrecte: {len(window)} au lieu de {total_points}, start={start}, end={end}")
                continue
            data_windows.append(window)
            labels.append(1)
    
    num_qrs_positions = len(qrs_positions)
    signal_length = len(ecg_signal)
    negative_count = 0
    
    
    data_windows_array = np.array([np.array(window) for window in data_windows if len(window) == total_points])
    
    return data_windows_array, np.array(labels)

In [4]:
def create_model(input_shape):
    model = Sequential()
    model.add(Conv1D(filters=32, kernel_size=3, activation='relu', input_shape=input_shape))
    model.add(MaxPooling1D(pool_size=2))
    model.add(Conv1D(filters=64, kernel_size=3, activation='relu'))
    model.add(MaxPooling1D(pool_size=2))
    model.add(Flatten())
    model.add(Dense(100, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(1, activation='sigmoid'))
    
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

In [5]:
# Créer le modèle
input_shape = (145, 1)  # 145 points par fenêtre, 1 canal
model = create_model(input_shape)
model.summary()

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [6]:
X_train_all, X_test_all = [], []
y_train_all, y_test_all = [], []

for file in ['101', '104', '107', '113', '116', '121', '201', '207', '209', '212', '215', '219', '228', '233']:
    print(file)
    df = pd.read_csv(f'data_csv/mit_bih_Arrhythmia/{file}.csv') #207
    ecg_signal = None
    if file == "104":
        ecg_signal = np.array(df["V2"], dtype=np.float32)#[:10000]
    else:
        ecg_signal = np.array(df["MLII"], dtype=np.float32)#[:10000]
    fs = 360
    QRS = df["labels"].dropna().astype(int).tolist()
    labels = np.zeros(len(ecg_signal))

    cleaned_ecg = preprocessing(ecg_signal, fs)
    
    X, y = create_windows(cleaned_ecg, QRS, fs)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.001, random_state=42)

    X_train = X_train.reshape((X_train.shape[0], X_train.shape[1], 1))
    X_test = X_test.reshape((X_test.shape[0], X_test.shape[1], 1))
    
    X_train_all.extend(X_train)
    X_test_all.extend(X_test)
    y_train_all.extend(y_train)
    y_test_all.extend(y_test)

101
104
107
113
116
121
201
207
209
212
215
219
228
233


In [7]:
X_train = X_train.reshape((X_train.shape[0], X_train.shape[1], 1))
X_test = X_test.reshape((X_test.shape[0], X_test.shape[1], 1))

input_shape = (145, 1)  # 145 points par fenêtre, 1 canal
model = create_model(input_shape)
history = model.fit(np.array(X_train_all), np.array(y_train_all), epochs=5, batch_size=64)

Epoch 1/5


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


[1m526/526[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 3ms/step - accuracy: 0.9953 - loss: 0.0158
Epoch 2/5
[1m526/526[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - accuracy: 1.0000 - loss: 8.2880e-07
Epoch 3/5
[1m526/526[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - accuracy: 1.0000 - loss: 2.7140e-07
Epoch 4/5
[1m526/526[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - accuracy: 1.0000 - loss: 6.9615e-08
Epoch 5/5
[1m526/526[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 3ms/step - accuracy: 1.0000 - loss: 6.8248e-08


In [None]:
model.save("../benchmark_qrs_detectors/model_CNN_arrhythmia.h5")