# Workspace Utility Functions

In [None]:
# Define workspace utility functions
import os
import sys
import pickle
import subprocess
from IPython.display import FileLink

# Call this function when in need to free RAM space
def check_variables():
    """
    Check the memory usage of variables in the workspace.
    Print variables and their memory sizes in descending order.
    """
    # Get the memory size of each variable
    variable_sizes = {k: sys.getsizeof(v) for k, v in locals().items() if not k.startswith('__')}
    # Sort the variables based on their memory size
    sorted_variables = sorted(variable_sizes.items(), key=lambda x: x[1], reverse=True)
    # Print the variables and their memory sizes in descending order
    for var, size in sorted_variables:
        print(f"{var}: {size} bytes")

# Save anything via pickle
def save(item, name: str, path="/kaggle/working/"):
    """
    Save an item using pickle.

    Parameters:
        item: The item to be saved.
        name (str): The name of the file.
        path (str): The path where the file will be saved (default: "/kaggle/working/").
    """
    item_file = path + name
    with open(item_file, 'wb') as file:
        pickle.dump(item, file)

# Download item as zip
def download_file(source_path: str, download_file_name: str, output_path="/kaggle/working/"):
    """
    Create a zip file from the specified source path and provide a download link.
    
    Parameters:
        source_path (str): The path to the source file or directory to be zipped.
        download_file_name (str): The name of the zip file and download link.
        output_path (str): The output path for the zip file (default: "/kaggle/working/").
    """
    # Save the current working directory
    current_working_directory = os.getcwd()  
    os.chdir(output_path)

    try:
        zip_name = f"{download_file_name}.zip"
        command = f"zip {zip_name} {source_path} -r"
        result = subprocess.run(command, shell=True, capture_output=True, text=True)
        if result.returncode != 0:
            raise RuntimeError(f"Unable to run zip command! Error: {result.stderr}")

        display(FileLink(zip_name))
    finally:
        # Restore the original working directory
        os.chdir(current_working_directory)  

# Import Data

In [None]:
# List first content of `input_directory`
import os

input_directory = "/kaggle/working/AI project/train"
dir_list = sorted(os.listdir(input_directory))
dir_list[0:10]

In [None]:
from scipy.io import loadmat

def load_data_from_directory(input_directory):
    print('Loading data...')
    # Get name of label files for fs=128Hz contained in the folder
    beat_labeling_128 = []
    for f in dir_list:
        g = os.path.join(input_directory, f)
        if not f.lower().startswith('.') and f.lower().endswith('128_ann.mat') and os.path.isfile(g):
            beat_labeling_128.append(g)
    num_labels_128 = len(beat_labeling_128)

    # Get name of peak location files for fs=128Hz contained in the folder
    peak_locations_128 = []
    for f in dir_list:
        g = os.path.join(input_directory, f)
        if not f.lower().startswith('.') and f.lower().endswith('128_spk.mat') and os.path.isfile(g):
            peak_locations_128.append(g)
    num_peak_locations_128 = len(peak_locations_128)

    # Get name of signals fs=128Hz contained in the folder
    signals_128 = []
    for f in dir_list:
        g = os.path.join(input_directory, f)
        if not f.lower().startswith('.') and f.lower().endswith('_128.mat') and os.path.isfile(g):
            signals_128.append(g)
    num_signals_128 = len(signals_128)

    # Get name of label files for fs=250Hz contained in the folder
    beat_labeling_250 = []
    for f in dir_list:
        g = os.path.join(input_directory, f)
        if not f.lower().startswith('.') and f.lower().endswith('250_ann.mat') and os.path.isfile(g):
            beat_labeling_250.append(g)
    num_labels_250 = len(beat_labeling_250)

    # Get name of peak location files for fs=250Hz contained in the folder
    peak_locations_250 = []
    for f in dir_list:
        g = os.path.join(input_directory, f)
        if not f.lower().startswith('.') and f.lower().endswith('250_spk.mat') and os.path.isfile(g):
            peak_locations_250.append(g)
    num_peak_locations_250 = len(peak_locations_250)

    # Get name of signals fs=250Hz contained in the folder
    signals_250 = []
    for f in dir_list:
        g = os.path.join(input_directory, f)
        if not f.lower().startswith('.') and f.lower().endswith('_250.mat') and os.path.isfile(g):
            signals_250.append(g)
    num_signals_250 = len(signals_250)

    if((num_signals_128+num_signals_250)==(num_labels_128+num_labels_250)==(num_peak_locations_128+num_peak_locations_250)):
        # Create empty list for recordings and header files
        recordings_128 = list()
        recordings_250 = list()
        labels_128 = list()
        labels_250 = list()
        locations_128 = list()
        locations_250 = list()

        # Load .mat, _ann.mat and _spk.mat files for each subject using the function "load_data"
        for i in range(num_signals_128):
            # load recordings
            recording_128 = loadmat(signals_128[i])
            recordings_128.append(recording_128['ppg'])
            # load labels
            label_128 = loadmat(beat_labeling_128[i])
            labels_128.append(label_128['labels'])
            # load locations
            location_128 = loadmat(peak_locations_128[i])
            locations_128.append(location_128['speaks'])
            # inform about loading step
            print(f"\rLoading fs=128Hz file: {i+1}/{num_signals_128}")

        for i in range(num_signals_250):
            # load recordings
            recording_250 = loadmat(signals_250[i])
            recordings_250.append(recording_250['ppg'])
            # load labels
            label_250 = loadmat(beat_labeling_250[i])
            labels_250.append(label_250['labels'])
            # load locations
            location_250 = loadmat(peak_locations_250[i])
            locations_250.append(location_250['speaks'])
            # inform about loading step
            print(f"\rLoading fs=250Hz file: {i+1}/{num_signals_250}")

    else:
        print("Error while reading files")

    return recordings_128, recordings_250, labels_128, labels_250, locations_128, locations_250

# Call the function with the input_directory
recordings_128, recordings_250, labels_128, labels_250, locations_128, locations_250 = load_data_from_directory(input_directory)

# Plot label distribution

In [None]:
import numpy as np
# get counts for each label
def calculate_label_distribution(labels):
  tot_count_n = 0
  tot_count_s = 0
  tot_count_v = 0
  for idx in range(len(labels)):
    counts_n = np.count_nonzero(labels[idx] == 'N')
    counts_s = np.count_nonzero(labels[idx] == 'S')
    counts_v = np.count_nonzero(labels[idx] == 'V')
    tot_count_n += counts_n
    tot_count_s += counts_s
    tot_count_v += counts_v
  return tot_count_n, tot_count_s, tot_count_v

# check label distribution in 128Hz samples
tot_count_n_128, tot_count_s_128, tot_count_v_128 = calculate_label_distribution(labels_128)

# check label distribution in 250Hz samples
tot_count_n_250, tot_count_s_250, tot_count_v_250 = calculate_label_distribution(labels_250)

In [None]:
# Check numerosity of classes
print(f"Signals 128Hz: {tot_count_n_128} N beats, {tot_count_s_128} S beats, {tot_count_v_128} V beats")
print(f"Signals 250Hz: {tot_count_n_250} N beats, {tot_count_s_250} S beats, {tot_count_v_250} V beats")

In [None]:
# Check class proportions 
print("Tot. proportion of N beats: ", (tot_count_n_128+tot_count_n_250)/(tot_count_n_128+tot_count_n_250+tot_count_s_128+tot_count_s_250+tot_count_v_128+tot_count_v_250))
print("Tot. proportion of S beats: ", (tot_count_s_128+tot_count_s_250)/(tot_count_n_128+tot_count_n_250+tot_count_s_128+tot_count_s_250+tot_count_v_128+tot_count_v_250))
print("Tot. proportion of V beats: ", (tot_count_v_128+tot_count_v_250)/(tot_count_n_128+tot_count_n_250+tot_count_s_128+tot_count_s_250+tot_count_v_128+tot_count_v_250))

In [None]:
import matplotlib.pyplot as plt

def plot_label_distribution(values, title, labels=['N', 'S', 'V']):
    plt.bar(labels, values, color=['#add8e6', '#90ee90', '#ffb6c1'])
    plt.xlabel('Labels')
    plt.ylabel('Counts')
    plt.title(title)
    plt.show()

# Plot histogram for 128Hz
values_128 = [tot_count_n_128, tot_count_s_128, tot_count_v_128]
plot_label_distribution(values_128, 'Histogram of Labels for 128Hz recordings')

# Plot histogram for 250Hz
values_250 = [tot_count_n_250, tot_count_s_250, tot_count_v_250]
plot_label_distribution(values_250, 'Histogram of Labels for 250Hz recordings')

# Plot histogram for overall distribution
values = [tot_count_n_250+tot_count_n_128, tot_count_s_250+tot_count_s_128, tot_count_v_250+tot_count_v_128]
plot_label_distribution(values, 'Histogram of Labels for all recordings')

# Remove patients having only N type beats

Given the large disproportion of classes, patients showing only normal beats are removed as they only carry redundant information.

In [None]:
# Checking if there are any patient with only 'n' label in the dataset with fs=128Hz
only_N_128 = []

for idx, label in enumerate(labels_128):
    unique_labels = set(label)

    if len(unique_labels) == 1 and 'N' in unique_labels:
        only_N_128.append(idx)

if only_N_128:
    print("Patients with only 'N' labels found among 128Hz recordings at indices:", only_N_128)
else:
    print("No patients with only 'N' labels found among 128Hz recordings.")

In [None]:
# Checking if there are any patient with only 'n' label in the dataset with fs=250Hz
only_N_250 = []

for idx, label in enumerate(labels_250):
  unique_labels = set(label)

  if len(unique_labels) == 1 and 'N' in unique_labels:
      only_N_250.append(idx)

if only_N_250:
    print("Patients with only 'N' labels found among 250Hz recordings at indices:", only_N_250)
else:
    print("No patients with only 'N' labels found among 250Hz recordings.")

In [None]:
# Remove patients from labels, recordings and peak_locations
locations_250 = [locations_250[i] for i in range(len(locations_250)) if i not in only_N_250]
recordings_250 = [recordings_250[i] for i in range(len(recordings_250)) if i not in only_N_250]
labels_250 = [labels_250[i] for i in range(len(labels_250)) if i not in only_N_250]

# Check dimensionality
print(f"New peaks locations dim.: {len(locations_250)}")
print(f"New recordings dim.: {len(recordings_250)}")
print(f"New labels dim.:{len(labels_250)}")

In [None]:
# Check new label distribution in 250Hz samples
tot_count_n_250, tot_count_s_250, tot_count_v_250 = calculate_label_distribution(labels_250)
print(f"Signals 250Hz: {tot_count_n_250} N beats, {tot_count_s_250} S beats, {tot_count_v_250} V beats")

In [None]:
# Visualize the new label distribution

# Plot histogram for 250Hz
values_250 = [tot_count_n_250, tot_count_s_250, tot_count_v_250]
plot_label_distribution(values_250, 'Histogram of Labels for 250Hz recordings')

# Plot histogram for overall distribution
values = [tot_count_n_250+tot_count_n_128, tot_count_s_250+tot_count_s_128, tot_count_v_250+tot_count_v_128]
plot_label_distribution(values, 'Histogram of Labels for all recordings')

# Signal Visualization

A rapid signal inspection has shown the presence of many artifacts along the recordings both for 128Hz and 250Hz samples. A further inspection aimed at assessing the labels associated to the peaks in these noisy portions.

In [None]:
# define function to plot signals over a given time range
def plot_signal(signal, seconds, fs, offset=0):
  t = np.arange(offset ,offset+seconds,1/fs)
  fig, axs = plt.subplots()
  #axs.plot(t, signal[:len(t)], color='C0')
  axs.plot(t, signal[offset*fs:(offset+seconds)*fs], color='C0')
  axs.set_xlabel("Time [s]")
  axs.set_ylabel("Amplitude [mV]")
  plt.title('PPG recording')
  plt.show()

In [None]:
# Show 128Hz signal
plot_signal(recordings_128[0],20,128)

In [None]:
# Show 250Hz signal
plot_signal(recordings_250[0],20,250)

In [None]:
# plot the signal with the corresponding peaks
def plot_signal_with_peaks(signal, peak_locations, fs):
    # Define the time axis
    t = np.arange(0, len(signal) / fs, 1 / fs)

    # Plot the signal
    plt.figure(figsize=(10, 6))
    plt.plot(t, signal, color='C0', label='Signal')

    # Plot the peak locations
    peak_times = np.array(peak_locations) / fs
    plt.scatter(peak_times, signal[peak_locations], color='red', label='Peak Locations')

    # Set the x-axis label and title
    plt.xlabel('Time (s)')
    plt.ylabel('Amplitude')
    plt.title('Signal with Peak Locations')

    # Show the legend
    plt.legend()

    # Display the plot
    plt.show()

