# Mean traces

In [None]:
import matplotlib.pyplot as plt
from pathlib import Path
import ap_features as apf
import numpy as np
import mps
import mps_motion
from mps_motion import (
    Mechanics,
    scaling,
    OpticalFlow,
    frame_sequence as fs,
)

import config

In [None]:
def get_default_displacement_velocity_traces(dataid):
    path = config.results_dir / f"default_traces_{dataid}.npy"
    config.results_dir.mkdir(exist_ok=True, parents=True)
    if path.is_file():
        return np.load(path, allow_pickle=True).item()

    peak_contraction_vel = [49, 176]
    peak_twitch = [56, 210]
    N = 6

    mps_path = config.mps_paths[dataid]
    data = mps.MPS(mps_path)
    opt_flow = mps_motion.OpticalFlow(
        data,
        flow_algorithm="farneback",
    )
    u = opt_flow.get_displacements(
        reference_frame=config.DEFAULT_REFERENCE_FRAMES[dataid]
    )
    v = Mechanics(u, t=data.time_stamps / 1000.0).velocity(
        spacing=config.DEFAULT_SPACING
    )

    v_norm_max = v.norm().max().compute()
    u_norm_max = u.norm().max().compute()

    u_peak_twitch = u.array[:, :, peak_twitch[dataid], :].compute()
    v_peak_contraction_vel = v.array[:, :, peak_contraction_vel[dataid], :].compute()

    d_original = {
        "time": data.time_stamps,
        "u_norm": u.norm().mean().compute(),
        "u_x": u.x.mean().compute(),
        "u_y": u.y.mean().compute(),
        "v_norm": v.norm().mean().compute(),
        "v_x": v.x.mean().compute(),
        "v_y": v.y.mean().compute(),
        "u_norm_max": u_norm_max,
        "v_norm_max": v_norm_max,
        "u_peak_twitch": u_peak_twitch,
        "v_peak_contraction_vel": v_peak_contraction_vel,
    }

    mask = u_norm_max < u_norm_max.mean()
    v.apply_mask(mask)
    u.apply_mask(mask)

    u_peak_twitch = u.array[:, :, peak_twitch[dataid], :].compute()
    v_peak_contraction_vel = v.array[:, :, peak_contraction_vel[dataid], :].compute()
    u_norm = u.norm().mean().compute()
    v_norm = v.norm().mean().compute()
    v_norm_max = v.norm().max().compute()
    u_norm_max = u.norm().max().compute()

    data = {
        "v_norm": v_norm,
        "u_norm": u_norm,
        "time": data.time_stamps,
        "u_x": u.x.mean().compute(),
        "u_y": u.y.mean().compute(),
        "v_x": v.x.mean().compute(),
        "v_y": v.y.mean().compute(),
        "u_norm_max": u_norm_max,
        "v_norm_max": v_norm_max,
        "frame_peak_contraction_vel": data.frames[:, :, peak_contraction_vel[dataid]],
        "frame_peak_twitch": data.frames[:, :, peak_twitch[dataid]],
        "u_peak_twitch": u_peak_twitch,
        "v_peak_contraction_vel": v_peak_contraction_vel,
        "original": d_original,
        "u_loc": u.norm().local_averages(N=N),
        "v_loc": v.norm().local_averages(N=N),
    }
    np.save(path, data)
    return data

In [None]:
fig, ax = plt.subplots(2, 4, sharex="col", sharey="row", figsize=(10, 6))
for dataid, i in enumerate([0, 2]):
    d = get_default_displacement_velocity_traces(dataid)

    (l_orig,) = ax[0, i].plot(d["time"], d["u_norm"], color="tab:blue")

    ax[1, i].plot(
        d["time"][: -config.DEFAULT_SPACING[dataid]], d["v_norm"], color="tab:blue"
    )
    bkg_u = apf.background.correct_background(
        d["time"], d["u_norm"], method="subtract"
    )

    bkg_v = apf.background.correct_background(
        d["time"][: -config.DEFAULT_SPACING[dataid]], d["v_norm"], method="subtract"
    )

    (l_bkg,) = ax[0, i].plot(bkg_u.x, bkg_u.background, color="tab:orange")
    (l_cor,) = ax[0, i + 1].plot(bkg_u.x, bkg_u.corrected, color="tab:red")

    ax[1, i].plot(bkg_v.x, bkg_v.background, color="tab:orange")
    ax[1, i + 1].plot(bkg_v.x, bkg_v.corrected, color="tab:red")

for axi in fig.axes:
    axi.grid()

for j, label in enumerate(
    [r"$\overline{\| u \|}$ [$\mu m$]", r"$\overline{\| v \|}$ [$\mu m/s$]"]
):
    ax[j, 0].set_ylabel(label)

ax[0, 0].set_title("Dataset 1")
ax[0, 1].set_title("Dataset 1 (corrected)")
ax[0, 2].set_title("Dataset 2")
ax[0, 3].set_title("Dataset 2 (corrected)")

for i in range(4):
    ax[1, i].set_xlabel("Time [ms]")

fig.legend(
    (l_orig, l_bkg, l_cor),
    ("Original", "Baseline", "Corrected"),
    loc="center",
    ncol=3,
)
fig.savefig(
    config.figdir / "mean_traces.png",
    dpi=300,
)
