In [1]:
from typing import Tuple
import numpy as np
import pickle as pkl
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from database_tools.processing.cardiac import estimate_spo2
from database_tools.processing.detect import detect_peaks, detect_notches
from database_tools.tools.dataset import ConfigMapper
from gcp_utils import constants
from database_tools.processing.modify import bandpass

In [2]:
def _preprocess_ppg(sig: list, cm: ConfigMapper):
    """Remove data too far from the signal medium (handles large noise and motion artifacts).
       Also, calculate peaks in longest run of good data.
    """
    # Prep data
    sig = np.array(sig).reshape(-1)
    sig = np.array(sig, dtype=np.float32)
    sig[np.isnan(sig)] = 0

    # Remove signal outliers
    sig = _remove_sig_outliers(sig, cm.deploy.sig_amp_thresh)
    # return sig, dict()

    # get, remove outliers from, and order peaks
    idx = detect_peaks(bandpass(sig, low=0.5, high=8.0, fs=200))
    peaks, troughs = idx['peaks'], idx['troughs']
    peaks = _remove_peak_outliers(sig, peaks, cm.deploy.peak_amp_thresh, cm.deploy.peak_dist_thresh)
    troughs = _remove_peak_outliers(sig, troughs, cm.deploy.peak_amp_thresh, cm.deploy.peak_dist_thresh)

    idx = dict(peaks=peaks, troughs=troughs)
    idx = _get_ordered_idx(idx)
    return sig, idx

def _get_ordered_idx(idx: dict) -> Tuple[list, list]:
    """Takes a list of peaks and troughs and removes
       out of order elements. Regardless of which occurs first,
       a peak or a trough, a peak must be followed by a trough
       and vice versa.

    Algorithm (if peaks start first)
    ---------
    - Loop through values starting with first peak
    - Is peak before valley?
        YES -> Is next peak after valley?
            YES -> Append peak and valley. Get next peak and valley.
            NO  -> Get next peak.
        NO  -> Get next valley.

    Args:
        peaks (list): Signal peaks.
        troughs (list): Signal troughs.

    Returns:
        first_repaired (list): Input with out of order items removed.
        second_repaired (list): Input with out of order items removed.

        Items are always returned with peaks idx as first tuple item.
    """
    order_lists = lambda x, y : (x, y, 0, 1) if x[0] < y[0] else (y, x, 1, 0)

    # Configure algorithm to start with lowest index.
    peaks, troughs = idx['peaks'], idx['troughs']

    try:
        first, second, flag1, flag2 = order_lists(peaks, troughs)
    except IndexError:
        return dict(peaks=np.array([]), troughs=np.array([]))

    result = dict(first=[], second=[])
    i, j = 0, 0
    for _ in enumerate(first):
        try:
            poi_1, poi_2 = first[i], second[j]
            if poi_1 < poi_2:  # first point of interest is before second
                poi_3 = first[i + 1]
                if poi_2 < poi_3:  # second point of interest is before third
                    result['first'].append(poi_1)
                    result['second'].append(poi_2)
                    i += 1; j += 1
                else:
                    i += 1
            else:
                j += 1
        except IndexError: # always thrown in last iteration
            result['first'].append(poi_1)
            result['second'].append(poi_2)

    # remove duplicates and return as peaks, troughs
    result['first'] = sorted(list(set(result['first'])))
    result['second'] = sorted(list(set(result['second'])))
    result = [result['first'], result['second']]
    return dict(peaks=np.array(result[flag1]), troughs=np.array(result[flag2]))

def _remove_sig_outliers(sig, amp_thresh, buffer=50):
    med = np.median(sig)
    print(med)
    mask = (sig > (med + amp_thresh)) | (sig < (med - amp_thresh))

    temp = sig.copy()
    temp[np.where(mask)] = 0
    temp[np.where(~mask)] = 1
    run_values, run_starts, run_lengths = _find_runs(temp)

    if run_lengths.shape[0] < 2:
        return sig
    else:
        for idx, val in enumerate(run_values):
            if (val == 0) | ( (val == 1) & (run_lengths[idx] < 800) ):
                if idx == (run_values.shape[0] - 1):
                    i = run_starts[idx]
                    sig[i::] = med
                else:
                    i, j = run_starts[idx], run_starts[idx+1]
                    sig[i:j+buffer] = med
        return sig

def _remove_peak_outliers(sig, idx, amp_thresh, dist_thresh):
    # remove indices whose amplitude is too far from mean
    values = sig[idx]
    mean = np.mean(values)
    mask = np.where( (values < (mean + amp_thresh)) & (values > (mean - amp_thresh)) )
    idx = idx[mask]

    # remove indices that are too from from each other
    diff = np.diff(idx, prepend=idx[0] - 10000, append=idx[-1] + 10000)
    delta = int(np.mean(diff[1:-1]))

    valid = []
    for i, distance1 in enumerate(diff[0:-1]):
        distance2 = diff[i+1]
        if (distance1 <= (delta + dist_thresh)) & (distance2 <= (delta + dist_thresh)):
            valid.append(idx[i])
        else:
            if len(valid) > 0:
                break
    return valid

def _find_runs(x):
    n = x.shape[0]
    loc_run_start = np.empty(n, dtype=bool)
    loc_run_start[0] = True
    np.not_equal(x[:-1], x[1:], out=loc_run_start[1:])
    
    run_starts = np.nonzero(loc_run_start)[0]
    run_values = x[loc_run_start]
    run_lengths = np.diff(np.append(run_starts, n))
    return (run_values, list(run_starts), run_lengths)

