# PhysioNet/Computing in Cardiology Challenge 2020
## Classification of 12-lead ECGs

# Setup Notebook

In [2]:
# Import 3rd party libraries
import os
import sys
import json
import random
import numpy as np
from scipy import signal
import matplotlib.pylab as plt
from ipywidgets import interact, fixed, IntSlider

# Local imports
sys.path.insert(0, os.path.dirname(os.path.abspath(os.getcwd())))
from kardioml import DATA_PATH, ECG_LEADS
from kardioml.visualization.plot_formatted_data import waveform_plot_interact

# Configure Notebook
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline
%load_ext autoreload
%autoreload 2

# Plot Waveforms

In [47]:
def waveform_plot(filename_id, filenames, path):
    """Plot measure vs time."""
    # Get filename
    filename = filenames[filename_id]

    # Import waveforms
    waveforms = np.load(os.path.join(path, '{}.npy'.format(filename)))

    # Import meta data
    meta_data = json.load(open(os.path.join(path, '{}.json'.format(filename))))

    # Scale waveforms
    waveforms = waveforms / np.median(waveforms[meta_data['rpeaks'][0], 0])

    # Get label
    label = ''
    if meta_data['labels_short']:
        for idx, lab in enumerate(meta_data['labels_full']):
            if idx == 0:
                label += lab
            else:
                label += ' and ' + lab
    else:
        label = 'Other'

    # Time array
    time = np.arange(waveforms.shape[0]) * 1 / meta_data['fs_resampled']
    
    # Random resample
    waveforms_resample = random_resample(waveform=waveforms, meta_data=meta_data, 
                                         fs_training=meta_data['fs_resampled'], probability=1., max_samples=19000)

    # Setup figure
    fig = plt.figure(figsize=(15, 15), facecolor='w')
    fig.subplots_adjust(wspace=0, hspace=0.05)
    ax1 = plt.subplot2grid((1, 1), (0, 0))

    # ECG
    ax1.set_title(
        'File Name: {}\nAge: {}\nSex: {}\nLabel: {}\nHR: {} BPM'.format(
            filename, meta_data['age'], meta_data['sex'], label, int(meta_data['hr'])
        ),
        fontsize=20,
        loc='left',
        x=0,
    )
    shift = 0
    for channel_id in range(waveforms.shape[1]):
        ax1.plot(waveforms[:, channel_id] + shift, '-k', lw=2)
        ax1.plot(waveforms_resample[:, channel_id] + shift, '-r', lw=2)
        ax1.text(0.1, 0.25 + shift, ECG_LEADS[channel_id], color='red', fontsize=16, ha='left')
        shift += 3

    # Axes labels
    ax1.set_xlabel('Time, seconds', fontsize=24)
    ax1.set_ylabel('ECG Amplitude, mV', fontsize=24)
    # ax1.set_xlim([time.min(), time.max()])
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=14)

    plt.show()

    
def random_resample(waveform, meta_data, fs_training, probability, max_samples):
    """Randomly resample waveform.
    bradycardia=3, sinus bradycardia=20, sinus tachycardia=22
    """
    if (
            meta_data['hr'] != 'nan' and
            all(meta_data['labels_training_merged'][label] == 0 for label in [3, 20, 22])
    ):
        # Get waveform duration
        duration = waveform.shape[0] / fs_training

        # Physiological limits
        hr_new = int(meta_data['hr'] * np.random.uniform(0.9, 1.1))
        if hr_new > 300:
            hr_new = 300
        elif hr_new < 40:
            hr_new = 40
        else:
            pass

        # Get new duration
        duration_new = duration * meta_data['hr'] / hr_new

        # Get number of samples
        
        samples = int(duration_new * fs_training)
        if samples > max_samples:
            samples = max_samples

        # Resample waveform
        waveform = signal.resample_poly(waveform, samples, waveform.shape[0], axis=0).astype(np.float32)

        return waveform
    else:
        return waveform

    
def _coin_flip(probability):
    if random.random() < probability:
        return True
    return False


def waveform_plot_interact(dataset):
    """Launch interactive plotting widget."""
    # Set data path
    path = os.path.join(DATA_PATH, dataset, 'formatted')

    # Get filenames
    filenames = [filename.split('.')[0] for filename in os.listdir(path) if 'npy' in filename]

    interact(
        waveform_plot,
        filename_id=IntSlider(value=0, min=0, max=len(filenames) - 1, step=1,),
        filenames=fixed(filenames),
        path=fixed(path),
    )

In [49]:
# Plot visualization
waveform_plot_interact(dataset='E') # 2680

interactive(children=(IntSlider(value=0, description='filename_id', max=21836), Output()), _dom_classes=('widg…