# Data access and visualization

The first thing one must do is to load the data and being able to visualize it.

In [None]:
import core.dataloader as crloader
import core.data_plots as crplt
import core.preprocesses as crpre
# Load the data
data = crloader.load_data(data_path='../data/physionet.org/files/ptb-xl/1.0.2',
                       sampling_rate=100)
print("\nData loaded.\n")

In [None]:
# Access a specific patient's data
patient_id = 5678
ecg_ids = crloader.get_patient_id_ecg_ids(patient_id=patient_id,
                                          annotations=data['train']['annotations'])

print(f"Patient {patient_id} has {len(ecg_ids)} ECGs.")
ecg_id, ecg_date = ecg_ids[-1]  # most recent

signals = crloader.get_signal_from_ecg_id(ecg_id=ecg_id,
                                          raw_data=data['train']['data'],
                                          channel=-1)

annots = crloader.get_annotations_from_ecg_id(ecg_id=ecg_id,
                                              annotations=data['train']['annotations'])

In [None]:
# Visualize the ECG signal for all channels and annotations
data_display = crplt.plot_ecg_channels(raw_data=data['train']['data'][ecg_id],
                                       title=f"ECG ID {ecg_id} from {ecg_date}")

print(annots)

# Data preprocess

By preprocessing the signals, one can make them smoother, remove outliers, etc.

### Filtering

A common pre-process when working with signals is smoothing/filtering. That allows to remove some outliers and noise from the signal for a better analysis.

Some of the most used signal filtering techniques are:
- Savitzky-Golay filter
- Gaussian filter
- Median filter
- Low-pass filter
- High-pass filter
- Butterworth filter (band-pass filter)
- Convolution filter

The biggest challenge of filtering is the manual tunning. Finding the right parameters is a empirical work.

In [None]:
channel = 0
original_signal = signals[:, channel]
savgol_ecg = crpre.smooth_signal_savgol(ecg_signal=original_signal,
                                        window_length=5,
                                        polyorder = 2)
crplt.plot_filtered_signal(ecg_signal=original_signal,
                           smoothed_ecg=savgol_ecg,
                           title="Savitzky-Golay filter")

In [None]:
gaussian_ecg = crpre.smooth_signal_gaussian(ecg_signal=original_signal, sigma=3)
crplt.plot_filtered_signal(ecg_signal=original_signal,
                           smoothed_ecg=gaussian_ecg,
                           title="Gaussian filter")

In [None]:
median_ecg = crpre.smooth_signal_median(ecg_signal=original_signal, kernel_size=3)
crplt.plot_filtered_signal(ecg_signal=original_signal,
                           smoothed_ecg=median_ecg,
                           title="Median filter")

In [None]:
lowcut = 45
lowpass_ecg = crpre.smooth_signal_lowpass(ecg_signal=original_signal,
                                          sample_rate=100,
                                          order_filter=5,
                                          cut=lowcut)
crplt.plot_filtered_signal(ecg_signal=original_signal,
                           smoothed_ecg=lowpass_ecg,
                           title=f"Low-pass filter at {lowcut} Hz")

In [None]:
highcut = 0.5
highpass_ecg = crpre.smooth_signal_highpass(ecg_signal=original_signal,
                                          sample_rate=100,
                                          order_filter=5,
                                          cut=highcut)
crplt.plot_filtered_signal(ecg_signal=original_signal,
                           smoothed_ecg=lowpass_ecg,
                           title=f"High-pass filter at {highcut} Hz")

In [None]:
lowcut = 0.5  # avoid the breathing noise
highcut = 45  # avoid power-line noise
band_ecg = crpre.smooth_signal_butterworth(ecg_signal=original_signal,
                                           sample_rate=100,
                                           order_filter=5,
                                           lowcut=lowcut,
                                           highcut=highcut)
crplt.plot_filtered_signal(ecg_signal=original_signal,
                           smoothed_ecg=band_ecg,
                           title=f"Butterworth filter ({lowcut}Hz - {highcut}Hz)")

In [None]:
kernel = 7
conv_ecg = crpre.smooth_signal_convolution(ecg_signal=original_signal,
                                           kernel=kernel)
crplt.plot_filtered_signal(ecg_signal=original_signal,
                           smoothed_ecg=conv_ecg,
                           title=f"Convolution filter (kernel wide {kernel})")

As mentioned, tunning a filter is hard work. As an example, I show the influence of difference frequency cuts on a low-filter.

In [None]:
lowpass_ecgs = []
cutoffs = []
for lowcut in range(40, 50, 2):
    lowpass_ecg = crpre.smooth_signal_lowpass(ecg_signal=original_signal,
                                            sample_rate=100,
                                            order_filter=5,
                                            cut=lowcut)
    lowpass_ecgs.append(lowpass_ecg)
    cutoffs.append(f"{lowcut}Hz")