In [None]:
plot_signal_with_peaks(recordings_250[0][:2000], locations_250[0][:10], 250)

In [None]:
# plot the signal with the labelled peaks
def plot_signal_with_labelled_peaks(signal, peak_locations, labels, fs):
    # Define the time axis
    t = np.arange(0, len(signal) / fs, 1 / fs)

    # Plot the signal
    plt.figure(figsize=(10, 6))
    plt.plot(t, signal, color='blue', label='Signal')

    # Plot the peak locations with different colors based on the label
    for i, peak_loc in enumerate(peak_locations):
        if labels[i] == 'N':
            color = 'blue'
        elif labels[i] == 'V':
            color = 'red'
        elif labels[i] == 'S':
            color = 'green'
        else:
            color = 'black'
        plt.scatter(t[peak_loc], signal[peak_loc], color=color)

    # Set the x-axis label and title
    plt.xlabel('Time (s)')
    plt.ylabel('Amplitude [mV]')
    plt.title('Signal with Labelled Peak Locations')
    print('Legend: N = blue, V = red, S = green')

In [None]:
plot_signal_with_labelled_peaks(recordings_128[1][:1000], locations_128[1][:10],labels_128[1][:10], 128)

# Signal Pre-Processing

Given the different sampling frequencies, resampling is performed to equalize them. In the analysis, __downsampling__ the 250Hz signals to 128Hz is performed for three main reasons:
1. It allows to decrease the computational complexity.
2. It allows for lower memory requirements.
3. The majority of signals are sampled at 128Hz.

Note that when downsampling from 250Hz to 128Hz we need to ensure that the original signal does not contain frequencies above 64Hz, in accordance with the sampling theorem. For this reason, the 250Hz signals' periodograms have been analyzed and it was observed that on average most of the frequency content of the signals is contained between 0 and 3 Hz.

In [None]:
from scipy.signal import periodogram

# plot a signal and the corresponding periodogram
def plot_signal_and_periodogram(signal, fs):
  frequencies, Pxx = periodogram(signal.flatten(), fs)

  # Plot the signal and its periodogram
  plt.figure(figsize=(12, 6))

  if 1:
    # Plot the signal
    plt.subplot(2, 1, 1)
    t = np.arange(0, len(signal) / fs, 1 / fs)
    plt.plot(t, signal)
    plt.title('Original Signal')
    plt.xlabel('Time (s)')
    plt.ylabel('Amplitude [mV]')

  # Plot the periodogram
  plt.subplot(2, 1, 2)
  plt.plot(frequencies, Pxx)
  plt.title('Periodogram of the Signal')
  plt.xlabel('Frequency (Hz)')
  plt.ylabel('Power/Frequency (dB/Hz)')

  # Limit the x-axis to 10
  plt.xlim(0, 10)

  plt.tight_layout()
  plt.show()

In [None]:
# Check for 250Hz signals
fs = 250
for i, signal in enumerate(recordings_250):
  print(f"250Hz signal {i}")
  plot_signal_and_periodogram(signal, fs)

In [None]:
# downsample 250Hz signals to 128Hz
from scipy.signal import resample

# Define the target sampling frequency
fs = 250
target_fs = 128

# Downsample the signals
downsampled_signals = [resample(signal, int(len(signal) * target_fs / fs)) for signal in recordings_250]
# Modify locations_250 to match downsampled_signals
downsampled_locations = [np.round(location * target_fs / fs).astype(int) for location in locations_250]

In [None]:
# Check the downsampled signal with the corresponding peaks
plot_signal_with_peaks(downsampled_signals[0][:2000], downsampled_locations[0][:10], 128)

In [None]:
del recordings_250

# Individual signal check
Visualization of all the signals composing the dataset is performed to avoid including any "outlier", meaning recordings not showing clear waveforms.

In [None]:
# Plot signals in downsampled_signals
fs = 128
for i, signal in enumerate(downsampled_signals):
  print(f"Downsampled signal {i}")
  plot_signal(signal, 20, fs, offset = 100)

In [None]:
# Plot signals in recordings_128
for i, signal in enumerate(recordings_128):
  print(f"Original 128Hz signal {i}")
  plot_signal(signal, 20, fs, offset = 100)

## Filtering and artifact removal

The periodograms also prove useful to decide the cut-off frequency of filters. Given the low frequency content of PPG signals, initially a bandpass filtering approach was used to deal with artifacts. Nevertheless, as filtering alone could not recostruct the morphology of the PPG waves in the noisy portions, a threshold-based wave removal was performed.

Initially 3Hz was chosen as the cutoff frequency, then it was raised to 5Hz as this value seemed to impact less on the morphology of the non-noisy signals. The DC component of the signal is instead filtered out using 0.5 Hz as cutoff.

In [None]:
from scipy.signal import butter, filtfilt

# Define the filter parameters
low_cutoff_frequency = 0.5  # Set the low cutoff frequency to 0.5 Hz
high_cutoff_frequency = 5  # Set the high cutoff frequency to 5 Hz
nyquist_freq = 0.5 * fs  # Nyquist frequency
fs = 128  # Set the sampling frequency to 128 Hz
filter_order = 2  # Set the filter order

# Calculate the normalized cutoff frequencies
normalized_low_cutoff_frequency = low_cutoff_frequency / nyquist_freq
normalized_high_cutoff_frequency = high_cutoff_frequency / nyquist_freq

# Design the Butterworth bandpass filter
b, a = butter(filter_order, [normalized_low_cutoff_frequency, normalized_high_cutoff_frequency], btype='band', analog=False, output='ba')

# Recordings_128 is a list of 2D arrays
flattened_signals_128 = [np.squeeze(signal) for signal in recordings_128]
# Apply the filter to the flattened signals
filtered_signals_128 = [filtfilt(b, a, signal) for signal in flattened_signals_128]

# Downsampled_signals is a list of 2D arrays
flattened_downsampled_signals = [np.squeeze(signal) for signal in downsampled_signals]
filtered_signals_downsampled = [filtfilt(b, a, signal) for signal in flattened_downsampled_signals]

In [None]:
def plot_signals_overlapped(signal_1, signal_2, seconds, fs,  offset=0):
    t = np.arange(offset,offset+seconds,1/fs)

    plt.figure(figsize=(10, 6))
    plt.plot(t, signal_1[offset*fs:(offset+seconds)*fs], label='Signal 1', color='C0')
    plt.plot(t, signal_2[offset*fs:(offset+seconds)*fs], label='Signal 2', color='red')
    plt.xlabel('Time (s)')
    plt.ylabel('Amplitude [mV]')
    plt.title('Signal Overlap')
    plt.legend()
    plt.show()

In [None]:
# show signal vs. filtered signal
fs = 128
for i, (signal, filtered_signal) in enumerate(zip(recordings_128,filtered_signals_128)):
  print(f"Original vs Filtered 128Hz signal {i}")
  plot_signals_overlapped(signal, filtered_signal, 20, fs, offset=300)

In [None]:
# show signal vs. filtered signal
for i, (signal, filtered_signal) in enumerate(zip(downsampled_signals,filtered_signals_downsampled)):
  print(f"Downsampled vs Downsampled and Filtered signal {i}")
  plot_signals_overlapped(signal, filtered_signal, 20, fs)

## Standardization

Standardization of the signals is performed in order to make them comparable across and within patients.

In [None]:
# Perform standardization based on the signal mean and standard deviation
def standardize_signals(signals):
  standardized_signals = []
  for signal in signals:
      mean = np.mean(signal)
      std = np.std(signal)
      standardized_signal = (signal - mean) / std
      standardized_signals.append(standardized_signal)
  return standardized_signals

In [None]:
# Apply standardization to filtered signals
standardized_128 = standardize_signals(filtered_signals_128)
standardized_downsampled = standardize_signals(filtered_signals_downsampled)

# Beat Segmentation and Feature Extraction
Segmentation is performed in two distinct ways:
1. Segmentation of individual beats
2. Segmentation to retain 2 consecutive RR intervals

In the first case, as the beats are segmented, feature extraction is performed.

In [None]:
# Define a Patient class so that train-validation-test split can be performed easily later
class Patient:
    def __init__(self,
                 single_beats=None, contiguous_beats=None,
                 mean=None, std=None, amplitude=None, peak_value=None,
                 pre_PP=None, post_PP=None, avg_PP=None, width=None, FWHM=None,
                 skewness=None, pre_skew=None, post_skew=None,
                 kurtosis=None, pre_kurt=None, post_kurt=None, 
                 entropy=None, RMS=None, neg_neg_jump=None, 
                 pre_pos_pos_jump=None, post_pos_pos_jump=None,
                 rise_time=None, fall_time=None, area=None, local_hrv=None,
                 energy=None, dominant_frequency=None, 
                 labels=None, peak_locations=None):

        # Segmented beats
        self.single_beats = single_beats if single_beats is not None else []
        self.contiguous_beats = contiguous_beats if contiguous_beats is not None else []
        # Features
        self.mean = mean if mean is not None else []
        self.std = std if std is not None else []
        self.amplitude = amplitude if amplitude is not None else []
        self.peak_value = peak_value if peak_value is not None else []
        self.pre_PP = pre_PP if pre_PP is not None else []
        self.post_PP = post_PP if post_PP is not None else []
        self.avg_PP = avg_PP if avg_PP is not None else []
        self.width = width if width is not None else []
        self.FWHM = FWHM if FWHM is not None else []
        self.skewness = skewness if skewness is not None else []
        self.pre_skew = pre_skew if pre_skew is not None else []
        self.post_skew = post_skew if post_skew is not None else []
        self.kurtosis = kurtosis if kurtosis is not None else []
        self.pre_kurt = pre_kurt if pre_kurt is not None else []
        self.post_kurt = post_kurt if post_kurt is not None else []
        self.entropy = entropy if entropy is not None else []
        self.RMS = RMS if RMS is not None else []
        self.neg_neg_jump = neg_neg_jump if neg_neg_jump is not None else []
        self.pre_pos_pos_jump = pre_pos_pos_jump if pre_pos_pos_jump is not None else []
        self.post_pos_pos_jump = post_pos_pos_jump if post_pos_pos_jump is not None else []
        self.rise_time = rise_time if rise_time is not None else []
        self.fall_time = fall_time if fall_time is not None else []
        self.area = area if area is not None else []
        self.local_hrv = local_hrv if local_hrv is not None else []
        self.energy = energy if energy is not None else []
        self.dominant_frequency = dominant_frequency if dominant_frequency is not None else []
        # Labels and peak locations
        self.labels = labels if labels is not None else []
        self.peak_locations = peak_locations if peak_locations is not None else []

In [None]:
# Initialize the patient instances as empty lists
NUM_PATIENTS = len(standardized_128) + len(standardized_downsampled)

patient_instances = [Patient() for _ in range(NUM_PATIENTS)]

In [None]:
# Define a function to extract single beats from a signal using a dynamic window size
def extract_single_beats_dynamically(signal, peak_locations, start_ratio=0.35, end_ratio=0.65):
    """
    Extracts beats from a signal based on the locations of the peaks.

    Args:
        signal (list): The input signal.
        peak_locations (list): The locations of the peaks in the signal.
        start_ratio (float, optional): The ratio of the window size to use as the starting point of the beat extraction.
            Defaults to 0.35.
        end_ratio (float, optional): The ratio of the window size to use as the ending point of the beat extraction.
            Defaults to 0.65.

    Returns:
        list: A list of beats extracted from the signal.
        list: A list of peak positions relative to the start of the window.
    """
    beats = []
    peak_positions = []
    for i in range(len(peak_locations)):
        if i == 0:
            window_size = peak_locations[i+1] - peak_locations[i]
        elif i == len(peak_locations) - 1:
            window_size = peak_locations[i] - peak_locations[i-1]
        else:
            # Compute the average of the previous and the successive peak-to-peak differences
            window_size = (peak_locations[i+1] - peak_locations[i] + peak_locations[i] - peak_locations[i-1]) / 2

        start = int(max(0, peak_locations[i] - window_size*start_ratio))
        end = int(min(len(signal), peak_locations[i] + window_size*end_ratio))
        beat = signal[start:end]
        beats.append(beat)

    return beats, peak_positions


# Define a function to extract single beats from a signal using a fixed window size
def extract_single_beats_statically(signal, peak_locations, window_size=100, start_ratio=0.35, end_ratio=0.65):
    """
    Extracts single beats from a given signal based on peak locations.

    Args:
        signal (array-like): The input signal.
        peak_locations (array-like): The locations of the peaks in the signal.
        window_size (int, optional): The size of the window around each peak to extract the beat.
            Defaults to 80.
        start_ratio (float, optional): The ratio of the window size to use as the starting point of the beat extraction.
            Defaults to 0.35.
        end_ratio (float, optional): The ratio of the window size to use as the ending point of the beat extraction.
            Defaults to 0.65.

    Returns:
        list: A list of extracted beats.
        list: A list of peak positions relative to the start of the window.
    """
    beats = []
    peak_positions = []

    # Segment the beats
    for peak in peak_locations:
        start = int(max(0, peak - window_size*start_ratio))
        end = int(min(len(signal), peak + window_size*end_ratio))
        beat = signal[start:end]
        beats.append(beat)

    # Calculate the relative position of the peak within the window
        if start > 0:
            peak_position = int(window_size * start_ratio)
        else:
            peak_position = peak
        peak_positions.append(peak_position)

    return beats, peak_positions