def predict_cardiac_metrics(red: list, ir: list, red_idx: list, ir_idx: list, cm: ConfigMapper) -> dict:
    red, ir = np.array(red), np.array(ir)
    pulse_rate = _calc_pulse_rate(ir_idx, fs=cm.deploy.bpm_fs)
    spo2, r = _calc_spo2(red, ir, red_idx, ir_idx)
    result = {
        'pulse_rate': int(pulse_rate),
        'spo2': float(spo2),
        'r': float(r)
    }
    return result

def _calc_pulse_rate(idx, fs):
    pulse_rate = fs / np.mean(np.diff(idx['peaks'])) * 60
    return pulse_rate

def _calc_spo2(ppg_red, ppg_ir, red_idx, ir_idx, method='linear'):
    red_peaks, red_troughs = red_idx['peaks'], red_idx['troughs']
    ir_peaks, ir_troughs = ir_idx['peaks'], ir_idx['troughs']

    # choose where to calculate based on shorted list
    options = [red_peaks, red_troughs, ir_peaks, ir_troughs]
    lengths = [len(x) for x in options]

    # Return error code if missing needed value
    if 0 in lengths:
        return (-1, -1)

    i = int(len(options[np.argmin(lengths)]) / 2)

    red_high, red_low = np.max(ppg_red[red_peaks[i]]), np.min(ppg_red[red_troughs[i]])
    ir_high, ir_low = np.max(ppg_ir[ir_peaks[i]]), np.min(ppg_ir[ir_troughs[i]])

    ac_red = red_high - red_low
    ac_ir = ir_high - ir_low

    r = ( ac_red / red_low ) / ( ac_ir / ir_low )

    if method == 'linear':
        spo2 = round(104 - (17 * r), 1)
    elif method == 'curve':
        spo2 = (1.596 * (r ** 2)) + (-34.670 * r) + 112.690
    return (spo2, r)

In [3]:
def plot_data(red, ir, red_idx, ir_idx):
    fig = make_subplots(rows=2, cols=1)
    fig.add_scatter(y=red, name='ppg_red', row=1, col=1)
    fig.add_scatter(x=red_idx['peaks'], y=red[red_idx['peaks']], mode='markers', row=1, col=1)
    fig.add_scatter(x=red_idx['troughs'], y=red[red_idx['troughs']], mode='markers', row=1, col=1)

    fig.add_scatter(y=ir, name='ppg_red', row=2, col=1)
    fig.add_scatter(x=ir_idx['peaks'], y=ir[ir_idx['peaks']], mode='markers', row=2, col=1)
    fig.add_scatter(x=ir_idx['troughs'], y=ir[ir_idx['troughs']], mode='markers', row=2, col=1)
    fig.show()

In [11]:
frame = constants.NEW_BPM_FRAME

In [12]:
cm = ConfigMapper('/home/cam/Documents/gcp_utils/gcp_utils/config.ini')

red_frame, ir_frame = frame['red_frame'], frame['ir_frame']

# Prep data
sig = np.array(red_frame).reshape(-1)
sig[np.isnan(sig)] = 0

# Remove signal outliers
clean = _remove_sig_outliers(sig, cm.deploy.sig_amp_thresh)

223530.0


In [14]:
from database_tools.processing.modify import bandpass
idx = detect_peaks(bandpass(clean, low=0.5, high=8.0, fs=200))

In [16]:
fig = go.Figure()

fig.add_scatter(y=clean)
fig.add_scatter(x=idx['peaks'], y=clean[idx['peaks']], mode='markers')
fig.add_scatter(x=idx['troughs'], y=clean[idx['troughs']], mode='markers')

In [7]:
red_clean, red_idx = _preprocess_ppg(red_frame, cm)
ir_clean, ir_idx = _preprocess_ppg(ir_frame, cm)

result = predict_cardiac_metrics(red_clean, ir_clean, red_idx, ir_idx, cm)

print(f'spo2: {result["spo2"]}, r: {result["r"]}, pulse rate: {result["pulse_rate"]}')
plot_data(red_clean, ir_clean, red_idx, ir_idx)

223530.0
0.0 475 0
1
3
0 475
1.0 12 475
1
3
475 487
0.0 53 487
1
3
487 540
1.0 5 540
1
3
540 545
0.0 48 545
1
3
545 593
1.0 76 593
1
3
593 669
0.0 47 669
1
3
669 716
1.0 3373 716
0.0 11 4089
1
2
243706.5
0.0 548 0
1
3
0 548
1.0 14 548
1
3
548 562
0.0 18 562
1
3
562 580
1.0 11 580
1
3
580 591
0.0 156 591
1
3
591 747
1.0 10 747
1
3
747 757
0.0 28 757
1
3
757 785
1.0 3300 785
0.0 4 4085
1
3
4085 4089
1.0 1 4089
1
3
4089 4090
0.0 10 4090
1
2
spo2: -1.0, r: -1.0, pulse rate: 50


In [None]:
from gcp_utils.tools.utils import query_collection
from firebase_admin import firestore, initialize_app

initialize_app()
database = firestore.client()
col = database.collection(u'bpm_data_test').document('v2iHQmPIVfVW0IuhfZ1yCIegsB52').collection('frames')

docs = query_collection(col, 'status', '==', 'new')
data = [doc.to_dict() for doc in docs]