# 1. Import the required libraries

In [None]:
# Standard code libraries
import os
import platform
import glob

import matplotlib.pyplot as plt

# Custom code libraries from ReSurfEMG
from resurfemg.config.config import Config
from resurfemg.data_connector.tmsisdk_lite import Poly5Reader
from resurfemg.data_classes.data_classes import (
VentilatorDataGroup, EmgDataGroup)

%matplotlib widget

## 2. Load the ventilator and sEMG data

In [None]:
# Identify all recordings available for the selected patient/measurement_date

# Root directory for test data
config = Config()
root_patient_data_directory = \
    config.get_directory('test_data')

if platform.system() == 'Windows':
    path_sep = "\\"
else:
    path_sep = '/'

emg_pattern = os.path.join(root_patient_data_directory, '**/*.Poly5')
emg_and_vent_files = glob.glob(emg_pattern, recursive=True)

emg_files = []
vent_files = []

for file in emg_and_vent_files:
    if 'vent' in file:
        vent_files.append(file)
    else:
        emg_files.append(file)

emg_file_chosen = emg_files[0]
vent_file_chosen = vent_files[0]
print("The chosen files are:\n", emg_file_chosen, '\n', vent_file_chosen)

In [None]:
# Load the EMG and ventilator data recordings from the selected folders.
data_emg = Poly5Reader(emg_file_chosen)
data_vent = Poly5Reader(vent_file_chosen)
data_emg_samples = data_emg.samples[:data_emg.num_samples]
fs_emg = data_emg.sample_rate
data_vent_samples = data_vent.samples[:data_vent.num_samples]
fs_vent = data_vent.sample_rate

# Define the time series of the EMG and ventilator recordings
y_emg = data_emg_samples
y_vent = data_vent_samples

# Define the time axes
t_emg = [i/fs_emg for i in range(len(y_emg[0, :]))]
t_vent = [i/fs_vent for i in range(len(y_vent[0, :]))]

In [None]:
# Store the EMG data in a group of TimeSeries objects
emg_timeseries = EmgDataGroup(
    y_emg,
    fs=fs_emg,
    labels=['ECG', 'EMGdi'],
    units=2*['uV'])


# Data is stored in:
# fs        --> emg_timeseries.fs
# labels    --> emg_timeseries.labels
# units     --> emg_timeseries.units
# ECG       --> emg_timeseries.channels[0] = TimeSeries object
# EMGdi     --> emg_timeseries.channels[1] = TimeSeries object
#   with:
#   emg_timeseries.channels[0].fs = fs
#   emg_timeseries.channels[0].y_raw = y_emg[0, :]
#   emg_timeseries.channels[0].t_data = time axis data for y_raw
#   emg_timeseries.channels[0].y_clean = None
#   emg_timeseries.channels[0].y_baseline = None
#   emg_timeseries.channels[0].y_baseline = None
#   etc.
# From the labels 'ECG' is automatically detected.

In [None]:
# Store the ventilator data in a group of TimeSeries objects
vent_timeseries = VentilatorDataGroup(
    y_vent,
    fs=fs_vent,
    labels=['Paw', 'F', 'Vvent'],
    units=['cmH2O', 'L/s', 'L'])

# Data is stored in:
# fs        --> vent_timeseries.fs
# labels    --> vent_timeseries.labels
# units     --> vent_timeseries.units
# ECG       --> vent_timeseries.channels[0] = TimeSeries object
# EMGdi     --> vent_timeseries.channels[1] = TimeSeries object
#   with:
#   vent_timeseries.channels[0].fs = fs
#   vent_timeseries.channels[0].y_raw = y_vent[0, :]
#   vent_timeseries.channels[0].t_data = time axis data for y_raw
#   vent_timeseries.channels[0].y_clean = None
#   vent_timeseries.channels[0].y_baseline = None
#   vent_timeseries.channels[0].y_baseline = None
#   etc.
# From the labels 'Paw', 'F', and 'Vvent' are automatically detected.

# 3. Pre-process the data

In [None]:
# Filter
emg_timeseries.filter(
    signal_type='raw',
    hp_cf=20.0,
    lp_cf=500.0,    
    channel_idxs=[0, 1],
)
# signal_type:      Filter the raw, just assigned, data
# hp_cf:            High-pass cut-off frequency of 20 Hz
# lp_cf:            Low-pass cut-off frequency of 500 Hz
# channel_idxs:     For all channels (None would default to this)

# --> Data is stored in:
# emg_timeseries.channels[channel_idx].y_clean


