In [None]:
# Core
import sys
import pandas as pd
import numpy as np
from pathlib import Path

## Plotting
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns

## Harp/Bonsai
sys.path.append('../../src/')
from bonsai import load_bonsai_config
load_bonsai_config(r"C:\git\AllenNeuralDynamics\aind-vr-foraging\Bonsai")
import harp
import harp.processing
import data_io



In [None]:
#Global Viz settings
sns.set_style('darkgrid') # darkgrid, white grid, dark, white and ticks
plt.rc('axes', titlesize=18)     # fontsize of the axes title
plt.rc('axes', labelsize=14)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=13)    # fontsize of the tick labels
plt.rc('ytick', labelsize=13)    # fontsize of the tick labels
plt.rc('legend', fontsize=13)    # legend fontsize
plt.rc('font', size=13)          # controls default text sizes

mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['font.family'] = 'Arial'

default_img_size = (15, 8)

In [None]:
session_paths = [
    Path(r"Z:\scratch\vr-foraging\672103\20231013T100814"),
    Path(r"Z:\scratch\vr-foraging\672107\20231013T111657"),
    Path(r"Z:\scratch\vr-foraging\672106\20231013T101026"),
    Path(r"Z:\scratch\vr-foraging\672104\20231013T092240"),
    Path(r"Z:\scratch\vr-foraging\672102\20231012T094718")
    ]




In [None]:
for session_path in session_paths:
    # Harp Devices:
    HarpBehavior = harp.HarpDevice("Behavior")
    harp_behavior_data = data_io.HarpSource(device=HarpBehavior, path=session_path / "Behavior", name="behavior", autoload=False)

    software_events = data_io.SoftwareEventSource(path=session_path / "SoftwareEvents", name="software_events", autoload=True)
    config = data_io.ConfigSource(path=session_path / "Config", name="config", autoload=True)
    operation_control = data_io.OperationControlSource(path=session_path / "OperationControl", name="config", autoload=False)


    treadmill_metadata = config.streams.Rig.data["treadmill"]
    encoder = harp.read_harp_bin(harp_behavior_data.streams.AnalogData.path).iloc[:,1]
    converter = treadmill_metadata["wheelDiameter"] * np.pi / treadmill_metadata["pulsesPerRevolution"] * (-1 if treadmill_metadata["invertDirection"] else 1)
    encoder = encoder.apply(lambda x : x * converter)
    encoder.index = pd.to_datetime(encoder.index, unit="s")
    encoder = encoder.resample("33ms").sum().interpolate(method="linear") / 0.033
    encoder.index = (encoder.index - pd.to_datetime(0))
    encoder.index = encoder.index.total_seconds()

    sites = software_events.streams.ActiveSite.data
    interpatches = sites.loc[sites["data"].apply(lambda x : x["label"] == "InterPatch")]
    interpatches.loc[:, "end"] = np.nan
    for index, site in interpatches.iterrows():
        original_idx = np.where(sites.index.values == index)[0][0]
        if original_idx + 1 < len(sites):
            interpatches.loc[index, "end"] = sites.index[original_idx + 1]
    interpatches


    fig, axs = plt.subplots(2, 1, figsize=(6, 8), sharex=True, sharey=True)

    window = (-1, 2)

    for isp, col in enumerate(["end", "start"]):
        traces = []
        for site_idx, site in interpatches.iloc[:-1,:].iterrows():
            if col == "end":
                site_idx = site[col]
            else:
                site_idx = site_idx
            enconder_slice = encoder.loc[site_idx + window[0]: site_idx + window[1]]
            traces.append(enconder_slice)
            axs[isp].plot(
                enconder_slice.index.values - site_idx,
                enconder_slice.values,
                color='k', alpha=0.2, lw = 1)
        min_len = min([len(x) for x in traces])
        traces_np = np.array([x.values[:min_len] for x in traces])
        mean = np.mean(traces_np, axis=0)
        std = np.std(traces_np, axis=0)
        axs[isp].plot(
            enconder_slice.index.values[:min_len] - site_idx,
            np.mean(traces_np, axis=0),
            color='b', alpha=1, lw = 3)
        axs[isp].fill_between(enconder_slice.index.values[:min_len] - site_idx, mean-std, mean+std,
                            color = 'b', alpha=0.3)
        axs[isp].vlines(0, -100, 100, color='r', lw=2)
        axs[isp].set_title(col)
        axs[isp].set_xlabel("Velocity (cm/s)")
    plt.ylim((-1, 50))
    plt.ylabel('Time from event (s)')
    plt.xlim(window)
    plt.suptitle(session_path)