In [None]:
import numpy as np
import pandas as pd
import mat73
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
file_prefix = "202230920"
session_id = "004"
min_ptp = 0.2
stimulus_sampling_rate = 2.5e6
relevant_channels = ["Ch1", "Ch3", "Ch32", "Ch4"]

raw_data = mat73.loadmat(f"raw/{file_prefix}_{session_id}.mat")
stimuli = mat73.loadmat(f"raw/StimSet220610.mat")["StimSet220610"]
data = {}
for k in relevant_channels:
    data[k.lower()] = raw_data[f"V{file_prefix}_{session_id}_{k}"]
del raw_data
recording_sampling_rate = 1 / data["ch4"]["interval"]

trigger_times = data["ch1"]["times"]  # EO motor command trigger times
time_intervals = data["ch32"]["times"][1:-1].reshape(-1, 2)

delivered_eod_id = data["ch32"]["codes"][1:-1, 0].reshape(-1, 2)
assert (
    delivered_eod_id[:, 0] == delivered_eod_id[:, 1]
).all(), "Delivered stimulus ID must match."
delivered_eod_id = delivered_eod_id[:, 0]

stimulus_snippet_margins = [
    round(0.2e-3 * stimulus_sampling_rate),
    round(0.3e-3 * stimulus_sampling_rate),
]
recording_peak_margins = [
    round(3e-3 * recording_sampling_rate),
    round(5e-3 * recording_sampling_rate),
]
recording_snippet_margins = [
    round(0.2e-3 * recording_sampling_rate),
    round(0.3e-3 * recording_sampling_rate),
]
trigger_times_int = (trigger_times * recording_sampling_rate).round().astype(int)

eod_data = pd.DataFrame()
for i in range(len(delivered_eod_id)):
    eod_id = delivered_eod_id[i]

    stimulus_leod = stimuli["waveform"][np.where(stimuli["marker"] == eod_id)[0][0]]
    stimulus_peak_location = stimulus_leod.argmax()
    stimulus_leod = stimulus_leod[
        stimulus_peak_location
        - stimulus_snippet_margins[0] : stimulus_peak_location
        + stimulus_snippet_margins[1]
    ]
    stimulus_leod = stimulus_leod / max(stimulus_leod)
    stimulus_times = np.arange(0, len(stimulus_leod)) / stimulus_sampling_rate

    t1 = time_intervals[i, 0]
    t2 = time_intervals[i, 1]
    trigger_times_within = np.where((trigger_times >= t1) & (trigger_times <= t2))[0]

    for j in range(len(trigger_times_within)):
        temp_recording_digital = data["ch4"]["values"][
            trigger_times_int[trigger_times_within[j]]
            + recording_peak_margins[0] : trigger_times_int[trigger_times_within[j]]
            + recording_peak_margins[1]
        ]
        if temp_recording_digital.ptp() > min_ptp:
            temp_recording_digital_peak_location = temp_recording_digital.argmax()
            temp_recording_digital = temp_recording_digital[
                max(
                    0, temp_recording_digital_peak_location - recording_snippet_margins[0]
                ) : temp_recording_digital_peak_location
                + recording_snippet_margins[1]
            ]
            temp_recording_digital = temp_recording_digital / max(temp_recording_digital)
            temp_recording_digital_times = (
                np.arange(0, len(temp_recording_digital)) / recording_sampling_rate
            )

            temp_recording_real = data["ch3"]["values"][
                trigger_times_int[trigger_times_within[j]]
                + recording_peak_margins[0] : trigger_times_int[trigger_times_within[j]]
                + recording_peak_margins[1]
            ]
            temp_recording_real_peak_location = temp_recording_real.argmax()
            temp_recording_real = temp_recording_real[
                max(
                    0, temp_recording_real_peak_location - recording_snippet_margins[0]
                ) : temp_recording_real_peak_location
                + recording_snippet_margins[1]
            ]
            temp_recording_real = temp_recording_real / max(temp_recording_real)
            temp_recording_real_times = (
                np.arange(0, len(temp_recording_real)) / recording_sampling_rate
            )

            eod_data = pd.concat(
                [
                    eod_data,
                    pd.DataFrame(
                        dict(
                            stimulus=(stimulus_leod,),
                            stimulus_time=(stimulus_times,),
                            recording_digital=(temp_recording_digital,),
                            recording_digital_time=(temp_recording_digital_times,),
                            recording_real=(temp_recording_real,),
                            recording_real_time=(temp_recording_real_times,),
                        ),
                        index=[0]
                    ),
                ],
                axis=0,
                ignore_index=True,
            )

In [None]:
def interpolate_recording(dfrow):
    x_new = dfrow["stimulus_time"]
    x_data = dfrow["recording_real_time"]
    y_data = dfrow["recording_real"]
    return np.interp(x_new, x_data, y_data)

eod_data["recording_real_interpolated"] = eod_data.apply(interpolate_recording, axis=1)
eod_data.to_pickle(f"processed/delivered-vs-recorded.pkl")

In [None]:
id = 0

dfrow = eod_data.iloc[id]

%matplotlib qt
plt.figure()
plt.plot(dfrow["stimulus_time"], dfrow["stimulus"], label="stimulus", color='k', lw=4)
plt.plot(dfrow["recording_digital_time"], dfrow["recording_digital"], label="stimulus", color='b', lw=2)
plt.plot(dfrow["recording_real_time"], dfrow["recording_real"], label="stimulus", color='r', lw=2)
plt.plot(dfrow["stimulus_time"], dfrow["recording_real_interpolated"], label="interpolated", color='g', lw=1)
plt.tight_layout()
plt.show()

In [None]:
xs = np.hstack(eod_data["stimulus"])
ys = np.hstack(eod_data["recording_real_interpolated"])

plt.figure()
plt.scatter(xs, ys, marker=".", color="k", s=1)
plt.axis("equal")
plt.tight_layout()
plt.show()

In [None]:
sns.displot(ys-xs, kind="kde")
plt.tight_layout()
plt.show()