In [None]:
# Define a function to extract contiguous beats from a signal
def extract_contiguous_beats_dynamically(signal, peak_locations):
    """
    Extracts single beats from a given signal based on peak locations.

    Args:
        signal (array-like): The input signal.
        peak_locations (array-like): The locations of the peaks in the signal.

    Returns:
        list: A list of extracted beats.

    """
    beats = []
    for i in range(len(peak_locations)):
        if i == 0:
            preceding_window_size = peak_locations[i]
            succeeding_window_size = peak_locations[i+1] - peak_locations[i]
        elif i == len(peak_locations) - 1:
            preceding_window_size = peak_locations[i] - peak_locations[i-1]
            succeeding_window_size = len(signal) - peak_locations[i]
        else:
            preceding_window_size = (peak_locations[i] - peak_locations[i-1] + peak_locations[i+1] - peak_locations[i]) / 2
            succeeding_window_size = (peak_locations[i+1] - peak_locations[i] + peak_locations[i] - peak_locations[i-1]) / 2

        start = int(max(0, peak_locations[i] - preceding_window_size))
        end = int(min(len(signal), peak_locations[i] + succeeding_window_size))
        beat = signal[start:end]
        beats.append(beat)
    return beats

def extract_contiguous_beats_statically(signal, peak_locations, window_size=165, start_ratio=0.5, end_ratio=0.5):
    """
    Extracts contiguous beats from a given signal based on peak locations.

    Args:
        signal (array-like): The input signal.
        peak_locations (array-like): The locations of the peaks in the signal.
        window_size (int, optional): The size of the window around each peak to extract the beat.
            Defaults to 200.
        start_ratio (float, optional): The ratio of the window size to use as the starting point of the beat extraction.
            Defaults to 0.5.
        end_ratio (float, optional): The ratio of the window size to use as the ending point of the beat extraction.
            Defaults to 0.5.

    Returns:
        list: A list of extracted beats.

    """
    beats = []
    for peak in peak_locations:
        start = int(max(0, peak - window_size*start_ratio))
        end = int(min(len(signal), peak + window_size*end_ratio))
        beat = signal[start:end]
        beats.append(beat)
    return beats

In [None]:
# Define a function for the visualization of beats
def plot_beat_with_peak(beats, positions, idx=None):
    """
    Plots a beat with the peak location.

    Args:
        beats (array-like): The input beats.
        positions (array-like): The locations of the peaks in the beats.
        idx (int, optional): The index of the beat to plot. Defaults to None.
    """
    if idx is None:
        idx = np.random.randint(len(beats))
    plt.figure(figsize=(10, 6))
    plt.plot(beats[idx])
    plt.scatter(positions[idx], beats[idx][positions[idx]], color='red')
    plt.xlabel('Samples')
    plt.ylabel('Amplitude [mV]')
    plt.title(f'Beat {idx} with Peak Location')
    plt.show()

# Define a function to plot a beat of a specific label
def plot_beat_of_given_label(beats, labels, label, idx=None):
    """
    Plots a beat of a given label.

    Args:
        beats (array-like): The input beats.
        labels (array-like): The labels of the beats.
        label (str): The label to plot.
        idx (int, optional): The index of the beat to plot. Defaults to None.
    """
    if idx is None:
        idx = np.random.randint(len(beats))
    while labels[idx] != label:
        idx = np.random.randint(len(beats))
    plt.figure(figsize=(10, 6))
    plt.plot(beats[idx])
    plt.xlabel('Samples')
    plt.ylabel('Amplitude [mV]')
    plt.title(f'Beat {idx} with Label {label}')
    plt.show()

In [None]:
# Feature extraction
def get_beat_mean(beats):
    """
    Calculates the mean of the beats.

    Args:
        beats (array-like): The input beats.

    Returns:
        array-like: The mean of the beats.
    """
    mean_array = []
    for beat in beats:
        mean = np.mean(beat)
        mean_array.append(mean)
    return mean_array

def get_beat_std(beats):
    """
    Calculates the standard deviation of the beats.

    Args:
        beats (array-like): The input beats.

    Returns:
        array-like: The standard deviation of the beats.
    """
    std_array = []
    for beat in beats:
        std = np.std(beat)
        std_array.append(std)
    return std_array

def get_beat_amplitude(beats):
    """
    Calculates the amplitude of the beats.

    Args:
        beats (array-like): The input beats.

    Returns:
        array-like: The amplitude of the beats.
    """
    amplitudes = []
    for beat in beats:
        amplitudes.append(np.max(beat) - np.min(beat))
    return amplitudes

def get_beat_peak_value(beats):
    """
    Calculates the peak value of the beats.

    Args:
        beats (array-like): The input beats.

    Returns:
        array-like: The peak value of the beats.
    """
    peak_values = []
    for beat in beats:
        peak_values.append(max(beat))
    return peak_values

def get_beat_pre_post_PP(peak_locations):
    """
    Calculates the peak to peak distances of the beats.

    Args:
        peak_locations (array-like): The locations of the peaks in the beats.

    Returns:
        array-like: The pre-PP and post-PP of the beats in seconds.
    """
    fs = 128
    pre_PPs = []
    post_PPs = []

    for i in range(len(peak_locations)):
        if i == 0:
            pre_PP = None
        else:
            pre_PP = float(peak_locations[i] - peak_locations[i-1])/fs

        if i == len(peak_locations) - 1:
            post_PP = None
        else:
            post_PP = float(peak_locations[i+1] - peak_locations[i])/fs

        if(pre_PP is None):
            pre_PP = post_PP
        elif(post_PP is None):
            post_PP = pre_PP

        pre_PPs.append(pre_PP)
        post_PPs.append(post_PP)

    return pre_PPs, post_PPs

def get_beat_width(beats):
    """
    Calculates the duration of the beats.

    Args:
        beats (array-like): The input beats.

    Returns:
        array-like: The width value of the beats in seconds.
    """
    fs = 128
    widths = []
    for beat in beats:
        widths.append(len(beat)/fs)
    return widths

def get_beat_FWHM(beats):
    """
    Calculates the Full Width at Half Maximum (FWHM) of the beats.

    Args:
        beats (array-like): The input beats.

    Returns:
        array-like: The FWHM value of the beats in seconds.
    """
    fs = 128
    widths = []
    for beat in beats:
        max_val = np.max(beat)
        half_max = max_val / 2.
        indices = np.where(beat > half_max)[0]
        if len(indices) > 0:  # Check if there are any indices found
            fwhm = (indices[-1] - indices[0] + 1) / fs
            widths.append(fwhm)
        else:
            widths.append(0)  # If no indices found, append 0
    return widths

def compute_rise_times(beats):
    """
    Calculates the rise time of each beat in the list.

    Args:
        beats (list of array-like): The input beats.

    Returns:
        list of float: The rise times of the beats in seconds.
    """
    fs = 128  # Sampling frequency
    rise_times = []
    for beat in beats:
        max_val = np.max(beat)
        max_index = np.argmax(beat)
        low_val = 0.1 * max_val
        high_val = 0.9 * max_val

        indices = np.where((beat[:max_index] >= low_val) & (beat[:max_index] <= high_val))[0]
        if len(indices) > 0:  # Check if there are any indices found
            rise_time = (indices[-1] - indices[0] + 1) / fs
            rise_times.append(rise_time)
        else:
            rise_times.append(0)  # If no indices found, append 0
    return rise_times

def compute_fall_times(beats):
    """
    Calculates the fall time of each beat in the list.

    Args:
        beats (list of array-like): The input beats.

    Returns:
        list of float: The fall times of the beats in seconds.
    """
    fs = 128  # Sampling frequency
    fall_times = []
    for beat in beats:
        max_val = np.max(beat)
        max_index = np.argmax(beat)
        high_val = 0.9 * max_val
        low_val = 0.1 * max_val

        # Reverse the beat to calculate fall time
        reversed_beat = beat[max_index:][::-1]
        indices = np.where((reversed_beat >= low_val) & (reversed_beat <= high_val))[0]
        if len(indices) > 0:  # Check if there are any indices found
            fall_time = (indices[-1] - indices[0] + 1) / fs
            fall_times.append(fall_time)
        else:
            fall_times.append(0)  # If no indices found, append 0
    return fall_times

def compute_negative_to_negative_peak_jump(beats):
    """
    Calculates the difference between the beat onset and beat end values of each beat in the list.

    Args:
        beats (list of array-like): The input beats.

    Returns:
        list of float: The negative to negative peak jump.
    """
    neg_jumps = []
    for beat in beats:
        max_index = np.argmax(beat)
        # Split the beat into two halves
        first_half = beat[:max_index]
        second_half = beat[max_index:]        
        # Find the minimum value in each half
        min_first_half = min(first_half) if first_half.size > 0 else beat[0]
        min_second_half = min(second_half) if second_half.size > 0 else beat[-1]
        # Calculate the negative to negative peak jump
        jump = min_second_half - min_first_half
        neg_jumps.append(jump)

    return neg_jumps

def compute_positive_to_positive_peak_jump(beats):
    """
    Calculates the difference between successive peak values for each beat in the list.

    Args:
        beats (list of array-like): The input beats.

    Returns:
        list of float: The negative to negative peak jump.
    """
    pre_pos_jumps = []
    post_pos_jumps = []
    for beat in beats:
        # Compute the peak value of the current beat
        peak_current = np.max(beats[i])
        # Compute the peak value of the previous beat
        peak_prev = np.max(beats[i - 1])
        # Compute the peak value of the next beat
        peak_next = np.max(beats[i + 1])
        # Compute the difference in peak values
        pre_diff = peak_current - peak_prev
        post_diff = peak_next - peak_current
        
        pre_pos_jumps.append(pre_diff)
        post_pos_jumps.append(post_diff)

    return pre_pos_jumps, post_pos_jumps

def compute_areas(beats):
    """
    Calculates the area under each beat in the list.

    Args:
        beats (list of array-like): The input beats.

    Returns:
        list of float: The areas under the beats.
    """
    areas = []
    for beat in beats:
        area = np.trapz(beat)
        areas.append(area)
    return areas

def compute_energy(beats):
    """
    Calculates the total energy of each beat in the list.

    Args:
        beats (list of array-like): The input beats.

    Returns:
        list of float: The total energy of the beats.
    """
    energies = []
    for beat in beats:
        energy = np.sum(np.square(beat))
        energies.append(energy)
    return energies

from scipy.stats import skew

def compute_skewness(beats):
    """
    Calculates the skewness of each beat in the list.

    Args:
        beats (list of array-like): The input beats.

    Returns:
        list of float: The skewness of the beats.
    """
    skewness_values = []
    for beat in beats:
        skewness = skew(beat)
        skewness_values.append(skewness)
    return skewness_values

def compute_skewness_diff(beats):
    """
    Calculates the difference in skewness between each beat and its previous and next beat.

    Args:
        beats (list of array-like): The input beats.

    Returns:
        list of tuples: Each tuple contains the difference in skewness with the previous beat and the next beat.
    """
    skewness_values = [skew(beat) for beat in beats]
    skewness_diffs_pre = []
    skewness_diffs_post = []

    for i in range(len(skewness_values)):
        if i == 0:  # first beat, no previous beat
            prev_diff = None
        else:
            prev_diff = skewness_values[i] - skewness_values[i-1]

        if i == len(skewness_values) - 1:  # last beat, no next beat
            next_diff = None
        else:
            next_diff = skewness_values[i] - skewness_values[i+1]

        if(prev_diff is None):
            prev_diff = next_diff
        elif(next_diff is None):
            next_diff = prev_diff
        
        skewness_diffs_pre.append(prev_diff)
        skewness_diffs_post.append(next_diff)

    return skewness_diffs_pre, skewness_diffs_post

from scipy.stats import kurtosis

def compute_kurtosis(beats):
    """
    Calculates the kurtosis of each beat in the list.

    Args:
        beats (list of array-like): The input beats.

    Returns:
        list of float: The kurtosis of the beats.
    """
    kurtosis_values = []
    for beat in beats:
        kurt = kurtosis(beat)
        kurtosis_values.append(kurt)
    return kurtosis_values