In [None]:
# Gate the EMG
emg_timeseries.gating(
    signal_type='clean',        
    gate_width_samples=None,    
    ecg_peak_idxs=None,         
    ecg_raw=None,               
    bp_filter=True,
    channel_idxs=None,
)
# signal_type:          Filter the clean, just filtered, data
# gate_width_samples:   Gate width, `None` defaults to fs // 10
# ecg_peak_idxs:        Sample idxs of ECG peaks, when `None` peaks are 
#                       automatically identified.
# ecg_raw:              ECG data to detect ECG peaks in if no ecg_peak_idxs are
#                       provided. If `None` and no ecg-channel is detected
#                       from the labels the raw channel data is used.
# bp_filter:            True/False: Filter the provided ecg_raw between 1-500
#                       Hz before peak detection
# channel_idxs:         For all channels (None would default to this)

# --> Data is stored in:
# emg_timeseries.channels[channel_idx].y_clean


In [None]:
# Calculate the envelope of the signal
emg_timeseries.envelope(
    env_window=None,
    env_type='arv',
    signal_type='clean',
)
# env_window:           Envelope window width, `None` defaults to fs // 5
# env_type:             'rms' for root-mean-square (default), 'arv' for average
#                       rectified
# signal_type:          Calculate the envelope over the clean data

# --> Data is stored in:
# emg_timeseries.channels[channel_idx].y_env

In [None]:
# Calculate the baseline for the EMG envelopes and Paw
emg_timeseries.baseline(
    percentile=33,
    window_s=int(7.5 * fs_emg),
    step_s=fs_emg // 5,
    method='default',
    signal_type='env',
    augm_percentile=25,
    ma_window=None,
    perc_window=None,
)    

vent_timeseries.baseline(
    channel_idxs=[0],
    signal_type='raw')


# percentile:       Percentile of signal in the window to take as the baseline
#                   value 
# window_s:         Window length in samples (default: int(7.5 * fs))
# step_s:           Number of consecutive samples with the same baseline value
#                   (default: fs // 5)
# method:           'default' or 'slopesum'
# signal_type:      Calculate the baseline over the envelope (y_env) for emg 
#                   and over the original signal for Paw
# augm_percentile   Augmented_percentile for slopesum baseline
# ma_window:        Moving average window in samples for average dy/dt in
#                   slopesum baseline
# perc_window:      'perc_window' for slopesum baseline.

# --> Data is stored in:
# emg_timeseries.channels[channel_idx].y_baseline

In [None]:
# Plot the raw data with the envelope
fig, axis = plt.subplots(nrows=3, ncols=2, figsize=(12, 6), sharex=True)
axes_emg = axis[:2, 0]
colors = ['tab:cyan', 'tab:orange']
emg_timeseries.plot_full(axes_emg, signal_type='clean', baseline_bool=False)
emg_timeseries.plot_full(axes_emg, signal_type='env', colors=colors)
axes_emg[0].set_title('EMG data')
axes_emg[-1].set_xlabel('t (s)')

axes_vent = axis[:, 1]
vent_timeseries.plot_full(axes_vent)
axes_vent[0].set_title('Ventilator data')
axes_vent[-1].set_xlabel('t (s)')

axis[-1, 0].axis('off')

axes_vent[-1].set_xlim([370, 410])

# 3. Identify Pocc peaks of interest in Paw

In [None]:
# Find occlusion pressures
vent_timeseries.peep
vent_timeseries.find_occluded_breaths(
    vent_timeseries.p_aw_idx, start_idx=360*vent_timeseries.fs)

paw = vent_timeseries.channels[vent_timeseries.p_aw_idx]
paw.peaks['Pocc'].detect_on_offset(baseline=paw.y_baseline)

paw.peaks['Pocc'].peak_df['peak_idx']/paw.fs

# Find supported breath pressures
v_vent = vent_timeseries.channels[vent_timeseries.v_vent_idx]
vent_timeseries.find_tidal_volume_peaks()

paw.peaks['Pocc'].peak_df['peak_idx']/paw.fs

# 4. Identify all sEMG breaths, and find those closest to the Pocc peaks

In [None]:
# Find sEAdi peaks in one channel (sEAdi)
emg_di = emg_timeseries.channels[1]
emg_di.detect_emg_breaths(peak_set_name='breaths')
emg_di.peaks['breaths'].detect_on_offset(
    baseline=emg_di.y_baseline
)

