# Exercise 1: Extending the Pan-Tompkins Algorithm

The Pan-Tompkins algorithm in the previous video is a basic version of the algorithm. In this exercise we will add features to the decision rules to improve its performance.


## Imports

In [None]:
import numpy as np

ts = np.arange(0, 5, 1/100)
sinusoid = 3 * np.sin(2 * np.pi * 1 * ts + np.pi) + 10

In [None]:
import glob
import os

import numpy as np
import pandas as pd
import scipy as sp
import scipy.signal

np.warnings.filterwarnings('ignore')

## Performance Evaluation Helpers

First, we need to build a function that tells us the performance of our QRS estimates. We will optimize for precision and recall. These two functions should help us do that.

In [None]:
def Evaluate(reference_peak_indices, estimate_peak_indices, tolerance_samples=40):
    """Evaluates algorithm performance for a single dataset.
    
    It is not expected that reference and estimate peak indices overlap exactly.
    Instead say a QRS estimate is correct if it is within <tolerance_samples> of
    a reference estimate.
    
    Args:
        reference_peak_indices: (np.array) ground-truth array of QRS complex locations
        estiamte_peak_indices: (np.array) array of QRS complex estimates
        tolerance_samples: (number) How close a QRS estimate needs to be to a reference
            location to be correct.
    Returns:
        n_correct: (number) The number of QRS complexes that were correctly detected
        n_missed: (number) The number of QRS complexes that the algorithm failed
            to detect
        n_extra: (number) The number of spurious QRS complexes detected by the
            algorithm
    """
    # Keep track of the number of QRS peaks that were found correctly
    n_correct = 0
    # ... that were missed
    n_missed = 0
    # ... and that are spurious
    n_extra = 0
    
    # Loop counters
    i, j = 0, 0
    while (i < len(reference_peak_indices)) and (j < len(estimate_peak_indices)):
        # Iterate through the arrays of QRS peaks, counting the number of peaks
        # that are correct, missed, and extra.
        ref = reference_peak_indices[i]
        est = estimate_peak_indices[j]
        if abs(ref - est) < tolerance_samples:
            # If the reference peak and the estimate peak are within <tolerance_samples>,
            # then we mark this beat correctly detected and move on to the next one.
            n_correct += 1
            i += 1
            j += 1
            continue
        if ref < est:
            # Else, if they are farther apart and the reference is before the estimate,
            # then the detector missed a beat and we advance the reference array.
            n_missed += 1
            i += 1
            continue
        # Else, the estimate is before the reference. This means we found an extra beat
        # in the estimate array. We advance the estimate array to check the next beat.
        j += 1
        n_extra += 1
    # Don't forget to count the number of missed or extra peaks at the end of the array.
    n_missed += len(reference_peak_indices[i:])
    n_extra += len(estimate_peak_indices[j:])
    return n_correct, n_missed, n_extra

Now we need a function that can compute precision and recall for us.

In [None]:
def PrecisionRecall(n_correct, n_missed, n_extra):
    # TODO: Compute precision and recall from the input arguments.
    precision = None
    recall = None
    return precision, recall

## Pan-Tompkins Algorithm

We will start with the same algorithm that you saw in the last video. This starter code differs only in that we do not *LocalizeMaxima* on the output peaks. This is because for this dataset the QRS complexes could be pointing up or down and if we try to find the maxima when the QRS complex is pointing downward we will hurt our algorithm performance. Instead we will be happy with the approximate QRS locations that our algorithm detects.

The current version of the algorithm has a precision and recall of 0.89 and 0.74. Verify this by running the next cell. Your task is to improve the performance of the algorithm by adding the following features.

