In [1]:
from cx_model import CXDetector

import numpy as np
import random

np.random.seed(42)
random.seed(42)

import wfdb
from scipy.signal import butter, lfilter, medfilt
from sklearn.preprocessing import RobustScaler

import os

import warnings
warnings.filterwarnings('ignore')

%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib.patches as patches

  from pandas.core import datetools


# Training the model

In [2]:
def butter_bandpass(lowcut, highcut, fs, order=5):
    nyq = 0.5 * fs
    low = lowcut / nyq
    high = highcut / nyq
    b, a = butter(order, [low, high], btype='band')
    return b, a


def butter_bandpass_filter(data, lowcut, highcut, fs, order=5):
    b, a = butter_bandpass(lowcut, highcut, fs, order=order)
    y = lfilter(b, a, data)
    return y

def _read_signal(file, low_freq, high_freq, sample_freq):
    record = wfdb.rdrecord(file)
    annotation = wfdb.rdann(file, 'atr')
    annotated_intervals = list(zip(annotation.sample, annotation.aux_note))
    
    signal_ch1 = record.p_signal[:, 0][1500:-1500]
    signal_ch2 = record.p_signal[:, 2][1500:-1500]
    signal_ch3 = record.p_signal[:, 4][1500:-1500]
    
    signal_ch1 = butter_bandpass_filter(signal_ch1, low_freq, 
                                        high_freq, sample_freq, order=4)
    signal_ch2 = butter_bandpass_filter(signal_ch2, low_freq, 
                                        high_freq, sample_freq, order=4)
    signal_ch3 = butter_bandpass_filter(signal_ch3, low_freq, 
                                        high_freq, sample_freq, order=4)
    
    for i, ann in enumerate(annotated_intervals):
        annotated_intervals[i] = (ann[0] - 1500, ann[1]) 

    signal_ch1 = medfilt(signal_ch1)
    signal_ch2 = medfilt(signal_ch2)
    signal_ch3 = medfilt(signal_ch3)

    ch1_scaler = RobustScaler()
    ch2_scaler = RobustScaler()
    ch3_scaler = RobustScaler()

    signal_ch1 = ch1_scaler.fit_transform(signal_ch1.reshape(-1, 1)).reshape(-1, )
    signal_ch2 = ch2_scaler.fit_transform(signal_ch2.reshape(-1, 1)).reshape(-1, )
    signal_ch3 = ch3_scaler.fit_transform(signal_ch3.reshape(-1, 1)).reshape(-1, )
        
    return signal_ch1, signal_ch2, signal_ch3, annotated_intervals


def _read_clinical(file):
    start_idx = 0
    with open(file+'.hea', 'r') as ifp:
        lines = ifp.readlines()
        
    for line_idx, line in enumerate(lines):
        if line.startswith('#'):
            start_idx = line_idx
            break

    names = []
    values = []
    for line in lines[start_idx+1:]:
        _, name, value = line.split()
        names.append(name)
        values.append(value)

    return names, values

def _process_clinical_df(clin_df):
    clin_df = clin_df.drop(['Gestation'], axis=1)
    clin_df = clin_df.replace('None', np.NaN)
    clin_df = clin_df.replace('N/A', np.NaN)
    clin_df['ID'] = clin_df['RecID']
    for col in ['Rectime', 'Age', 'Abortions', 'Weight']:
        clin_df[col] = clin_df[col].astype(float)
    clin_df = clin_df.drop_duplicates()
    clin_df = clin_df[['file', 'Rectime', 'Age', 'Parity', 'Abortions']]
    return clin_df


def partition_data(directory, n_splits=5):
    files = np.unique([x.split('.')[0] for x in os.listdir(directory)])
    p_files, t_files, n_files = [], [], []
    for file in files:
        if file[-4] == 'n':
            n_files.append(file)
        elif file[-4] == 'p':
            p_files.append(file)
        else:
            t_files.append(file)

    np.random.shuffle(p_files)
    np.random.shuffle(t_files)

    folds = []
    for split in range(n_splits):
        start = lambda x: int(x * (split / n_splits))
        end   = lambda x: int(x * ((split + 1) / n_splits))
        if split == n_splits - 1:
            test_p_files = p_files[start(len(p_files)):]
            test_t_files = t_files[start(len(t_files)):]
        else:
            test_p_files = p_files[start(len(p_files)):end(len(p_files))]
            test_t_files = t_files[start(len(t_files)):end(len(t_files))]

        train_p_files = sorted(list(set(p_files) - set(test_p_files)))
        train_t_files = sorted(list(set(t_files) - set(test_t_files)))

        test_files = test_t_files + test_p_files
        train_files = train_t_files + train_p_files

        folds.append((['{}{}{}'.format(directory, os.sep, x) for x in train_files], 
                      ['{}{}{}'.format(directory, os.sep, x) for x in test_files]))

    return folds

