In [1]:
import typing
import numpy as np
import pickle as pkl
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from database_tools.processing.cardiac import estimate_spo2
from database_tools.processing.detect import detect_peaks, detect_notches
from database_tools.tools.dataset import ConfigMapper
from gcp_utils import constants

In [2]:
def predict_cardiac_metrics(red, ir, red_idx, ir_idx, cm: ConfigMapper) -> dict:
    pulse_rate = _calc_pulse_rate(ir_idx, fs=cm.deploy.bpm_fs)
    spo2, r = _calc_spo2(red, ir, red_idx, ir_idx)
    result = {
        'pulse_rate': int(pulse_rate),
        'spo2': float(spo2),
        'r': float(r)
    }
    return result

def _preprocess_for_cardiac(sig: list, cm: ConfigMapper):
    """Remove data too far from the signal medium (handles large noise and motion artifacts).
       Also, calculate peaks in longest run of good data.
    """
    # Prep data
    sig = np.array(sig).reshape(-1)
    sig = np.array(sig, dtype=np.float32)
    sig[np.isnan(sig)] = 0

    # Remove signal outliers
    sig, i, j = _remove_sig_outliers(sig, cm.deploy.sig_amp_thresh)

    # get, remove outliers from, and order peaks
    idx = detect_peaks(sig[i:j])
    peaks, troughs = idx['peaks'] + i, idx['troughs'] + i
    peaks = _remove_peak_outliers(sig, peaks, cm.deploy.peak_amp_thresh, cm.deploy.peak_dist_thresh)
    troughs = _remove_peak_outliers(sig, troughs, cm.deploy.peak_amp_thresh, cm.deploy.peak_dist_thresh)

    idx = dict(peaks=peaks, troughs=troughs)
    idx = _get_ordered_idx(idx)

    peaks, troughs = idx['peaks'], idx['troughs']

    peaks_troughs = np.concatenate([peaks, troughs])
    min_idx = np.min(peaks_troughs)
    max_idx = np.max(peaks_troughs)
    sig[0:min_idx-1] = sig[i]
    sig[max_idx+1::] = sig[j]
    return sig, idx

def _get_ordered_idx(idx: dict) -> typing.Tuple[list, list]:
    """Takes a list of peaks and troughs and removes
       out of order elements. Regardless of which occurs first,
       a peak or a trough, a peak must be followed by a trough
       and vice versa.

    Algorithm (if peaks start first)
    ---------
    - Loop through values starting with first peak
    - Is peak before valley?
        YES -> Is next peak after valley?
            YES -> Append peak and valley. Get next peak and valley.
            NO  -> Get next peak.
        NO  -> Get next valley.

    Args:
        peaks (list): Signal peaks.
        troughs (list): Signal troughs.

    Returns:
        first_repaired (list): Input with out of order items removed.
        second_repaired (list): Input with out of order items removed.

        Items are always returned with peaks idx as first tuple item.
    """
    order_lists = lambda x, y : (x, y, 0, 1) if x[0] < y[0] else (y, x, 1, 0)

    # Configure algorithm to start with lowest index.
    peaks, troughs = idx['peaks'], idx['troughs']

    try:
        first, second, flag1, flag2 = order_lists(peaks, troughs)
    except IndexError:
        return dict(peaks=np.array([]), troughs=np.array([]))

    result = dict(first=[], second=[])
    i, j = 0, 0
    for _ in enumerate(first):
        try:
            poi_1, poi_2 = first[i], second[j]
            if poi_1 < poi_2:  # first point of interest is before second
                poi_3 = first[i + 1]
                if poi_2 < poi_3:  # second point of interest is before third
                    result['first'].append(poi_1)
                    result['second'].append(poi_2)
                    i += 1; j += 1
                else:
                    i += 1
            else:
                j += 1
        except IndexError: # always thrown in last iteration
            result['first'].append(poi_1)
            result['second'].append(poi_2)

    # remove duplicates and return as peaks, troughs
    result['first'] = sorted(list(set(result['first'])))
    result['second'] = sorted(list(set(result['second'])))
    result = [result['first'], result['second']]
    return dict(peaks=np.array(result[flag1]), troughs=np.array(result[flag2]))

def _remove_sig_outliers(sig, amp_thresh):
    med = np.median(sig)
    mask = (sig > (med + amp_thresh)) | (sig < (med - amp_thresh))

    temp = sig.copy()
    temp[np.where(mask)] = 0
    temp[np.where(~mask)] = 1
    _, run_starts, run_lengths = _find_runs(temp)

    # find the longest run and set all other data to median
    k = np.argmax(run_lengths)
    i, j = run_starts[k], run_starts[k+1]
    sig[0:i] = med
    sig[j::] = med
    return sig, i, j