def compute_kurtosis_diff(beats):
    """
    Calculates the difference in kurtosis between each beat and its previous and next beat.

    Args:
        beats (list of array-like): The input beats.

    Returns:
        Two lists: Each list contains the difference in kurtosis with the previous beat and the next beat.
    """
    kurtosis_values = [kurtosis(beat) for beat in beats]
    kurtosis_diffs_pre = []
    kurtosis_diffs_post = []

    for i in range(len(kurtosis_values)):
        if i == 0:  # first beat, no previous beat
            prev_diff = None
        else:
            prev_diff = kurtosis_values[i] - kurtosis_values[i-1]

        if i == len(kurtosis_values) - 1:  # last beat, no next beat
            next_diff = None
        else:
            next_diff = kurtosis_values[i] - kurtosis_values[i+1]

        if(prev_diff is None):
            prev_diff = next_diff
        elif(next_diff is None):
            next_diff = prev_diff
        
        kurtosis_diffs_pre.append(prev_diff)
        kurtosis_diffs_post.append(next_diff)

    return kurtosis_diffs_pre, kurtosis_diffs_post

import nolds

def compute_entropy(beats):
    """
    Calculates the sample entropy of each beat.

    Args:
        beats (list of array-like): The input beats.

    Returns:
        list of float: The sample entropy of each beat.
    """
    entropy_values = []
    for beat in beats:
        # Compute the sample entropy of the beat
        e = nolds.sampen(beat)
        entropy_values.append(e)

    return entropy_values

def compute_rms(beats):
    """
    Calculates the root mean square (RMS) of each beat in the list.

    Args:
        beats (list of array-like): The input beats.

    Returns:
        list of float: The RMS of the beats.
    """
    rms_values = []
    for beat in beats:
        rms = np.sqrt(np.mean(np.square(beat)))
        rms_values.append(rms)
    return rms_values

def calculate_hrv(peak_locations, window_size=4):
    """
    Calculate Heart Rate Variability (HRV) for each peak location within a given window size.

    Parameters:
    peak_locations (list): List of peak locations.
    window_size (int): The size of the window to consider for each peak. Default is 4.

    Returns:
    list: HRV measures for each peak.
    """

    # Initialize an empty list to store HRV measures
    hrv_measures = []

    # Calculate half window size for creating centered window
    half_window = window_size // 2

    # Loop over each peak location
    for i in range(len(peak_locations)):
        # Define the start and end of the window centered on the current peak
        window_start = max(0, i - half_window)
        window_end = min(len(peak_locations), i + half_window + 1)

        # Get the peak locations within this window
        window_peaks = peak_locations[window_start:window_end]
        window_peaks_1d = window_peaks.flatten()
        # Calculate differences between successive peaks to get PP intervals
        rr_intervals = np.diff(window_peaks_1d)

        # Calculate HRV measure for this window
        # Here we use the standard deviation of PP intervals as the HRV measure
        hrv = np.std(rr_intervals)

        # Append the calculated HRV measure to the list
        hrv_measures.append(hrv)

    # Return the list of HRV measures
    return hrv_measures

def compute_avg_peak_to_peak_distance(peak_locations, window_size=4):
    """
    Calculates the average peak-to-peak distance for each peak within a sliding window.

    Args:
        peak_locations (list of int): The locations of the peaks.
        window_size (int, optional): The size of the sliding window. Defaults to 4.

    Returns:
        list of float: The average peak-to-peak distance for each peak.
    """
    fs = 128
    avg_PP_distances = []
    for i in range(len(peak_locations)):
        # Determine the start and end of the sliding window
        start = max(0, i - window_size)
        end = min(i + window_size, len(peak_locations))

        # Extract the peak locations within the sliding window
        window_peaks = peak_locations[start:end]
        window_peaks_1d = window_peaks.flatten()
        # Compute the peak-to-peak distances
        distances = np.diff(window_peaks_1d) / fs

        # Compute the average distance
        avg_distance = np.mean(distances) if distances.size else 0
        avg_PP_distances.append(avg_distance)

    return avg_PP_distances

from scipy.fftpack import fft

def compute_dominant_frequency(beats, sample_rate):
    """
    Calculates the dominant frequency of each beat in the list.

    Args:
        beats (list of array-like): The input beats.
        sample_rate (float): The sample rate of the beats.

    Returns:
        list of float: The dominant frequency of the beats.
    """
    dominant_frequencies = []
    for beat in beats:
        # Compute FFT
        fft_vals = fft(beat)

        # Compute absolute value of FFT
        abs_fft_vals = np.abs(fft_vals)

        # Find the frequency where the absolute value of FFT is maximum
        dominant_frequency = np.argmax(abs_fft_vals) * sample_rate / len(beat)
        dominant_frequencies.append(dominant_frequency)
    return dominant_frequencies

In [None]:
import tqdm
# Extract beats and features from the signals
print("Extracting beats from 128Hz signals...")
for i, patient_instance in tqdm.tqdm(enumerate(patient_instances[:len(standardized_128)]), total=len(standardized_128)):
    # Extract single beats from the signal
    single_beats, peak_locations = extract_single_beats_statically(standardized_128[i],
                                                                   locations_128[i],
                                                                   window_size=100)
    # Extract contiguous beats from the signal
    contiguous_beats = extract_contiguous_beats_statically(standardized_128[i],
                                                            locations_128[i],
                                                            window_size=200)
    # Store the beats and peak locations in the patient instance
    patient_instance.single_beats = single_beats
    patient_instance.contiguous_beats = contiguous_beats
    patient_instance.peak_locations = peak_locations
    # Store the labels in the patient instance
    patient_instance.labels = labels_128[i]
    # Calculate the features and store them in the patient instance
    single_beats_dynamic, peak_locations_ = extract_single_beats_dynamically(filtered_signals_128[i],
                                                                             locations_128[i])
    patient_instance.mean = get_beat_mean(single_beats_dynamic)
    patient_instance.std = get_beat_std(single_beats_dynamic)
    patient_instance.amplitude = get_beat_amplitude(single_beats_dynamic)
    patient_instance.peak_value = get_beat_peak_value(single_beats_dynamic)
    patient_instance.pre_PP, patient_instance.post_PP = get_beat_pre_post_PP(locations_128[i])
    patient_instance.avg_PP = compute_avg_peak_to_peak_distance(locations_128[i])
    patient_instance.width = get_beat_width(single_beats_dynamic)
    patient_instance.FWHM = get_beat_FWHM(single_beats_dynamic)
    patient_instance.rise_time = compute_rise_times(single_beats_dynamic)
    patient_instance.fall_time = compute_fall_times(single_beats_dynamic)
    patient_instance.area = compute_areas(single_beats_dynamic)
    patient_instance.skewness = compute_skewness(single_beats_dynamic)
    patient_instance.pre_skew, patient_instance.post_skew = compute_skewness_diff(single_beats_dynamic)
    patient_instance.kurtosis = compute_kurtosis(single_beats_dynamic)
    patient_instance.pre_kurt, patient_instance.post_kurt = compute_kurtosis_diff(single_beats_dynamic)
    patient_instance.entropy = compute_entropy(single_beats_dynamic)
    patient_instance.RMS = compute_rms(single_beats_dynamic)
    patient_instance.neg_neg_jump = compute_negative_to_negative_peak_jump(single_beats_dynamic)
    patient_instance.pre_pos_pos_jump, patient_instance.post_pos_pos_jump = compute_positive_to_positive_peak_jump(single_beats_dynamic)
    patient_instance.local_hrv = calculate_hrv(locations_128[i])
    patient_instance.energy = compute_energy(single_beats_dynamic)
    patient_instance.dominant_frequency = compute_dominant_frequency(single_beats_dynamic, 128)


print("Extracting beats from downsampled 250Hz signals...")
for i, patient_instance in tqdm.tqdm(enumerate(patient_instances[len(standardized_128):]), total=len(standardized_downsampled)):
    # Extract single beats from the signal
    single_beats, peak_locations = extract_single_beats_statically(standardized_downsampled[i],
                                                                   downsampled_locations[i],
                                                                   window_size=100)
    # Extract also beats form non-standardized signals to be used for feature extraction
    # single_beats_non_standardized, peak_locations_non_standardized = extract_single_beats_statically(filtered_signals_downsampled[i],
    #                                                                                                  downsampled_locations[i],
    #                                                                                                  window_size=100)
    # Extract contiguous beats from the signal
    contiguous_beats = extract_contiguous_beats_statically(standardized_downsampled[i],
                                                            downsampled_locations[i],
                                                            window_size=200)
    # Store the beats and peak locations in the patient instance
    patient_instance.single_beats = single_beats
    patient_instance.contiguous_beats = contiguous_beats
    patient_instance.peak_locations = peak_locations
    # Store the labels in the patient instance
    patient_instance.labels = labels_250[i]
    # Calculate the features and store them in the patient instance
    single_beats_dynamic, peak_locations_ = extract_single_beats_dynamically(filtered_signals_downsampled[i],
                                                                             downsampled_locations[i])
    patient_instance.mean = get_beat_mean(single_beats_dynamic)
    patient_instance.std = get_beat_std(single_beats_dynamic)
    patient_instance.amplitude = get_beat_amplitude(single_beats_dynamic)
    patient_instance.peak_value = get_beat_peak_value(single_beats_dynamic)
    patient_instance.pre_PP, patient_instance.post_PP = get_beat_pre_post_PP(downsampled_locations[i])
    patient_instance.avg_PP = compute_avg_peak_to_peak_distance(downsampled_locations[i])
    patient_instance.width = get_beat_width(single_beats_dynamic)
    patient_instance.FWHM = get_beat_FWHM(single_beats_dynamic)
    patient_instance.rise_time = compute_rise_times(single_beats_dynamic)
    patient_instance.fall_time = compute_fall_times(single_beats_dynamic)
    patient_instance.area = compute_areas(single_beats_dynamic)
    patient_instance.skewness = compute_skewness(single_beats_dynamic)
    patient_instance.pre_skew, patient_instance.post_skew = compute_skewness_diff(single_beats_dynamic)
    patient_instance.kurtosis = compute_kurtosis(single_beats_dynamic)
    patient_instance.pre_kurt, patient_instance.post_kurt = compute_kurtosis_diff(single_beats_dynamic)
    patient_instance.entropy = compute_entropy(single_beats_dynamic)
    patient_instance.RMS = compute_rms(single_beats_dynamic)
    patient_instance.neg_neg_jump = compute_negative_to_negative_peak_jump(single_beats_dynamic)
    patient_instance.pre_pos_pos_jump, patient_instance.post_pos_pos_jump = compute_positive_to_positive_peak_jump(single_beats_dynamic)
    patient_instance.local_hrv = calculate_hrv(downsampled_locations[i])
    patient_instance.energy = compute_energy(single_beats_dynamic)
    patient_instance.dominant_frequency = compute_dominant_frequency(single_beats_dynamic, 128)

In [None]:
# Plot beat of label N
plot_beat_of_given_label(patient_instances[0].single_beats, patient_instances[0].labels, 'N')

In [None]:
# Plot beat of label S
plot_beat_of_given_label(patient_instances[0].single_beats, patient_instances[0].labels, 'S')

In [None]:
# Plot beat of label V
plot_beat_of_given_label(patient_instances[0].single_beats, patient_instances[0].labels, 'V')

# Beat Cleaning
The goal is to remove the noisy beats belonging to class 'N' and to reconstruct those belonging to classes 'S' and 'V' in order to maximize the amount of data available for the minority classes.

In [None]:
# Inspect beats based on amplitude threshold
def get_noisy_beats(patients, threshold=1.5):
    """
    Gets the indices of the beats that are considered noisy.

    Args:
        patients (list): The list of patient instances.
        threshold (float, optional): The threshold to use to determine if a beat is noisy. Defaults to 1.5.

    Returns:
        list: A list of indices of noisy beats.
    """
    noisy_beats = []
    for i, patient in enumerate(patients):
        for j in range(len(patient.single_beats)):
            if (np.max(patient.single_beats[j])-np.min(patient.single_beats[j])) > threshold:
                noisy_beats.append((i, j))
    return noisy_beats

# Plot noisy beats
def plot_noisy_beats(patients, noisy_beats, num_beats=5):
    """
    Plots the noisy beats.

    Args:
        patients (list): The list of patient instances.
        noisy_beats (list): The list of indices of noisy beats.
        num_beats (int, optional): The number of beats to plot. Defaults to 5.
    """
    for i, j in noisy_beats[:num_beats]:
        plot_beat_with_peak(patients[i].single_beats, patients[i].peak_locations, idx=j)

In [None]:
# Get the indices of the noisy beats
noisy_beats = get_noisy_beats(patient_instances)
# Check noisy beats labels
check_noisy_beats_labels = [patient_instances[i].labels[j] for i, j in noisy_beats]
unique, counts = np.unique(check_noisy_beats_labels, return_counts=True)
print(f"Labels of noisy beats: {dict(zip(unique, counts))}")