In [None]:
folds = partition_data('tpehgts')
train_files, test_files = folds[0]
detector = CXDetector(20, 0.05, 4.0, 750, 125, 100, 100, _read_signal, _read_clinical, _process_clinical_df)
features = detector.fit(train_files)

(2315, 100)
it		avg		std		max		time
1		0.2731		0.025		0.297963	47.5285
2		0.2904		0.007		0.298383	42.9142
3		0.295		0.005		0.298636	78.2362
4		0.2986		0.002		0.300152	62.9635


In [None]:
print(list(features.columns))

# Evaluating the model

In [None]:
from sklearn.metrics import roc_auc_score

def get_labels_preds(intervals, predictions):
    preds = []
    labels = []
    for (start_idx, start_type), (end_idx, end_type) in zip(intervals[::2], intervals[1::2]):
        if start_idx < 0 or end_idx >= len(predictions):
            continue
        if start_type[-1] == 'C':
            labels.extend([1]*(end_idx - start_idx))
            preds.extend(predictions.loc[list(range(start_idx, end_idx)), 'pred'].values)
        else:
            labels.extend([0]*(end_idx - start_idx))
            preds.extend(predictions.loc[list(range(start_idx, end_idx)), 'pred'].values)

    return labels, preds

def _load_pred_labels_intervals(predictions):
    _, _, _, intervals = _read_signal(predictions['file'].values[0], 0.05, 4.0, 20.0)
    labels, preds = get_labels_preds(intervals, predictions)
    return labels, preds, intervals

def unweighted_auc(predictions):
    all_labels, all_preds = [], []
    for file in np.unique(predictions['file']):
        preds = predictions[predictions['file'] == file].set_index('index', drop=True)
        labels, preds, intervals = _load_pred_labels_intervals(preds)
        all_labels.extend(labels)
        all_preds.extend(preds)

    mask = ~np.isnan(all_preds)
    return roc_auc_score(np.array(all_labels)[mask], np.array(all_preds)[mask])

def create_plots(predictions):
    def create_plot(signal_ch1, signal_ch2, signal_ch3, predictions, intervals):
        f, ax = plt.subplots(4, 1, sharex=True, figsize=(15,3))
        ax[0].plot(signal_ch1)
        ax[1].plot(signal_ch2)
        ax[2].plot(signal_ch3)

        _max = np.max([np.max(signal_ch1), np.max(signal_ch2), np.max(signal_ch3)])
        _min = np.min([np.min(signal_ch1), np.min(signal_ch2), np.min(signal_ch3)])

        for (start_idx, start_type), (end_idx, end_type) in zip(intervals[::2], intervals[1::2]):
            if start_type[-1] == 'C':
                color = 'g'
            elif start_type == '(c)':
                color = 'y'
            else:
                color = 'r'

            for k in range(3):
                rect = patches.Rectangle((start_idx, _min), end_idx - start_idx, _max - _min, facecolor=color, alpha=0.5)
                ax[k].add_patch(rect)

        ax[3].plot(predictions)
        plt.show()
        plt.close()
        
    for file in np.unique(predictions['file']):
        sign_ch1, sign_ch2, sign_ch3, intervals = _read_signal(file, 0.05, 4.0, 20.0)
        create_plot(sign_ch1, sign_ch2, sign_ch3, predictions[predictions['file'] == file]['pred'].values, intervals)

In [None]:
preds = detector.predict(test_files)
print(unweighted_auc(preds))
create_plots(preds) #0.8077784799594918

In [None]:
"""

def generate_predictions(file, X, idx, model, WINDOW_SIZE, DATA_DIR, OUTPUT_DIR):
    for col in ['ID', 'file']:
        if col in X.columns:
            X = X.drop(col, axis=1)

    signal_ch1, signal_ch2, signal_ch3, annotated_intervals = read_signal(DATA_DIR + '/' + file)
    ts_predictions = np.empty((len(signal_ch1),), dtype=object)
    predictions = model.predict_proba(X)[:, 1]
    for pred, x in zip(predictions, idx):
      for i in range(x, x+WINDOW_SIZE):
        if ts_predictions[i] is None:
          ts_predictions[i] = [pred]
        else:
          ts_predictions[i].append(pred)
    
    for i in range(len(signal_ch1)):
      if ts_predictions[i] is None:
        ts_predictions[i] = last_value
      else:
        avg = np.mean(ts_predictions[i])
        ts_predictions[i] = avg
        last_value = avg

    pd.Series(ts_predictions).to_csv('{}/{}.csv'.format(OUTPUT_DIR, file))
    create_plot(signal_ch1, signal_ch2, signal_ch3, ts_predictions, annotated_intervals, '{}/{}.png'.format(OUTPUT_DIR, file))

"""