def _remove_peak_outliers(sig, idx, amp_thresh, dist_thresh):
    # remove indices whose amplitude is too far from mean
    values = sig[idx]
    mean = np.mean(values)
    mask = np.where( (values < (mean + amp_thresh)) & (values > (mean - amp_thresh)) )
    idx = idx[mask]

    # remove indices that are too from from each other
    diff = np.diff(idx, prepend=idx[0] - 10000, append=idx[-1] + 10000)
    delta = int(np.mean(diff[1:-1]))
    print(delta)
    valid = []
    for i, distance1 in enumerate(diff[0:-1]):
        distance2 = diff[i+1]
        if (distance1 <= (delta + dist_thresh)) & (distance2 <= (delta + dist_thresh)):
            valid.append(idx[i])
        else:
            if len(valid) > 0:
                break
    return valid

def _find_runs(x):
    n = x.shape[0]
    loc_run_start = np.empty(n, dtype=bool)
    loc_run_start[0] = True
    np.not_equal(x[:-1], x[1:], out=loc_run_start[1:])
    
    run_starts = np.nonzero(loc_run_start)[0]
    run_values = x[loc_run_start]
    run_lengths = np.diff(np.append(run_starts, n))
    return (run_values, list(run_starts), run_lengths)

def _calc_pulse_rate(idx, fs):
    pulse_rate = fs / np.mean(np.diff(idx['peaks'])) * 60
    return pulse_rate

def _calc_spo2(ppg_red, ppg_ir, red_idx, ir_idx):
    red_peaks, red_troughs = red_idx['peaks'], red_idx['troughs']
    ir_peaks, ir_troughs = ir_idx['peaks'], ir_idx['troughs']

    # TODO: Choose based off of shorted list
    i = int(len(ir_peaks) / 2)

    red_high, red_low = np.max(ppg_red[red_peaks[i]]), np.min(ppg_red[red_troughs[i]])
    ir_high, ir_low = np.max(ppg_ir[ir_peaks[i]]), np.min(ppg_ir[ir_troughs[i]])

    ac_red = red_high - red_low
    ac_ir = ir_high - ir_low

    r = ( ac_red / red_low ) / ( ac_ir / ir_low )
    spo2 = round(104 - (17 * r), 1)
    return (spo2, r)

In [3]:
def plot_data(red, ir, red_idx, ir_idx):
    fig = make_subplots(rows=2, cols=1)
    fig.add_scatter(y=red, name='ppg_red', row=1, col=1)
    fig.add_scatter(x=red_idx['peaks'], y=red[red_idx['peaks']], mode='markers', row=1, col=1)
    fig.add_scatter(x=red_idx['troughs'], y=red[red_idx['troughs']], mode='markers', row=1, col=1)

    fig.add_scatter(y=ir, name='ppg_red', row=2, col=1)
    fig.add_scatter(x=ir_idx['peaks'], y=ir[ir_idx['peaks']], mode='markers', row=2, col=1)
    fig.add_scatter(x=ir_idx['troughs'], y=ir[ir_idx['troughs']], mode='markers', row=2, col=1)
    fig.show()

In [4]:
frame = constants.NEW_BPM_FRAME

In [5]:
cm = ConfigMapper('/home/cam/Documents/gcp_utils/gcp_utils/config.ini')

red_frame, ir_frame = frame['red_frame'], frame['ir_frame']

red_clean, red_idx = _preprocess_for_cardiac(red_frame, cm)
ir_clean, ir_idx = _preprocess_for_cardiac(ir_frame, cm)

result = predict_cardiac_metrics(red_clean, ir_clean, red_idx, ir_idx, cm)

plot_data(red_clean, ir_clean, red_idx, ir_idx)

222
235
217
171


In [10]:
result

{'pulse_rate': 71, 'spo2': 97.7, 'r': 0.36972686648368835}

In [7]:
red_high, red_low = 224304, 223715
ir_high, ir_low = 244011, 242606

ac_red = red_high - red_low
ac_ir = ir_high - ir_low

r = ( ac_red / red_low ) / ( ac_ir / ir_low )
spo2 = round(104 - (17 * r), 1)
spo2, r

(96.3, 0.454616719305503)

In [8]:
a = 1.5958422
b = -34.6596622
c = 112.6898759

a*pow(r,2) + b*r + c

97.26283683777868

In [9]:
fig = go.Figure()

fig.add_scatter(y=red_frame)