In [None]:
# Compute label distribution
tot_count_n, tot_count_s, tot_count_v = calculate_label_distribution([patient.labels for patient in patient_instances])
print(f"Label Distribution: {tot_count_n} N beats, {tot_count_s} S beats, {tot_count_v} V beats")

In [None]:
# Separate beats based on label
noisy_beats_N = [(i,j) for i, j in noisy_beats if patient_instances[i].labels[j] == 'N']
noisy_beats_S = [(i,j) for i, j in noisy_beats if patient_instances[i].labels[j] == 'S']
noisy_beats_V = [(i,j) for i, j in noisy_beats if patient_instances[i].labels[j] == 'V']

In [None]:
# Visualize a batch of N noisy beats
plot_noisy_beats(patient_instances, noisy_beats_N)

In [None]:
from collections import defaultdict

def group_noisy_beats(noisy_beats_N):
    """
    Groups the noisy beats based on the first index.

    Args:
        noisy_beats_N (list of tuples): The noisy beats to group.

    Returns:
        list of list: The grouped noisy beats.
    """
    groups = defaultdict(list)
    for beat in noisy_beats_N:
        groups[beat[0]].append(beat)

    return list(groups.values())

In [None]:
# Group N noisy beats by patients
grouped_noisy_beats_N = group_noisy_beats(noisy_beats_N)
noisy_N_list = []
for group in grouped_noisy_beats_N:
    noisy_N_list.append([item[1] for item in group])

In [None]:
# Check if noisy beats have been found for each patient
patient_indexes = []
for group in grouped_noisy_beats_N:
    idx = [item[0] for item in group]
    unique_idx = set(idx)
    if unique_idx not in patient_indexes:
        patient_indexes.append(unique_idx)

print(f"Patient indexes: {patient_indexes}")

In [None]:
# Remove noisy beats labelled as 'N'
for i, patient in enumerate(patient_instances):
    patient.single_beats = [beat for j, beat in enumerate(patient.single_beats) if j not in noisy_N_list[i]]
    # Note that in this way the noisy beats is removed from the contiguous beats as well
    # Yet the previous and the following beats to the noisy one are not removed (this may need to be done)
    patient.contiguous_beats = [beat for j, beat in enumerate(patient.contiguous_beats) if j not in noisy_N_list[i]]
    patient.mean = [mean for j, mean in enumerate(patient.mean) if j not in noisy_N_list[i]]
    patient.std = [std for j, std in enumerate(patient.std) if j not in noisy_N_list[i]]
    patient.amplitude = [amplitude for j, amplitude in enumerate(patient.amplitude) if j not in noisy_N_list[i]]
    patient.peak_value = [peak_value for j, peak_value in enumerate(patient.peak_value) if j not in noisy_N_list[i]]
    patient.pre_PP = [pre_PP for j, pre_PP in enumerate(patient.pre_PP) if j not in noisy_N_list[i]]
    patient.post_PP = [post_PP for j, post_PP in enumerate(patient.post_PP) if j not in noisy_N_list[i]]
    patient.avg_PP = [avg_PP for j, avg_PP in enumerate(patient.avg_PP) if j not in noisy_N_list[i]]
    patient.width = [width for j, width in enumerate(patient.width) if j not in noisy_N_list[i]]
    patient.labels = [label for j, label in enumerate(patient.labels) if j not in noisy_N_list[i]]
    patient.peak_locations = [peak_location for j, peak_location in enumerate(patient.peak_locations) if j not in noisy_N_list[i]]
    patient.FWHM = [FWHM for j, FWHM in enumerate(patient.FWHM) if j not in noisy_N_list[i]]
    patient.skewness = [skewness for j, skewness in enumerate(patient.skewness) if j not in noisy_N_list[i]]
    patient.pre_skew = [pre_skew for j, pre_skew in enumerate(patient.pre_skew) if j not in noisy_N_list[i]]
    patient.post_skew = [post_skew for j, post_skew in enumerate(patient.post_skew) if j not in noisy_N_list[i]]
    patient.kurtosis = [kurtosis for j, kurtosis in enumerate(patient.kurtosis) if j not in noisy_N_list[i]]
    patient.pre_kurt = [pre_kurt for j, pre_kurt in enumerate(patient.pre_kurt) if j not in noisy_N_list[i]]
    patient.post_kurt = [post_kurt for j, post_kurt in enumerate(patient.post_kurt) if j not in noisy_N_list[i]]
    patient.entropy = [entropy for j, entropy in enumerate(patient.entropy) if j not in noisy_N_list[i]]
    patient.RMS = [RMS for j, RMS in enumerate(patient.RMS) if j not in noisy_N_list[i]]
    patient.neg_neg_jump = [neg_neg_jump for j, neg_neg_jump in enumerate(patient.neg_neg_jump) if j not in noisy_N_list[i]]
    patient.pre_pos_pos_jump = [pre_pos_pos_jump for j, pre_pos_pos_jump in enumerate(patient.pre_pos_pos_jump) if j not in noisy_N_list[i]]
    patient.post_pos_pos_jump = [post_pos_pos_jump for j, post_pos_pos_jump in enumerate(patient.post_pos_pos_jump) if j not in noisy_N_list[i]]
    patient.rise_time = [rise_time for j, rise_time in enumerate(patient.rise_time) if j not in noisy_N_list[i]]
    patient.fall_time = [fall_time for j, fall_time in enumerate(patient.fall_time) if j not in noisy_N_list[i]]
    patient.area = [area for j, area in enumerate(patient.area) if j not in noisy_N_list[i]]
    patient.local_hrv = [local_hrv for j, local_hrv in enumerate(patient.local_hrv) if j not in noisy_N_list[i]]
    patient.energy = [energy for j, energy in enumerate(patient.energy) if j not in noisy_N_list[i]]
    patient.dominant_frequency = [dominant_frequency for j, dominant_frequency in enumerate(patient.dominant_frequency) if j not in noisy_N_list[i]]

In [None]:
# Define a function to check class distribution
def calculate_class_distribution(patient_instances):
    """
    Calculates the class distribution of the labels.

    Args:
        patient_instances (list): The list of patient instances.
    """
    tot_count_n = 0
    tot_count_s = 0
    tot_count_v = 0
    for patient in patient_instances:
        count_n, count_s, count_v = calculate_label_distribution(patient.labels)
        tot_count_n += count_n
        tot_count_s += count_s
        tot_count_v += count_v
    print(f"Label Distribution: {tot_count_n} N beats, {tot_count_s} S beats, {tot_count_v} V beats")

# Check class distribution
calculate_class_distribution(patient_instances)

## Minority Classes Noisy Beats Inspection

As N noisy beats have been removed indexes for the S and V noisy beats must be re-computed.
In this subsection the goal is to try to reconstruct the beats associated to the minority class labels via autoencoders.
As some of the beats show labelled peak positions far from the actual peak, this discrepancy must be taken into account.

In [None]:
# Define a function to check if the peak position is correct
def check_peak_position(patients, threshold=10):
    """
    Checks if the peak position is correct.

    Args:
        patients (list): The list of patient instances.
        threhshold (int, optional): The threshold to use to determine if the peak position is correct. Defaults to 5.
    Returns:
        list: A list of incorrect peak positions.
        Each element of the list is a tuple containing the patient index and the incorrect peak positions.
    """
    incorrect_peak_positions = []
    for patient_id, patient in enumerate(patients):
        for beat_id, (beat, peak_location)in enumerate(zip(patient.single_beats, patient.peak_locations)):
            peak_pos = np.argmax(beat)
            # Check if the peak is at the beginning or at the end of the beat
            if peak_pos in [0, len(beat)-1]:
                continue
            # Check if the peak is correct
            if not (peak_pos - threshold <= peak_location <= peak_pos + threshold):
                incorrect_peak_positions.append((patient_id, beat_id))

    return incorrect_peak_positions

# Check incorrect peak positions
incorrect_peak_positions = check_peak_position(patient_instances)
# Check dimensionality of mislabelled peaks.
# Note that many, but not necessarily all of these will correspond to noisy beats
print(f"Num. mislabelled peaks: {len(incorrect_peak_positions)}")

In [None]:
# Group incorrect_peak_positions by patients
grouped_incorrect_peak_positions = group_noisy_beats(incorrect_peak_positions)
incorrect_peak_positions_list = []
for group in grouped_incorrect_peak_positions:
    incorrect_peak_positions_list.append([item[1] for item in group])

In [None]:
# Check if incorrect_peak_positions have been found for each patient
patient_indexes = []
for group in grouped_incorrect_peak_positions:
    idx = [item[0] for item in group]
    unique_idx = set(idx)
    if unique_idx not in patient_indexes:
        patient_indexes.append(unique_idx)

print(f"Patient indexes: {patient_indexes}")

In [None]:
# Remove mislabelled peaks
for i, patient in enumerate(patient_instances):
    patient.single_beats = [beat for j, beat in enumerate(patient.single_beats) if j not in incorrect_peak_positions_list[i]]
    patient.contiguous_beats = [beat for j, beat in enumerate(patient.contiguous_beats) if j not in incorrect_peak_positions_list[i]]
    patient.mean = [mean for j, mean in enumerate(patient.mean) if j not in incorrect_peak_positions_list[i]]
    patient.std = [std for j, std in enumerate(patient.std) if j not in incorrect_peak_positions_list[i]]
    patient.amplitude = [amplitude for j, amplitude in enumerate(patient.amplitude) if j not in incorrect_peak_positions_list[i]]
    patient.peak_value = [peak_value for j, peak_value in enumerate(patient.peak_value) if j not in incorrect_peak_positions_list[i]]
    patient.pre_PP = [pre_PP for j, pre_PP in enumerate(patient.pre_PP) if j not in incorrect_peak_positions_list[i]]
    patient.post_PP = [post_PP for j, post_PP in enumerate(patient.post_PP) if j not in incorrect_peak_positions_list[i]]
    patient.avg_PP = [avg_PP for j, avg_PP in enumerate(patient.avg_PP) if j not in incorrect_peak_positions_list[i]]
    patient.width = [width for j, width in enumerate(patient.width) if j not in incorrect_peak_positions_list[i]]
    patient.labels = [label for j, label in enumerate(patient.labels) if j not in incorrect_peak_positions_list[i]]
    patient.peak_locations = [peak_location for j, peak_location in enumerate(patient.peak_locations) if j not in incorrect_peak_positions_list[i]]
    patient.FWHM = [FWHM for j, FWHM in enumerate(patient.FWHM) if j not in incorrect_peak_positions_list[i]]
    patient.skewness = [skewness for j, skewness in enumerate(patient.skewness) if j not in incorrect_peak_positions_list[i]]
    patient.pre_skew = [pre_skew for j, pre_skew in enumerate(patient.pre_skew) if j not in incorrect_peak_positions_list[i]]
    patient.post_skew = [post_skew for j, post_skew in enumerate(patient.post_skew) if j not in incorrect_peak_positions_list[i]]
    patient.kurtosis = [kurtosis for j, kurtosis in enumerate(patient.kurtosis) if j not in incorrect_peak_positions_list[i]]
    patient.pre_kurt = [pre_kurt for j, pre_kurt in enumerate(patient.pre_kurt) if j not in incorrect_peak_positions_list[i]]
    patient.post_kurt = [post_kurt for j, post_kurt in enumerate(patient.post_kurt) if j not in incorrect_peak_positions_list[i]]
    patient.entropy = [entropy for j, entropy in enumerate(patient.entropy) if j not in incorrect_peak_positions_list[i]]
    patient.RMS = [RMS for j, RMS in enumerate(patient.RMS) if j not in incorrect_peak_positions_list[i]]
    patient.neg_neg_jump = [neg_neg_jump for j, neg_neg_jump in enumerate(patient.neg_neg_jump) if j not in incorrect_peak_positions_list[i]]
    patient.pre_pos_pos_jump = [pre_pos_pos_jump for j, pre_pos_pos_jump in enumerate(patient.pre_pos_pos_jump) if j not in incorrect_peak_positions_list[i]]
    patient.post_pos_pos_jump = [post_pos_pos_jump for j, post_pos_pos_jump in enumerate(patient.post_pos_pos_jump) if j not in incorrect_peak_positions_list[i]]
    patient.rise_time = [rise_time for j, rise_time in enumerate(patient.rise_time) if j not in incorrect_peak_positions_list[i]]
    patient.fall_time = [fall_time for j, fall_time in enumerate(patient.fall_time) if j not in incorrect_peak_positions_list[i]]
    patient.area = [area for j, area in enumerate(patient.area) if j not in incorrect_peak_positions_list[i]]
    patient.local_hrv = [local_hrv for j, local_hrv in enumerate(patient.local_hrv) if j not in incorrect_peak_positions_list[i]]
    patient.energy = [energy for j, energy in enumerate(patient.energy) if j not in incorrect_peak_positions_list[i]]
    patient.dominant_frequency = [dominant_frequency for j, dominant_frequency in enumerate(patient.dominant_frequency) if j not in incorrect_peak_positions_list[i]]