# Link ventilator Pocc peaks to EMG breaths
t_pocc_peaks_vent = paw.peaks['Pocc'].peak_df['peak_idx'].to_numpy() / paw.fs
emg_di.link_peak_set(
    peak_set_name='breaths',
    t_reference_peaks=t_pocc_peaks_vent,
    linked_peak_set_name='Pocc',
)
emg_di.peaks['Pocc'].peak_df['start_idx']/emg_di.fs


In [None]:
# Plot the identified Pocc peaks in Paw and sEAdi
fig, axis = plt.subplots(nrows=3, ncols=2, figsize=(10, 6), sharex=True)
axes_emg = axis[:-1, 0]
colors = ['tab:cyan', 'tab:orange', 'tab:red']
# emg_timeseries.plot_full(axes_emg, signal_type='clean', baseline_bool=False)
# emg_timeseries.plot_full(axes_emg, signal_type='env', colors=colors)
emg_timeseries.plot_full(axes_emg, signal_type='env')
emg_di.plot_markers(peak_set_name='Pocc', axes=axes_emg[1])
axes_emg[1].set_ylabel('sEAdi (uV)')
axes_emg[0].set_title('EMG data')
axes_emg[-1].set_xlabel('t (s)')

axes_vent = axis[:, 1]
vent_timeseries.plot_full(axes_vent)
paw.plot_markers(peak_set_name='Pocc', axes=axes_vent[0])
v_vent.plot_markers(peak_set_name='ventilator_breaths',
                    axes=axes_vent[2], colors='c')

axes_vent[0].set_title('Ventilator data')
axes_vent[-1].set_xlabel('t (s)')
axes_vent[-1].set_xlim([370, 410])

axis[-1, 0].axis('off')

In [None]:
# Plot the individual peaks
n_peaks = len(emg_di.peaks['Pocc'].peak_df['start_idx'].to_numpy())
fig, axis = plt.subplots(nrows=2, ncols=n_peaks, figsize=(10, 6), sharey='row')

axes_emg = axis[0, :]
colors = ['tab:cyan', 'tab:orange', 'tab:red']
emg_di.plot_peaks(axes=axes_emg, peak_set_name='Pocc')
emg_di.plot_markers(axes=axes_emg, peak_set_name='Pocc')
axes_emg[0].set_ylabel('sEAdi (uV)')

axes_vent = axis[1, :]
paw.plot_peaks(axes=axes_vent, peak_set_name='Pocc')
paw.plot_markers(axes=axes_vent, peak_set_name='Pocc')

for axis in axes_vent:
    axis.set_xlabel('t (s)')

# 5. Calculate features: ETP and PTP

In [None]:
# Calculate PTPocc
paw.calculate_time_products(
    peak_set_name='Pocc', 
    aub_reference_signal=paw.y_baseline,
    parameter_name='PTPocc')

print(paw.peaks['Pocc'].peak_df)

In [None]:
# Calculate ETPdi and PTPocc
emg_di.calculate_time_products(
    peak_set_name='Pocc', parameter_name='ETPdi')

print(emg_di.peaks['Pocc'].peak_df)

# 6. Test Pocc and sEMG quality of the peaks

In [None]:
# Test Pocc quality
parameter_names = {
    'time_product': 'PTPocc'
}
paw.test_pocc_quality(
    'Pocc',
    parameter_names=parameter_names, 
    verbose=True)


In [None]:
# The peak-validity is updated in the peak_df:
print(paw.peaks['Pocc'].peak_df)

In [None]:
# Test EMG quality
parameter_names = {
    'time_product': 'ETPdi'
}
emg_di.test_emg_quality('Pocc', parameter_names=parameter_names)

In [None]:
# The peak-validity is updated in the peak_df:
print(emg_di.peaks['Pocc'].peak_df)

In [None]:
# Plot the individual peaks bell-fit
n_peaks = len(emg_di.peaks['Pocc'].peak_df['start_idx'].to_numpy())
fig, axis = plt.subplots(nrows=2, ncols=n_peaks, figsize=(10, 6), sharey='row')

axes_emg = axis[0, :]
colors = ['tab:cyan', 'tab:orange', 'tab:red']
emg_di.plot_peaks(axes=axes_emg, peak_set_name='Pocc')
emg_di.plot_markers(axes=axes_emg, peak_set_name='Pocc')
emg_di.plot_curve_fits(axes=axes_emg, peak_set_name='Pocc')
axes_emg[0].set_ylabel('sEAdi (uV)')

axes_vent = axis[1, :]
paw.plot_peaks(axes=axes_vent, peak_set_name='Pocc')
paw.plot_markers(axes=axes_vent, peak_set_name='Pocc')

for axis in axes_vent:
    axis.set_xlabel('t (s)')