crplt.plot_filtered_signals(ecg_signal=original_signal,
                           smoothed_ecgs=lowpass_ecgs,
                           labels=cutoffs,
                           title="Low-pass filter search")

An application of the filtering is to remove the baseline wander.

Baseline wander is a typical artifact that corrupts the ECG. It can be caused by a variety of noise sources including respiration, body movements, and poor electrode contact. Its spectral content is usually confined to frequencies below 0.5 Hz.

The majority of baseline wander removal techniques can change the ECG and compromise its clinical relevance. For that reason, it is not a easy process.

A very basic baseline wander estimator was implemented using a sequence of median filter with different kernel sizes. The kernel size is estimated based on the sampling rate and the window duration in seconds.

In [None]:
wander = crpre.estimate_baseline_wander(ecg_signal=original_signal,durations=[0.5, 2], sample_rate=100)
rem_wander_ecg = crpre.remove_baseline_wander(ecg_signal=original_signal,durations=[0.5, 2], sample_rate=100)

crplt.plot_filtered_signals(ecg_signal=original_signal,
                           smoothed_ecgs=[wander, rem_wander_ecg],
                           labels=['estimated wander', 'filtered'],
                           title="Remove baseline wander")

Once we have the ECG signal, there is some basic analysis that one can do.
One of the most relevant information from an ECG is to look at the [QRS complex](https://en.wikipedia.org/wiki/QRS_complex).
In layman terms:
- R peak are the highest peaks
- Q peaks are the minimum peak before the R peak
- S peaks are the minimum peak after the R peak

From the R peaks one can estimate the heart rate.

Some of the most used detectors are:
- Pan and Tompkins
- Hamilton
- Christov
- Stationary Wavelet Transform
- Two Moving Average

And you can find an implementation [here](https://github.com/berndporr/py-ecg-detectors).
After trying it, I was not satisfied with the results. In most of the cases the R peaks were completly off.

I implemented my own [Pan and Tompkins QRS complex detector](https://en.wikipedia.org/wiki/Pan%E2%80%93Tompkins_algorithm).

In [None]:
from core.pan_tompkins import PanTompkinsQRS
peak_detector = PanTompkinsQRS(signal=original_signal, sample_rate=100, window_size=0.15)

crplt.plot_signal(signal=peak_detector.band_pass_sgn,
                  xlabel="Samples",
                  ylabel="Amplitude",
                  title="Bandpassed signal")

crplt.plot_signal(signal=peak_detector.mov_win_sgn,
                  xlabel="Samples",
                  ylabel="Amplitude",
                  title="Moving window integrated signal")

peak_detector.find_r_peaks()

crplt.plot_signal_and_rpeaks(signal=original_signal,
                             rpeaks_loc=peak_detector.tuned_peaks,
                             xlabel="Samples",
                             ylabel="Amplitude",
                             title="R peaks")

heart_bpm, heart_var = peak_detector.estimate_heartrate()
print(f"Heart rate: {heart_bpm:.2f} +- {heart_var:.2f} bpm")

# Data split check up

I want to see how the classes are distributed in the 3 different splits. I'm using the splits suggested by `Physionet`.

In [None]:
from core.utils import calculate_distribution

datasets = ["train", "val", "test"]
class_counts = {}
class_percentages = {}

for dataset in datasets:
    class_counts[dataset], class_percentages[dataset] = calculate_distribution(data[dataset]["labels"],
                                                                                       use_combo=False)

num_classes = len(class_counts["train"])
print(f"Number of classes: {num_classes}")

crplt.plot_data_distribution(class_counts, class_percentages)

datasets = ["train", "val", "test"]
class_counts = {}
class_percentages = {}

for dataset in datasets:
    class_counts[dataset], class_percentages[dataset] = calculate_distribution(data[dataset]["labels"],
                                                                                       use_combo=True)

num_classes = len(class_counts["train"])
print(f"Number of classes: {num_classes}")

crplt.plot_data_distribution(class_counts, class_percentages)

In [None]:
for ds_type in datasets:
    single_label = 0
    norm_label = 0
    multi_label = 0
    for key, perc in class_percentages[ds_type].items():
        if " " in key:
            multi_label += perc
        else:
            single_label += perc
            if "norm" == key.lower():
                norm_label += perc
    print(f"Dataset type: {ds_type}")
    print(f"\tPercentage of single labels: {single_label:.2f}%")
    print(f"\t\tNORM contribution: {norm_label:.2f}%")
    print(f"\t\tNon NORM contribution: {single_label - norm_label:.2f}%")
    print(f"\tPercentage of multi labels: {multi_label:.2f}%")

We can observe that we have 5 unique labels and 22 out of 31 possible combinations of labels.

There are few other things to notice:
- there are data points without a class
- the data is balanced between splits
- the contribution of single labels is dominant
- the single label NORM represents +- 40% of the data in every split
- the data doesn't have a long tail distribution (not assuming the combos)

Since this is a health use case, `sex` and `age` are often important factors. Let's check if that data is also properly distributed.

In [None]:
from collections import Counter
for ds_type in datasets:
    sex_count = Counter(data[ds_type]["annotations"].sex)
    total = sum(sex_count.values())
    print(f"Dataset type: {ds_type}")
    print(f"\tMale percentage: {sex_count['male']/total:.2%} ({sex_count['male']})")
    print(f"\tFemale percentage: {sex_count['female']/total:.2%} ({sex_count['female']})")

Although there is a slightly prevalence of `male`, the `sex` attribute is properly distributed between datasets.

Let's analyze the `age` distribution per `sex`.

In [None]:
for ds_type in datasets:
    print(f"Dataset type: {ds_type}")
    crplt.plot_distribution_age_sex(data[ds_type]["annotations"])

The `age` distribution per `sex` in the different datasets is very similar.

There 2 points to note:
- there are clearly data points with unrealistic age (+- 300)
- we have more `female` data in the extremes and more `male` data in the centre

This last points alert us to check for bias in the results later on.


Finally, lets check the `diagnostic` distribution per `sex`.

In [None]:
for ds_type in datasets:
    print(f"Dataset type: {ds_type}")
    crplt.plot_distribution_diagnostic_sex(data[ds_type]["annotations"])

The `diagnostic` distribution per `sex` in the different datasets is very similar.
We can also observe that `female` has more `NORM` diagnostics. That means that the we have less females in the data and they are mostly healthy. One can expect to be harder to detect problems with a model on females. Something to confirm later.

## Data augmentation

In this multi-label problem, we could see that the data is well distributed between the splits considering `diagnostic`, `sex` and `age`.

But we also noticed that the label `NORM` is dominant and represents +- 35% of the data in a split and the next most representative label is +- 20%.

The challenge now is to check if we can augment the data such that the the models would perform better.
In theory, the best results are achieved with a labels distribution more balance inside the train split. In practiced, they may not be true and only by training and testing the models one can know.

I will focus the augmentation on the cases with one label only.
Once the data is resampled, I will add to the data all the multi labeled cases.

I will have at my disposal 6 data distributions to evaluate:
- use the data as it is
- use only the over/under sample the data with one label
- over/under sample the data with one label and add the multi labeled cases
- use only the multi labeled cases


In [None]:
# Recap of the original distribution
datasets = ["train", "val", "test"]
class_counts = {}
class_percentages = {}

for dataset in datasets:
    class_counts[dataset], class_percentages[dataset] = calculate_distribution(data[dataset]["labels"],
                                                                                       use_combo=False)
crplt.plot_data_distribution(class_counts, class_percentages)

num_classes = len(class_counts["train"])
print(f"Number of classes: {num_classes}")
print(f"Classes: {class_counts['train']}")

In [None]:
from core.data_augmentation import DataAugmentor
aug = DataAugmentor()

# Single label only
new_data, new_labels = aug.get_single_label_only(data["train"]["data"], data["train"]["labels"])
class_counts["train"], class_percentages["train"] = calculate_distribution(new_labels)
crplt.plot_data_distribution(class_counts, class_percentages)

# Multi labels only
new_data, new_labels = aug.get_multi_label_only(data["train"]["data"], data["train"]["labels"])
class_counts["train"], class_percentages["train"] = calculate_distribution(new_labels)
crplt.plot_data_distribution(class_counts, class_percentages)

# Single label under sampled
new_data, new_labels = aug.get_undersampled(data["train"]["data"], data["train"]["labels"], add_multi=False)
class_counts["train"], class_percentages["train"] = calculate_distribution(new_labels)
crplt.plot_data_distribution(class_counts, class_percentages)

# Single label under sampled and multi label
new_data, new_labels = aug.get_undersampled(data["train"]["data"], data["train"]["labels"], add_multi=True)
class_counts["train"], class_percentages["train"] = calculate_distribution(new_labels)
crplt.plot_data_distribution(class_counts, class_percentages)

# Single label over sampled
new_data, new_labels = aug.get_oversampled(data["train"]["data"], data["train"]["labels"], add_multi=False)
class_counts["train"], class_percentages["train"] = calculate_distribution(new_labels)
crplt.plot_data_distribution(class_counts, class_percentages)

# Single label over sampled and multi label
new_data, new_labels = aug.get_oversampled(data["train"]["data"], data["train"]["labels"], add_multi=True)
class_counts["train"], class_percentages["train"] = calculate_distribution(new_labels)
crplt.plot_data_distribution(class_counts, class_percentages)