In [None]:
# Compute new class distribution
calculate_class_distribution(patient_instances)

In [None]:
# Get the indices of the noisy beats
noisy_beats = get_noisy_beats(patient_instances)
# Check noisy beats labels
check_noisy_beats_labels = [patient_instances[i].labels[j] for i, j in noisy_beats]
unique, counts = np.unique(check_noisy_beats_labels, return_counts=True)
print(f"Labels of noisy beats: {dict(zip(unique, counts))}")

In [None]:
# Separate beats based on label
noisy_beats_S = [(i,j) for i, j in noisy_beats if patient_instances[i].labels[j] == 'S']
noisy_beats_V = [(i,j) for i, j in noisy_beats if patient_instances[i].labels[j] == 'V']

In [None]:
# Visualize a batch of S noisy beats
plot_noisy_beats(patient_instances, noisy_beats_S)

In [None]:
# Visualize a batch of V noisy beats
plot_noisy_beats(patient_instances, noisy_beats_V)

Notice how few the noisy S and V beats are. For this reason, reconstructing the actual beats may be unfeasible.
So, noisy beats are completely removed.

In [None]:
# Group noisy_beats_S by patients
grouped_noisy_beats_S = group_noisy_beats(noisy_beats_S)
noisy_beats_S_list = []
for group in grouped_noisy_beats_S:
    noisy_beats_S_list.append([item[1] for item in group])

In [None]:
# Check if incorrect_peak_positions have been found for each patient
patient_indexes = []
for group in grouped_noisy_beats_S:
    idx = [item[0] for item in group]
    unique_idx = set(idx)
    if unique_idx not in patient_indexes:
        patient_indexes.append(unique_idx)

print(f"Patient indexes: {patient_indexes}")

In [None]:
# As noisy 'S' beats are not found for all the patients the removal must include an additional check
# Extract the indexes from the sets
patient_indexes = [list(s)[0] for s in patient_indexes]
# Remove the noisy beats labelled as 'S'
for i, patient in enumerate(patient_indexes):
    patient_instances[patient].single_beats = [beat for j, beat in enumerate(patient_instances[patient].single_beats) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].contiguous_beats = [beat for j, beat in enumerate(patient_instances[patient].contiguous_beats) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].mean = [mean for j, mean in enumerate(patient_instances[patient].mean) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].std = [std for j, std in enumerate(patient_instances[patient].std) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].amplitude = [amplitude for j, amplitude in enumerate(patient_instances[patient].amplitude) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].peak_value = [peak_value for j, peak_value in enumerate(patient_instances[patient].peak_value) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].pre_PP = [pre_PP for j, pre_PP in enumerate(patient_instances[patient].pre_PP) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].post_PP = [post_PP for j, post_PP in enumerate(patient_instances[patient].post_PP) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].avg_PP = [avg_PP for j, avg_PP in enumerate(patient_instances[patient].avg_PP) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].width = [width for j, width in enumerate(patient_instances[patient].width) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].labels = [label for j, label in enumerate(patient_instances[patient].labels) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].peak_locations = [peak_location for j, peak_location in enumerate(patient_instances[patient].peak_locations) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].FWHM = [FWHM for j, FWHM in enumerate(patient_instances[patient].FWHM) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].skewness = [skewness for j, skewness in enumerate(patient_instances[patient].skewness) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].pre_skew = [pre_skew for j, pre_skew in enumerate(patient_instances[patient].pre_skew) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].post_skew = [post_skew for j, post_skew in enumerate(patient_instances[patient].post_skew) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].kurtosis = [kurtosis for j, kurtosis in enumerate(patient_instances[patient].kurtosis) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].pre_kurt = [pre_kurt for j, pre_kurt in enumerate(patient_instances[patient].pre_kurt) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].post_kurt = [post_kurt for j, post_kurt in enumerate(patient_instances[patient].post_kurt) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].entropy = [entropy for j, entropy in enumerate(patient_instances[patient].entropy) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].RMS = [RMS for j, RMS in enumerate(patient_instances[patient].RMS) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].neg_neg_jump = [neg_neg_jump for j, neg_neg_jump in enumerate(patient_instances[patient].neg_neg_jump) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].pre_pos_pos_jump = [pre_pos_pos_jump for j, pre_pos_pos_jump in enumerate(patient_instances[patient].pre_pos_pos_jump) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].post_pos_pos_jump = [post_pos_pos_jump for j, post_pos_pos_jump in enumerate(patient_instances[patient].post_pos_pos_jump) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].rise_time = [rise_time for j, rise_time in enumerate(patient_instances[patient].rise_time) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].fall_time = [fall_time for j, fall_time in enumerate(patient_instances[patient].fall_time) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].area = [area for j, area in enumerate(patient_instances[patient].area) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].local_hrv = [local_hrv for j, local_hrv in enumerate(patient_instances[patient].local_hrv) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].energy = [energy for j, energy in enumerate(patient_instances[patient].energy) if j not in noisy_beats_S_list[i]]
    patient_instances[patient].dominant_frequency = [dominant_frequency for j, dominant_frequency in enumerate(patient_instances[patient].dominant_frequency) if j not in noisy_beats_S_list[i]]
    
# Check new label distribution
calculate_class_distribution(patient_instances)

As we removed the noisy 'S' beats then the indexes for the noisy 'V' beats must be re-computed.

In [None]:
# Get the indices of the noisy beats
noisy_beats = get_noisy_beats(patient_instances)
# Check noisy beats labels
check_noisy_beats_labels = [patient_instances[i].labels[j] for i, j in noisy_beats]
unique, counts = np.unique(check_noisy_beats_labels, return_counts=True)
print(f"Labels of noisy beats: {dict(zip(unique, counts))}")
# Separate beats based on label
noisy_beats_V = [(i,j) for i, j in noisy_beats if patient_instances[i].labels[j] == 'V']

In [None]:
# Group noisy_beats_V by patients
grouped_noisy_beats_V = group_noisy_beats(noisy_beats_V)
noisy_beats_V_list = []
for group in grouped_noisy_beats_V:
    noisy_beats_V_list.append([item[1] for item in group])

In [None]:
# Check if incorrect_peak_positions have been found for each patient
patient_indexes = []
for group in grouped_noisy_beats_V:
    idx = [item[0] for item in group]
    unique_idx = set(idx)
    if unique_idx not in patient_indexes:
        patient_indexes.append(unique_idx)

print(f"Patient indexes: {patient_indexes}")

In [None]:
# As noisy 'V' beats are not found for all the patients the removal must include an additional check
# Extract the indexes from the sets
patient_indexes = [list(s)[0] for s in patient_indexes]
# Remove the noisy beats labelled as 'V'
for i, patient in enumerate(patient_indexes):
    patient_instances[patient].single_beats = [beat for j, beat in enumerate(patient_instances[patient].single_beats) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].contiguous_beats = [beat for j, beat in enumerate(patient_instances[patient].contiguous_beats) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].mean = [mean for j, mean in enumerate(patient_instances[patient].mean) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].std = [std for j, std in enumerate(patient_instances[patient].std) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].amplitude = [amplitude for j, amplitude in enumerate(patient_instances[patient].amplitude) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].peak_value = [peak_value for j, peak_value in enumerate(patient_instances[patient].peak_value) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].pre_PP = [pre_PP for j, pre_PP in enumerate(patient_instances[patient].pre_PP) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].post_PP = [post_PP for j, post_PP in enumerate(patient_instances[patient].post_PP) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].avg_PP = [avg_PP for j, avg_PP in enumerate(patient_instances[patient].avg_PP) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].width = [width for j, width in enumerate(patient_instances[patient].width) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].labels = [label for j, label in enumerate(patient_instances[patient].labels) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].peak_locations = [peak_location for j, peak_location in enumerate(patient_instances[patient].peak_locations) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].FWHM = [FWHM for j, FWHM in enumerate(patient_instances[patient].FWHM) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].skewness = [skewness for j, skewness in enumerate(patient_instances[patient].skewness) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].pre_skew = [pre_skew for j, pre_skew in enumerate(patient_instances[patient].pre_skew) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].post_skew = [post_skew for j, post_skew in enumerate(patient_instances[patient].post_skew) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].kurtosis = [kurtosis for j, kurtosis in enumerate(patient_instances[patient].kurtosis) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].pre_kurt = [pre_kurt for j, pre_kurt in enumerate(patient_instances[patient].pre_kurt) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].post_kurt = [post_kurt for j, post_kurt in enumerate(patient_instances[patient].post_kurt) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].entropy = [entropy for j, entropy in enumerate(patient_instances[patient].entropy) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].RMS = [RMS for j, RMS in enumerate(patient_instances[patient].RMS) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].neg_neg_jump = [neg_neg_jump for j, neg_neg_jump in enumerate(patient_instances[patient].neg_neg_jump) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].pre_pos_pos_jump = [pre_pos_pos_jump for j, pre_pos_pos_jump in enumerate(patient_instances[patient].pre_pos_pos_jump) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].post_pos_pos_jump = [post_pos_pos_jump for j, post_pos_pos_jump in enumerate(patient_instances[patient].post_pos_pos_jump) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].rise_time = [rise_time for j, rise_time in enumerate(patient_instances[patient].rise_time) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].fall_time = [fall_time for j, fall_time in enumerate(patient_instances[patient].fall_time) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].area = [area for j, area in enumerate(patient_instances[patient].area) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].local_hrv = [local_hrv for j, local_hrv in enumerate(patient_instances[patient].local_hrv) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].energy = [energy for j, energy in enumerate(patient_instances[patient].energy) if j not in noisy_beats_V_list[i]]
    patient_instances[patient].dominant_frequency = [dominant_frequency for j, dominant_frequency in enumerate(patient_instances[patient].dominant_frequency) if j not in noisy_beats_V_list[i]]

# Check new label distribution
calculate_class_distribution(patient_instances)

# Split patients into train-validation and test sets

In [None]:
# Define a function to compute class proportions
def calculate_class_proportions(patient_instances):
    """
    Calculates the class proportions of the labels.

    Args:
        patient_instances (list): The list of patient instances.
    """
    tot_count_n = 0
    tot_count_s = 0
    tot_count_v = 0
    for patient in patient_instances:
        count_n, count_s, count_v = calculate_label_distribution(patient.labels)
        tot_count_n += count_n
        tot_count_s += count_s
        tot_count_v += count_v
    n_ratio = tot_count_n / (tot_count_n + tot_count_v + tot_count_s)
    v_ratio = tot_count_v / (tot_count_n + tot_count_v + tot_count_s)
    s_ratio = tot_count_s / (tot_count_n + tot_count_v + tot_count_s)
    print(f"Label proportions: {round(n_ratio, 4)} N beats, {round(v_ratio, 4)} V beats, {round(s_ratio, 4)} S beats")
    return n_ratio, v_ratio, s_ratio

# Check class proportions
n_ratio, v_ratio, s_ratio = calculate_class_proportions(patient_instances)

In [None]:
from sklearn.model_selection import train_test_split

# Initialize variables
n_ratio_train = 0
s_ratio_train = 0
v_ratio_train = 0

n_ratio_val = 0
s_ratio_val = 0
v_ratio_val = 0

n_ratio_test = 0
s_ratio_test = 0
v_ratio_test = 0
random_state = 999
best_random_state = random_state
max_iterations = 100
iteration = 0

# Initialize variables for the best split
best_diff = float('inf')
best_split = None

# Loop until desired conditions are met
while((abs(n_ratio_train - n_ratio) > 0.001 or abs(s_ratio_train - s_ratio) > 0.001 or abs(v_ratio_train - v_ratio) > 0.001 or
       abs(n_ratio_val - n_ratio) > 0.001 or abs(s_ratio_val - s_ratio) > 0.001 or abs(v_ratio_val - v_ratio) > 0.001 or
       abs(n_ratio_test - n_ratio) > 0.001 or abs(s_ratio_test - s_ratio) > 0.001 or abs(v_ratio_test - v_ratio) > 0.001) and iteration < max_iterations):

    # Split the data into train, validation and test sets
    X_train_val, X_test = train_test_split(patient_instances, test_size=0.15, random_state=random_state)
    X_train, X_val = train_test_split(X_train_val, test_size=len(X_test), random_state=random_state)

    # Check label distribution in train set
    print("Train set:")
    calculate_class_distribution(X_train)
    n_ratio_train, v_ratio_train, s_ratio_train = calculate_class_proportions(X_train)

    # Check label distribution in validation set
    print("Validation set:")
    calculate_class_distribution(X_val)
    n_ratio_val, v_ratio_val, s_ratio_val = calculate_class_proportions(X_val)

    # Check label distribution in test set
    print("Test set:")
    calculate_class_distribution(X_test)
    n_ratio_test, v_ratio_test, s_ratio_test = calculate_class_proportions(X_test)

    # Calculate the total difference between the ratios
    total_diff = abs(n_ratio_train - n_ratio) + abs(s_ratio_train - s_ratio) + abs(v_ratio_train - v_ratio) + \
                abs(n_ratio_val - n_ratio) + abs(s_ratio_val - s_ratio) + abs(v_ratio_val - v_ratio) + \
                abs(n_ratio_test - n_ratio) + abs(s_ratio_test - s_ratio) + abs(v_ratio_test - v_ratio)

    # If this split is better than the previous best, update the best split
    if total_diff < best_diff:
        best_random_state = random_state
        best_diff = total_diff
        best_split = (X_train, X_val, X_test)

    random_state += 1
    iteration += 1

