In [1]:
from emgdecompy.decomposition import *
from emgdecompy.contrast import *
from emgdecompy.preprocessing import *

In [8]:
from codecs import raw_unicode_escape_decode
import numpy as np
import pandas as pd
import altair as alt
import panel as pn
from panel.interact import interact, fixed
import math
from sklearn.metrics import mean_squared_error
from emgdecompy.preprocessing import (
    flatten_signal,
    center_matrix,
    butter_bandpass_filter,
)

pn.extension("vega")
from scipy.io import loadmat

In [69]:
signal = flatten_signal(raw)
signal = np.apply_along_axis(
        butter_bandpass_filter,
        axis=1,
        arr=signal,
        lowcut=10,
        highcut=900,
        fs=2048,
        order=6,
    )
centered = center_matrix(signal)
c_sq = centered ** 2
c_sq_mean = c_sq.mean(axis=0)

In [77]:
single_signal = flatten_signal(single_raw)
single_signal = np.apply_along_axis(
        butter_bandpass_filter,
        axis=1,
        arr=signal,
        lowcut=10,
        highcut=900,
        fs=2048,
        order=6,
    )
single_centered = center_matrix(single_signal)
single_c_sq = single_centered ** 2
single_c_sq_mean = single_c_sq.mean(axis=0)

In [84]:
raw_data_dict = loadmat('../data/raw/gl_10.mat')
pt = raw_data_dict['MUPulses']
raw = raw_data_dict['SIG']
shape_dict = muap_dict(raw, pt)

single_pt = pt[0][0]
single_raw = raw[0]
single_shape = muap_dict(single_raw, single_pt)
pulse_plot(single_pt, single_c_sq_mean, 0)