### Refractory Period Blanking
Recall from the physiology lesson that the QRS complex is a result of ventricular depolarization, and that cellular depolarization happens when ions travel across the cell membrane. There is a physiological constraint on how soon consecutive depolarization can occur. This constraint is 200 ms. Read more about it [here](https://en.wikipedia.org/wiki/Refractory_period_(physiology)#Cardiac_refractory_period). We can take advantage of this phenomenon in our algorithm by removing detections that occur within 200ms of another one. Preserve the larger detection.

### Adaptive Thresholding
The QRS complex height can change over time as contact with the electrodes changes or shifts. Instead of using a fixed threshold, we should use one that changes over time. Make the detection threshold 70% of the average peak height for the last 8 peaks.

### T-Wave Discrimination
One error mode is to detect T-waves as QRS complexes. We can avoid picking T-waves by doing the following:
  * Find peaks that follow a previous one by 360ms or less
  * Compute the maximum absolute slope within 60ms of each peak. Eg `np.max(np.abs(np.diff(ecg[peak - 60ms: peak + 60ms])))`
  * If the slope of the second peak is less than half of the slope of the first peak, discard the second peak as a T-wave
Read another description of this technique [here](https://en.wikipedia.org/wiki/Pan%E2%80%93Tompkins_algorithm#T_wave_discrimination)

After implementing these three techniques you should see a significant increase in precision and recall. I ended up with 0.95 and 0.87. See if you can beat that! 

In [None]:
def BandpassFilter(signal, fs=300):
    """Bandpass filter the signal between 5 and 15 Hz."""
    b, a = sp.signal.butter(3, (5, 15), btype='bandpass', fs=fs)
    return sp.signal.filtfilt(b, a, signal)

def MovingSum(signal, fs=300):
    """Moving sum operation with window size of 150ms."""
    n_samples = int(round(fs * 0.150))
    return pd.Series(signal).rolling(n_samples, center=True).sum().values

def FindPeaks(signal, order=10):
    """A simple peak detection algorithm."""
    msk = (signal[order:-order] > signal[:-order * 2]) & (signal[order:-order] > signal[order * 2:])
    for o in range(1, order):
        msk &= (signal[order:-order] > signal[o: -order * 2 + o])
        msk &= (signal[order:-order] > signal[order * 2 - o: -o])
    return msk.nonzero()[0] + order

def ThresholdPeaks(filtered_signal, peaks):
    """Threshold detected peaks to select the QRS complexes."""
    thresh = np.mean(filtered_signal[peaks])
    return peaks[filtered_signal[peaks] > thresh]

def AdaptiveThresholdPeaks(filtered_signal, peaks):
    # TODO: Implement adaptive thresholding
    pass
    
def RefractoryPeriodBlanking(filtered_signal, peaks, fs, refractory_period_ms=200):
    # TODO: Implement refractory period blanking
    pass

def TWaveDiscrimination(signal, peaks, fs, twave_window_ms=360, slope_window_ms=60):
    # TODO: Implement t-wave discrimination
    pass

def PanTompkinsPeaks(signal, fs):
    """Pan-Tompkins QRS complex detection algorithm."""
    filtered_signal = MovingSum(
        np.square(
            np.diff(
                BandpassFilter(signal, fs))), fs)
    peaks = FindPeaks(filtered_signal)
    #peaks = RefractoryPeriodBlanking(filtered_signal, peaks, fs)  # TODO: Uncomment this line
    peaks = ThresholdPeaks(filtered_signal, peaks)                 # TODO: Remove this line
    #peaks = AdaptiveThresholdPeaks(filtered_signal, peaks)        # TODO: Uncomment this line
    #peaks = TWaveDiscrimination(signal, peaks, fs)                # TODO: Uncomment this line
    return peaks

## Load Data and Evaluate Performance

As we add features to the algorithm we can continue to evaluate it and see the change in performance.  Use the code below to compute an overall precision and recall for QRS detection. You must first implement the `PrecisionRecall` function above.

In [None]:
# This dataset is sampled at 300 Hz.
fs = 300
files = glob.glob('../../data/cinc/*.npz')

# Keep track of the total number of correct, missed, and extra detections.
total_correct, total_missed, total_extra = 0, 0, 0

for i, fl in enumerate(files):
    # For each file, load the data...
    with np.load(fl) as npz:
        ecg = npz['ecg']
        reference_peak_indices = npz['qrs']
    # Compute our QRS location estimates...
    estimate_peak_indices = PanTompkinsPeaks(ecg, fs)

    # Compare our estimates against the reference...
    n_correct, n_missed, n_extra = Evaluate(reference_peak_indices, estimate_peak_indices)

    # And add them to our running totals.
    total_correct += n_correct
    total_missed += n_missed
    total_extra += n_extra
    print('\r{}/{} files processed...'.format(i+1, len(files)), end='')
print('') # print a newline

# Compute and report the overall performance.
precision, recall = PrecisionRecall(total_correct, total_missed, total_extra)
print('Total performance:\n\tPrecision = {:0.2f}\n\tRecall = {:0.2f}'.format(precision, recall))