# After the loop, best_split contains the best split found
if(iteration >= max_iterations):
    print("Max iterations reached")
    X_train, X_val, X_test = best_split
    print(f"Best Random State: {best_random_state}")

In [None]:
# Check the class proportions of the sets
print("Train set:")
calculate_class_proportions(X_train)
print("Validation set:")
calculate_class_proportions(X_val)
print("Test set:")
calculate_class_proportions(X_test)

# Build Network Input


In [None]:
# Build train, validation and test sets
X_train_single_beats = [beat for patient in X_train for beat in patient.single_beats]
X_val_single_beats = [beat for patient in X_val for beat in patient.single_beats]
X_test_single_beats = [beat for patient in X_test for beat in patient.single_beats]

X_train_contiguous_beats = [beat for patient in X_train for beat in patient.contiguous_beats]
X_val_contiguous_beats = [beat for patient in X_val for beat in patient.contiguous_beats]
X_test_contiguous_beats = [beat for patient in X_test for beat in patient.contiguous_beats]

# Build train, validation and test labels
y_train = [label for patient in X_train for label in patient.labels]
y_val = [label for patient in X_val for label in patient.labels]
y_test = [label for patient in X_test for label in patient.labels]

# Check dimensionality of train, validation and test sets
print("-> Single Beats")
print(f"Train dim.: {len(X_train_single_beats)}")
print(f"Validation dim.: {len(X_val_single_beats)}")
print(f"Test dim.: {len(X_test_single_beats)}")
print("-> Contiguous Beats")
print(f"Train dim.: {len(X_train_contiguous_beats)}")
print(f"Validation dim.: {len(X_val_contiguous_beats)}")
print(f"Test dim.: {len(X_test_contiguous_beats)}")

# Check dimensionality of labels
print("-> Labels")
print(f"Train labels dim.: {len(y_train)}")
print(f"Validation labels dim.: {len(y_val)}")
print(f"Test labels dim.: {len(y_test)}")

As the train-validation-test split was performed, it was seen that some beats (10 for the normal beats and 1 for the abnormal ones) where not 100 samples in length. This is due to the fact that these beats were in proximity to the signal truncations so in the case of the N signal they are removed, for the abnormal signals padding is used.

In [None]:
def separate_short_beats(beats, labels, target_len=100):
    # Initialize empty lists for each label
    short_beats_N = []
    short_beats_S = []
    short_beats_V = []

    # Iterate over the beats and append them to the corresponding list based on their label
    for i, beat in enumerate(beats):
        label = labels[i]
        if len(beat) < target_len:
            if label == 'N':
                print(f"Beat {i} is a {len(beat)} N beat")
                short_beats_N.append(i)
            elif label == 'S':
                print(f"Beat {i} is a {len(beat)} S beat")
                short_beats_S.append(i)
            elif label == 'V':
                print(f"Beat {i} is a {len(beat)} V beat")
                short_beats_V.append(i)

    return short_beats_N, short_beats_S, short_beats_V

# Check the number of short beats for each label
SINGLE_BEAT_LEN = 100
CONT_BEAT_LEN = 200

print("--> Single Beats")
print("Train set:")
short_beats_N_train, short_beats_S_train, short_beats_V_train = separate_short_beats(X_train_single_beats, y_train, target_len=SINGLE_BEAT_LEN)
print("Validation set:")
short_beats_N_val, short_beats_S_val, short_beats_V_val = separate_short_beats(X_val_single_beats, y_val, target_len=SINGLE_BEAT_LEN)
print("Test set:")
short_beats_N_test, short_beats_S_test, short_beats_V_test = separate_short_beats(X_test_single_beats, y_test, target_len=SINGLE_BEAT_LEN)

print("--> Contiguous Beats")
print("Train set:")
short_contiguous_beats_N_train, short_contiguous_beats_S_train, short_contiguous_beats_V_train = separate_short_beats(X_train_contiguous_beats, y_train, target_len=CONT_BEAT_LEN)
print("Validation set:")
short_contiguous_beats_N_val, short_contiguous_beats_S_val, short_contiguous_beats_V_val = separate_short_beats(X_val_contiguous_beats, y_val, target_len=CONT_BEAT_LEN)
print("Test set:")
short_contiguous_beats_N_test, short_contiguous_beats_S_test, short_contiguous_beats_V_test = separate_short_beats(X_test_contiguous_beats, y_test, target_len=CONT_BEAT_LEN)

In [None]:
# Define a function to pad the sequence by repeating its last value
def pad_sequence(seq, target_length):
    pad_size = target_length - len(seq)
    if pad_size <= 0:
        return seq
    else:
        return np.pad(seq, (0, pad_size), 'constant', constant_values=seq[-1])

In [None]:
# Apply padding to the short V sequence in Single beats
X_train_beats = [pad_sequence(beat, SINGLE_BEAT_LEN) if i in short_beats_V_train else beat for i, beat in enumerate(X_train_single_beats)]
X_val_beats = [pad_sequence(beat, SINGLE_BEAT_LEN) if i in short_beats_V_val else beat for i, beat in enumerate(X_val_single_beats)]
X_test_beats = [pad_sequence(beat, SINGLE_BEAT_LEN) if i in short_beats_V_test else beat for i, beat in enumerate(X_test_single_beats)]

# Apply padding to the short V sequence in Contiguous beats
X_train_beats_cont = [pad_sequence(beat, CONT_BEAT_LEN) if i in short_contiguous_beats_V_train else beat for i, beat in enumerate(X_train_contiguous_beats)]
X_val_beats_cont = [pad_sequence(beat, CONT_BEAT_LEN) if i in short_contiguous_beats_V_val else beat for i, beat in enumerate(X_val_contiguous_beats)]
X_test_beats_cont = [pad_sequence(beat, CONT_BEAT_LEN) if i in short_contiguous_beats_V_test else beat for i, beat in enumerate(X_test_contiguous_beats)]

# Apply padding to the short S sequence in Contiguous beats
X_train_beats_cont = [pad_sequence(beat, CONT_BEAT_LEN) if i in short_contiguous_beats_S_train else beat for i, beat in enumerate(X_train_beats_cont)]
X_val_beats_cont = [pad_sequence(beat, CONT_BEAT_LEN) if i in short_contiguous_beats_S_val else beat for i, beat in enumerate(X_val_beats_cont)]
X_test_beats_cont = [pad_sequence(beat, CONT_BEAT_LEN) if i in short_contiguous_beats_S_test else beat for i, beat in enumerate(X_test_beats_cont)]

In [None]:
print("--> Single Beats")
print("Train set:")
short_beats_N_train, short_beats_S_train, short_beats_V_train = separate_short_beats(X_train_beats, y_train)
print("Validation set:")
short_beats_N_val, short_beats_S_val, short_beats_V_val = separate_short_beats(X_val_beats, y_val)
print("Test set:")
short_beats_N_test, short_beats_S_test, short_beats_V_test = separate_short_beats(X_test_beats, y_test)

print("--> Contiguous Beats")
print("Train set:")
short_contiguous_beats_N_train, short_contiguous_beats_S_train, short_contiguous_beats_V_train = separate_short_beats(X_train_beats_cont, y_train, target_len=CONT_BEAT_LEN)
print("Validation set:")
short_contiguous_beats_N_val, short_contiguous_beats_S_val, short_contiguous_beats_V_val = separate_short_beats(X_val_beats_cont, y_val, target_len=CONT_BEAT_LEN)
print("Test set:")
short_contiguous_beats_N_test, short_contiguous_beats_S_test, short_contiguous_beats_V_test = separate_short_beats(X_test_beats_cont, y_test, target_len=CONT_BEAT_LEN)

In [None]:
# Remove the short N beats from all the sets
# Single Beats
X_train_single = [beat for i, beat in enumerate(X_train_beats) if i not in short_beats_N_train]
X_val_single = [beat for i, beat in enumerate(X_val_beats) if i not in short_beats_N_val]
X_test_single = [beat for i, beat in enumerate(X_test_beats) if i not in short_beats_N_test]
# Contiguous Beats
X_train_contiguous = [beat for i, beat in enumerate(X_train_beats_cont) if i not in short_contiguous_beats_N_train]
X_val_contiguous = [beat for i, beat in enumerate(X_val_beats_cont) if i not in short_contiguous_beats_N_val]
X_test_contiguous = [beat for i, beat in enumerate(X_test_beats_cont) if i not in short_contiguous_beats_N_test]
# Single Beats labels
y_train_single = [label for i, label in enumerate(y_train) if i not in short_beats_N_train]
y_val_single = [label for i, label in enumerate(y_val) if i not in short_beats_N_val]
y_test_single = [label for i, label in enumerate(y_test) if i not in short_beats_N_test]
# Contiguous Beats labels
y_train_contiguous = [label for i, label in enumerate(y_train) if i not in short_contiguous_beats_N_train]
y_val_contiguous = [label for i, label in enumerate(y_val) if i not in short_contiguous_beats_N_val]
y_test_contiguous = [label for i, label in enumerate(y_test) if i not in short_contiguous_beats_N_test]

In [None]:
# Single Beats
# Convert to NumPy array
X_train_single = np.array(X_train_single)
X_val_single = np.array(X_val_single)
X_test_single = np.array(X_test_single)
# Check dimensionality
print("Single Beats")
print(X_train_single.shape,X_val_single.shape,X_test_single.shape)
# Contiguous Beats
# Convert to NumPy array
X_train_contiguous = np.array(X_train_contiguous)
X_val_contiguous = np.array(X_val_contiguous)
X_test_contiguous = np.array(X_test_contiguous)
# Check dimensionality
print("Contiguous Beats")
print(X_train_contiguous.shape,X_val_contiguous.shape,X_test_contiguous.shape)

In [None]:
# Single Beats
# Convert to NumPy array
y_train_single = np.array(y_train_single)
y_val_single = np.array(y_val_single)
y_test_single = np.array(y_test_single)
# Check dimensionality
print("Single Beats")
print(y_train_single.shape,y_val_single.shape,y_test_single.shape)
# Contiguous Beats
# Convert to NumPy array
y_train_contiguous = np.array(y_train_contiguous)
y_val_contiguous = np.array(y_val_contiguous)
y_test_contiguous = np.array(y_test_contiguous)
# Check dimensionality
print("Contiguous Beats")
print(y_train_contiguous.shape,y_val_contiguous.shape,y_test_contiguous.shape)

In [None]:
from sklearn.preprocessing import LabelEncoder, OneHotEncoder

# One hot encode labels
num_classes = 3
encoder = LabelEncoder()
one_hot_encoder = OneHotEncoder(sparse=False, categories='auto')

# Single Beats
y_train_single_encoded = encoder.fit_transform(y_train_single).reshape(-1, 1)
y_val_single_encoded = encoder.transform(y_val_single).reshape(-1, 1)
y_test_single_encoded = encoder.transform(y_test_single).reshape(-1, 1)

y_train_single = one_hot_encoder.fit_transform(y_train_single_encoded)
y_val_single = one_hot_encoder.transform(y_val_single_encoded)
y_test_single = one_hot_encoder.transform(y_test_single_encoded)

# Contiguous Beats
y_train_contiguous_encoded = encoder.fit_transform(y_train_contiguous).reshape(-1, 1)
y_val_contiguous_encoded = encoder.transform(y_val_contiguous).reshape(-1, 1)
y_test_contiguous_encoded = encoder.transform(y_test_contiguous).reshape(-1, 1)

y_train_contiguous = one_hot_encoder.fit_transform(y_train_contiguous_encoded)
y_val_contiguous = one_hot_encoder.transform(y_val_contiguous_encoded)
y_test_contiguous = one_hot_encoder.transform(y_test_contiguous_encoded)