In [68]:
def muap_dict(raw, pt, l=31):
    """
    Returns multi-level dictionary containing sample number, average signal, and channel
    for each motor unit by averaging the peak shapes over every firing for each MUAP.

    Parameters
    ----------
        raw: numpy.ndarray
            Raw EMG signal.
        pt: numpy.ndarray
            Multi-dimensional array containing indices of firing times
            for each motor unit.
        l: int
            One half of action potential discharge time in samples.

    Returns
    -------
        dict
            Dictionary containing MUAP shapes for each motor unit.
    """
    raw = flatten_signal(raw)
    channels = raw.shape[0]
    shape_dict = {}
    
    if pt.ndim > 1:
        pt = pt.squeeze()
        for i in range(pt.shape[0]):
            pt[i] = pt[i].squeeze()

            # Create array to contain indices of peak shapes
            ptl = np.zeros((pt[i].shape[0], l * 2 + 1), dtype="int")

            for j, k in enumerate(pt[i]):
                ptl[j] = np.arange(k - l, k + l + 1)

                if np.less(k, l) == True:
                    ptl[j] = np.arange(k - l, k + l + 1)
                    neg_idx = abs(k - l)
                    ptl[j][:neg_idx] = np.repeat(0, neg_idx)

                else:
                    ptl[j] = np.arange(k - l, k + l + 1)

            ptl = ptl.flatten()

            # Create channel index of each peak
            channel_index = np.repeat(np.arange(channels), l * 2 + 1)

            # Get sample number of each position along each peak
            sample = np.arange(l * 2 + 1)
            sample = np.tile(sample, channels)

            # Get average signals from each channel
            signal = (
                raw[:, ptl]
                .reshape(channels, ptl.shape[0] // (l * 2 + 1), l * 2 + 1)
                .mean(axis=1)
                .flatten()
            )

            shape_dict[f"mu_{i}"] = {
                "sample": sample,
                "signal": signal,
                "channel": channel_index,
            }
    
    else:
        # Create array to contain indices of peak shapes
        ptl = np.zeros((pt.shape[0], l * 2 + 1), dtype="int")
        
        for j, k in enumerate(pt):
                ptl[j] = np.arange(k - l, k + l + 1)

                if np.less(k, l) == True:
                    ptl[j] = np.arange(k - l, k + l + 1)
                    neg_idx = abs(k - l)
                    ptl[j][:neg_idx] = np.repeat(0, neg_idx)

                else:
                    ptl[j] = np.arange(k - l, k + l + 1)
        
        ptl = ptl.flatten()

        # Create channel index of each peak
        channel_index = np.repeat(np.arange(channels), l * 2 + 1)

        # Get sample number of each position along each peak
        sample = np.arange(l * 2 + 1)
        sample = np.tile(sample, channels)

        # Get average signals from each channel
        signal = (
                raw[:, ptl]
                .reshape(channels, ptl.shape[0] // (l * 2 + 1), l * 2 + 1)
                .mean(axis=1)
                .flatten()
            )

        shape_dict[f"mu_0"] = {
                "sample": sample,
                "signal": signal,
                "channel": channel_index,
            }


    return shape_dict


In [4]:
def muap_dict_by_peak(raw, peak, mu_index=0, l=31):
    """
    Returns the dictionary of shapes for a selected peak, by channel.
    It is called by the select_peak() function when a peak is selected by a user.

    Parameters
    ----------
        raw: numpy.ndarray
            Raw EMG signal.
        peak: int
            Peak timing to plot.
        mu_index: int
            Motor Unit the peak belongs to, to keep dict format consistent.
        l: int
            One half of action potential discharge time in samples.

    Returns
    -------
        dict
            Dictionary containing shapes for a given peak per channel.
    """
    raw = flatten_signal(raw)
    channels = raw.shape[0]
    shape_dict = {}

    low = peak - l
    high = peak + l + 1

    shape = raw[:, low:high]  # Shape is channels x Firings; frequently 63 x 62
    # Each of 62 values is a signal

    # Make dictionary from this data

    # Create channel index of each peak
    channel_index = np.repeat(np.arange(channels), l * 2 + 1)  # 64 zeros,
    # 64 ones,
    # 64 twos,
    # [...],
    # 64 sixty-threes.

    # Get sample number of each position along each peak
    sample = np.arange(l * 2 + 1)
    sample = np.tile(sample, channels)
    # sample <- [0,1,2,...,61,0,1,2,...,61]

    # Get signals of each peak
    signal = shape.flatten()
    if peak < l:
        neg_idx = abs(peak - l)
        signal[:neg_idx] = np.repeat(0, neg_idx)

    shape_dict[f"mu_{mu_index}"] = {
        "sample": sample,
        "signal": signal,
        "channel": channel_index,
    }

    return shape_dict

In [5]:
def muap_plot(
    mu_data, mu_index, peak_data=None, l=31, peak="", method="RMSE", preset="standard"
):
    """
    Returns a plot for MUAP shapes separated by channel.
    If peak_data is specified, also plots overlay of contribution of the peak to the shape per channel.
    Called by select_peak() function.

    Parameters
    ----------
        mu_data: dict
            Dictionary containing MUAP shapes for each motor unit.
        mu_index: int
            Index of motor unit to examine
        peak_data: dict
            Dictionary containing shapes for a given peak per channel.
            Specifying it creates the overlay of peak contribution
        l: int
            One half of action potential discharge time in samples.
        peak: int:
            Index of the peak, used for the Title of the plot.
        method: function name
            Function to use to calculate mean (over all channels) mismatch score between averaged shape and given peak
        preset: str
            Name of preset to use, for arranging the channels on the plot

    Returns
    -------
        altair.vegalite.v4.api.FacetChart
            Facetted altair plot overlaying MU shapes per channel
            and peak shapes per channel.
    """
    alt.data_transformers.disable_max_rows()

    df = pd.DataFrame(mu_data[f"mu_{mu_index}"])
    df["Source"] = "MUAP"
    plot_title = f"MUAP Shapes for MU {mu_index}"
    legend_position = None  # Hide legend when we only showing MUAPs
    sort_order = channel_preset(preset)["sort_order"]
    cols = channel_preset(preset)["cols"]

    if peak_data:
        peak_df = pd.DataFrame(peak_data[f"mu_{mu_index}"])
        peak_df["Source"] = "Peak Contribution"
        df = pd.concat([df, peak_df])
        err = mismatch_score(mu_data, peak_data, mu_index, method=method, channel=-1)
        err = round(err)
        plot_title = (
            f"Peak at {peak} s contribution per Channel to MU {mu_index}. RMSE = {err}"
        )
        legend_position = alt.Legend(
            orient="none",
            title=None,
            legendX=400,
            legendY=-40,
            direction="horizontal",
            titleAnchor="middle",
        )  # Show Legend when showing overlay

    selection = alt.selection_multi(fields=["Source"], bind="legend")

    plot = (
        alt.Chart(df, title=plot_title)
        .encode(
            x=alt.X("sample", axis=None),
            y=alt.Y("signal", axis=None),
            color=alt.Color(
                "Source",
                scale={"range": ["#fd3a4a", "#99a7f1"]},
                legend=legend_position,
            ),
            opacity=alt.condition(selection, alt.value(1), alt.value(0.2)),
            facet=alt.Facet(
                "channel",
                columns=cols,
                spacing={"row": 0},
                header=alt.Header(
                    titleFontSize=0,
                    labelFontSize=14,
                    # labelOrient="bottom",
                ),
                sort=sort_order,
            ),
        )
        .mark_line()
        .properties(width=112, height=100)
        .configure_title(fontSize=14, anchor="middle")
        .configure_axis(labelFontSize=14)
        .configure_view(strokeWidth=0)
        .add_selection(selection)
    )

    return plot


In [83]:
def pulse_plot(pt, c_sq_mean, mu_index, sel_type="single"):
    """
    Plot firings for a given motor unit.

    Parameters
    ----------
        pulse_train: np.array
            Pulse train.
        c_sq_mean: np.array
            Centered, squared and averaged firings over the duration of the trial.
        mu_index: int
            Motor Unit of interest to plot firings for.
            Default is None and means return all pulses.
        sel_type: str
            Whether to select single points or intervals

    Returns
    -------
        altair plot object
    """

    color_pulse = "#35d3da"
    color_rate = "#9cb806"
    
    motor_df = pd.DataFrame(columns=["Pulse", "Strength", "Motor Unit", "Hz"])
    
    if pt.ndim > 1:
        mu_count = pt.squeeze().shape[0]

        for i in range(0, mu_count):
            # PT for MU of interest:
            pt_selected = pt.squeeze()[i].squeeze()
            strength_selected = c_sq_mean[pt_selected]
            hertz = np.insert(1 / np.diff(pt_selected) * 2048, 0, 0)

            # Make those into DF:
            pulses_i = {
                "Pulse": pt_selected,
                "Strength": strength_selected,
                "Motor Unit": i,
                "seconds": pt_selected / 2048,
                "Hz": hertz,
            }
            motor_df_i = pd.DataFrame(pulses_i)
            motor_df = pd.concat([motor_df, motor_df_i])

            motor_df = motor_df.loc[motor_df["Motor Unit"] == mu_index]
            # brush = alt.selection_interval(encodings=['x'], name='brush') # Don't know if we will use this
    
    else:
        mu_count = 1
        
        pt_selected = pt
        strength_selected = c_sq_mean[pt_selected]
        hertz = np.insert(1 / np.diff(pt_selected) * 2048, 0, 0)
        
        # Make those into DF:
        pulses_i = {
                "Pulse": pt_selected,
                "Strength": strength_selected,
                "Motor Unit": 0,
                "seconds": pt_selected / 2048,
                "Hz": hertz,
            }
        motor_df_i = pd.DataFrame(pulses_i)
        motor_df = pd.concat([motor_df, motor_df_i])

        motor_df = motor_df.loc[motor_df["Motor Unit"] == mu_index]
        # brush = alt.selection_interval(encodings=['x'], name='brush') # Don't know if we will use this


    # TODO: Selection only makes sense if we are working with specific MU

    sel_peak = alt.selection_single(name="sel_peak")

    sel_interval = alt.selection_interval(encodings=["x"], name="sel_interval")

    chart_top_base = (
        alt.Chart(motor_df)
        .encode(
            alt.X(
                "seconds:Q",
                axis=alt.Axis(title="Time (s)", grid=False),
            )
        )
        .properties(width=1000, height=100)
    )

    chart_top_rate = (
        chart_top_base.mark_point(size=30, color=color_rate)
        .encode(
            alt.Y(
                "Hz:Q",
                axis=alt.Axis(
                    title="Instantaneous Firing Rate (Hz)",
                    grid=False,
                    format=".0f",
                    titleColor=color_rate,
                ),
            )
        )
        .add_selection(sel_interval)
    )

    chart_top_pulse = chart_top_base.mark_bar(
        size=3.5, color=color_pulse, opacity=0.3
    ).encode(
        alt.Y(
            "Strength:Q",
            axis=alt.Axis(
                title="Signal (A.U.)", grid=False, format="s", titleColor=color_pulse
            ),
        )
    )

    chart_top = alt.layer(chart_top_pulse, chart_top_rate).resolve_scale(
        y="independent"
    )

    chart_rate = (
        alt.Chart(motor_df)
        .encode(
            alt.X(
                "seconds:Q",
                axis=alt.Axis(title="Time (s)", grid=False),
                scale=alt.Scale(domain=sel_interval),
            ),
            alt.Y(
                "Hz:Q",
                axis=alt.Axis(
                    title="Instantaneous Firing Rate (Hz)",
                    grid=False,
                    format=".0f",
                    titleColor=color_rate,
                ),
            ),
            color=alt.condition(
                sel_peak, alt.value(color_rate), alt.value("lightgray"), legend=None
            ),
            tooltip=[
                alt.Tooltip("Hz", format=".2f"),
                alt.Tooltip("seconds", format=".2f"),
            ],
        )
        .properties(width=1000, height=250)
        .mark_point(size=30)
        .add_selection(sel_peak)
        .transform_filter(sel_interval)
    )

    chart_pulse = (
        alt.Chart(motor_df)
        .encode(
            alt.X(
                "seconds:Q",
                axis=alt.Axis(title="Time (s)", grid=False),
                scale=alt.Scale(domain=sel_interval),
            ),
            alt.Y(
                "Strength:Q",
                axis=alt.Axis(
                    title="Signal (A.U.)",
                    grid=False,
                    format="s",
                    titleColor=color_pulse,
                ),
            ),
            color=alt.condition(
                sel_peak, alt.value(color_pulse), alt.value("lightgray"), legend=None
            ),
        )
        .mark_bar(size=3.5)
        .add_selection(sel_peak)
        .properties(width=1000, height=250)
        .transform_filter(sel_interval)
    )

    return chart_top & chart_rate & chart_pulse

In [7]:
def select_peak(
    selection, mu_index, raw, shape_dict, pt, preset="standard", method="RMSE"
):
    """
    Retrieves a given peak (if any) and re-graphs MUAP plot via muap_plot() function.
    Called within dashboard() function, binded to the peak selection on pulse graphs.

    Parameters
    ----------
        selection: selection object
            Selection object to dig into and retrieve peak index to plot.

        mu_index: int
            Currently plotted Motor Unit.

        raw: numpy.ndarray
            Raw EMG signal array.

        shape_dict: dict
            Dictionary containing MUAP shapes for each motor unit.

        pt: numpy.ndarray
            Multi-dimensional array containing indices of firing times
            for each motor unit.

    Returns
    -------
        altair plot object

    """
    global selected_peak

    if not selection:
        plot = muap_plot(shape_dict, mu_index, l=31, preset=preset, method="RMSE")
        selected_peak = -1

    else:
        selected_peak = selection[0] - 1
        # for some reason beyond my grasp these are 1-indexed
        peak = pt.squeeze()[mu_index].squeeze()[selected_peak]

        peak_data = muap_dict_by_peak(raw, peak, mu_index=mu_index, l=31)
        plot = muap_plot(
            shape_dict,
            mu_index,
            peak_data,
            l=31,
            peak=str(round(peak / 2048, 2)),
            preset=preset,
            method="RMSE",
        )

    return pn.Column(
        pn.Row(
            pn.pane.Vega(plot, debounce=10, width=750),
        )
    )