In [None]:
import pandas as pd
# Build Train dataframe with extracted features
X_train_feat = pd.DataFrame(columns=['mean', 'std', 'amplitude', 'peak_value', 
                                     'pre_PP','post_PP', 'avg_PP',
                                     'width', 'FWHM', 
                                     'skewness', 'pre_skew', 'post_skew',
                                     'kurtosis', 'pre_kurt', 'post_kurt',
                                     'entropy', 'RMS', 
                                     'neg_neg_jump', 'pre_pos_pos_jump', 'post_pos_pos_jump',
                                     'rise_time', 'fall_time', 'area',
                                     'local_hrv', 'energy', 'dominant_frequency'])
for patient in tqdm.tqdm(X_train, desc="Building Train Feature Dataframe",
                         total=len(X_train)):
  for i, beat in enumerate(patient.single_beats):
        row = [patient.mean[i], patient.std[i], patient.amplitude[i], patient.peak_value[i], 
               patient.pre_PP[i],patient.post_PP[i], patient.avg_PP[i], 
               patient.width[i], patient.FWHM[i], 
               patient.skewness[i], patient.pre_skew[i], patient.post_skew[i],
               patient.kurtosis[i], patient.pre_kurt[i], patient.post_kurt[i],
               patient.entropy[i], patient.RMS[i], 
               patient.neg_neg_jump[i], patient.pre_pos_pos_jump[i], patient.post_pos_pos_jump[i],
               patient.rise_time[i], patient.fall_time[i], patient.area[i],
               patient.local_hrv[i], patient.energy[i], patient.dominant_frequency[i]]
        X_train_feat.loc[len(X_train_feat)] = row

X_train_feat.head()

In [None]:
# Inspect data
X_train_feat.describe()

In [None]:
#check the presence of missing values as 'NaN'
print("The number of missing values per attribute is the following:")
X_train_feat.isna().sum()

In [None]:
#check the presence of 'Inf' values
print("The number of Inf values per attribute is the following:")
print(np.isinf(X_train_feat).sum())

In [None]:
# Build Val dataframe with extracted features
X_val_feat = pd.DataFrame(columns=['mean', 'std', 'amplitude', 'peak_value', 
                                     'pre_PP','post_PP', 'avg_PP',
                                     'width', 'FWHM', 
                                     'skewness', 'pre_skew', 'post_skew',
                                     'kurtosis', 'pre_kurt', 'post_kurt',
                                     'entropy', 'RMS', 
                                     'neg_neg_jump', 'pre_pos_pos_jump', 'post_pos_pos_jump',
                                     'rise_time', 'fall_time', 'area',
                                     'local_hrv', 'energy', 'dominant_frequency'])
for patient in tqdm.tqdm(X_val, desc="Building Val Feature Dataframe",
                         total=len(X_val)):
  for i, beat in enumerate(patient.single_beats):
        row = [patient.mean[i], patient.std[i], patient.amplitude[i], patient.peak_value[i], 
               patient.pre_PP[i],patient.post_PP[i], patient.avg_PP[i], 
               patient.width[i], patient.FWHM[i], 
               patient.skewness[i], patient.pre_skew[i], patient.post_skew[i],
               patient.kurtosis[i], patient.pre_kurt[i], patient.post_kurt[i],
               patient.entropy[i], patient.RMS[i], 
               patient.neg_neg_jump[i], patient.pre_pos_pos_jump[i], patient.post_pos_pos_jump[i],
               patient.rise_time[i], patient.fall_time[i], patient.area[i],
               patient.local_hrv[i], patient.energy[i], patient.dominant_frequency[i]]
        X_val_feat.loc[len(X_val_feat)] = row

In [None]:
# Concatenate the two dataframes in case ML approch is considered
# This is done since the aim is to perform Cross-Validation over the entire train-validation set
ML_MODEL = 0

if ML_MODEL:
  X_train_feat = pd.concat([X_train_feat, X_val_feat], axis=0)
  X_train_feat.shape

In [None]:
# Build test dataframe with extracted features
X_test_feat = pd.DataFrame(columns=['mean', 'std', 'amplitude', 'peak_value', 
                                     'pre_PP','post_PP', 'avg_PP',
                                     'width', 'FWHM', 
                                     'skewness', 'pre_skew', 'post_skew',
                                     'kurtosis', 'pre_kurt', 'post_kurt',
                                     'entropy', 'RMS', 
                                     'neg_neg_jump', 'pre_pos_pos_jump', 'post_pos_pos_jump',
                                     'rise_time', 'fall_time', 'area',
                                     'local_hrv', 'energy', 'dominant_frequency'])
for patient in tqdm.tqdm(X_test, desc="Building test Feature Dataframe",
                         total=len(X_test)):
  for i, beat in enumerate(patient.single_beats):
        row = [patient.mean[i], patient.std[i], patient.amplitude[i], patient.peak_value[i], 
               patient.pre_PP[i],patient.post_PP[i], patient.avg_PP[i], 
               patient.width[i], patient.FWHM[i], 
               patient.skewness[i], patient.pre_skew[i], patient.post_skew[i],
               patient.kurtosis[i], patient.pre_kurt[i], patient.post_kurt[i],
               patient.entropy[i], patient.RMS[i], 
               patient.neg_neg_jump[i], patient.pre_pos_pos_jump[i], patient.post_pos_pos_jump[i],
               patient.rise_time[i], patient.fall_time[i], patient.area[i],
               patient.local_hrv[i], patient.energy[i], patient.dominant_frequency[i]]
        X_test_feat.loc[len(X_test_feat)] = row

In [None]:
import seaborn as sns
# Visualize correlation among features
def plot_correlation_matrix(df):
    correlation_matrix = df.corr()
    plt.figure(figsize=(20, 10))
    sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm')
    plt.title('Correlation Matrix Heatmap', fontsize=13)
    plt.show()
    # Compute the average correlation of each feature with all the others
    average_correlation = correlation_matrix.mean()
    average_correlation_sorted = average_correlation.sort_values(ascending=False)
    print("-> Average correlation values for each feature:")
    print(average_correlation_sorted)

# Plot correlation matrix
plot_correlation_matrix(X_train_feat)

From the correlation matrix we notice that some features show high Pearson correlation values (>0.9). This is a case of multicollinearity and one of the two variables in the highly correlated pair should be removed to prevent feeding the model with redundant information. Removal is therefore performed by considering the average correlation of each feature with respect to all others.

In [None]:
# Drop highly correlated features
X_train_feat.drop('peak_value', axis=1, inplace=True)
X_test_feat.drop('peak_value', axis=1, inplace=True)

X_train_feat.drop('amplitude', axis=1, inplace=True)
X_test_feat.drop('amplitude', axis=1, inplace=True)

X_train_feat.drop('width', axis=1, inplace=True)
X_test_feat.drop('width', axis=1, inplace=True)

X_train_feat.drop('avg_PP', axis=1, inplace=True)
X_test_feat.drop('avg_PP', axis=1, inplace=True)

X_train_feat.drop('std', axis=1, inplace=True)
X_test_feat.drop('std', axis=1, inplace=True)

X_train_feat.drop('mean', axis=1, inplace=True)
X_test_feat.drop('mean', axis=1, inplace=True)

X_train_feat.drop('RMS', axis=1, inplace=True)
X_test_feat.drop('RMS', axis=1, inplace=True)

if not ML_MODEL:
    X_val_feat.drop('peak_value', axis=1, inplace=True)
    X_val_feat.drop('amplitude', axis=1, inplace=True)
    X_val_feat.drop('width', axis=1, inplace=True)
    X_val_feat.drop('avg_PP', axis=1, inplace=True)
    X_val_feat.drop('std', axis=1, inplace=True)
    X_val_feat.drop('mean', axis=1, inplace=True)
    X_val_feat.drop('RMS', axis=1, inplace=True)

In [None]:
# Plot new correlation matrix
plot_correlation_matrix(X_train_feat)

In [None]:
# Min-Max Normalization of Features
max_df = X_train_feat.max()
min_df = X_train_feat.min()

X_train_feat = (X_train_feat - min_df)/(max_df - min_df)
X_test_feat = (X_test_feat - min_df)/(max_df - min_df)

if not ML_MODEL:
  X_val_feat = (X_val_feat - min_df)/(max_df - min_df)

## Three class split

In [None]:
# Encode labels
num_classes = 3
# Label Encoder
y_train_encoded = encoder.fit_transform(y_train_single)
y_val_encoded = encoder.transform(y_val_single)
y_test_encoded = encoder.transform(y_test_single)

if ML_MODEL:
  y_train_encoded = np.concatenate((y_train_encoded, y_val_encoded), axis=0)
else:
  # One Hot Encoding
  from sklearn.preprocessing import OneHotEncoder

  one_hot_encoder = OneHotEncoder(sparse=False, categories='auto')
  y_train_enc = y_train_encoded.reshape(-1, 1)
  y_val_enc = y_val_encoded.reshape(-1, 1)
  y_test_enc = y_test_encoded.reshape(-1, 1)
  
  y_train_feat = one_hot_encoder.fit_transform(y_train_enc)
  y_val_feat = one_hot_encoder.transform(y_val_enc)
  y_test_feat = one_hot_encoder.transform(y_test_enc)

# Save 

In [None]:
# Feature dataframes and labels
save(X_train_feat, 'X_train_feat')
save(y_train_feat, 'y_train_feat')
save(X_test_feat, 'X_test_feat')
save(y_test_feat, 'y_test_feat')

if not ML_MODEL:
  save(X_val_feat, 'X_val_feat')
  save(y_val_feat, 'y_val_feat')

# Beat dataframes and labels
save(X_train_single, 'X_train_single')
save(y_train_single, 'y_train_single')
save(X_test_single, 'X_test_single')
save(y_test_single, 'y_test_single')

save(X_train_contiguous, 'X_train_contiguous')
save(y_train_contiguous, 'y_train_contiguous')
save(X_test_contiguous, 'X_test_contiguous')
save(y_test_contiguous, 'y_test_contiguous')

if not ML_MODEL:
  save(X_val_single, 'X_val_single')
  save(y_val_single, 'y_val_single')
  save(X_val_contiguous, 'X_val_contiguous')
  save(y_val_contiguous, 'y_val_contiguous')

# Two Class Split

In [None]:
y_train_encoded[y_train_encoded == 2] = 1
y_val_encoded[y_val_encoded == 2] = 1
y_test_encoded[y_test_encoded == 2] = 1

# Save the encoded labels
save(y_test_encoded, 'y_test_encoded_binary')

if not ML_MODEL:
  save(y_val_encoded, 'y_val_encoded_binary')
else:
  y_train_encoded = np.concatenate((y_train_encoded, y_val_encoded), axis=0)
  save(y_train_encoded, 'y_train_encoded_binary')

# One hot encode the labels
y_train_encoded = one_hot_encoder.fit_transform(y_train_encoded.reshape(-1, 1))
y_test_encoded = one_hot_encoder.transform(y_test_encoded.reshape(-1, 1))
save(y_train_encoded, 'y_train_one_hot_binary')
save(y_test_encoded, 'y_test_one_hot_binary')

if not ML_MODEL:
  y_val_encoded = one_hot_encoder.transform(y_val_encoded.reshape(-1, 1))
  save(y_val_encoded, 'y_val_one_hot_binary')

## Two Class Split for autoencoder

In [None]:
# Get 'N' beats
norm_beats = [beat for i, patient in enumerate(patient_instances)
              for j, (beat, label) in enumerate(zip(patient.single_beats, patient.labels))
              if label == 'N']
# Get abnormal beats
abnorm_beats = [beat for i, patient in enumerate(patient_instances)
                for j, (beat, label) in enumerate(zip(patient.single_beats, patient.labels))
                if label != 'N']
# Check dimensionality
print(f"Normal beats dim.: {len(norm_beats)}")
print(f"Abnormal beats dim.: {len(abnorm_beats)}")

In [None]:
# Apply padding to all elements in abnorm_beats
abnorm_beats = np.array([pad_sequence(seq, 100) for seq in abnorm_beats])

In [None]:
# Reshape the array
abnorm_beats = np.expand_dims(abnorm_beats, axis=-1)
abnorm_beats.shape

In [None]:
# Remove the short beats from norm_beats
norm_beats = np.array([beat for beat in norm_beats if len(beat)==100])
norm_beats = np.expand_dims(norm_beats,axis=-1)
norm_beats.shape

In [None]:
# Split the 'N' beats into train, validation and test sets
from sklearn.model_selection import train_test_split
X_train_norm, X_test_norm = train_test_split(norm_beats, test_size=0.1, random_state=99)
X_train_norm, X_val_norm = train_test_split(X_train_norm, test_size=0.1, random_state=99)

# Check dimensionality
print(f"Normal beats Train dim.: {X_train_norm.shape}")
print(f"Normal beats Validation dim.: {X_val_norm.shape}")
print(f"Normal beats Test dim.: {X_test_norm.shape}")

In [None]:
# Save 
save(X_train_norm, 'X_train_norm')
save(X_val_norm, 'X_val_norm')
save(X_test_norm, 'X_test_norm')
save(abnorm_beats, 'abnorm_beats')