In [None]:

# Core
import sys
import pandas as pd
import numpy as np
from pathlib import Path
from typing import List, Dict, Tuple, Optional, Union, Any, Callable

## Plotting
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import ticker
from matplotlib.patches import Rectangle
import ipywidgets as widgets
import mpl_interactions.ipyplot as iplt

import seaborn as sns

## Harp/Bonsai
sys.path.append('../../src/')
from bonsai import load_bonsai_config
load_bonsai_config(r"C:\git\AllenNeuralDynamics\aind-vr-foraging\Bonsai")
import harp
import harp.processing
import data_io

#Global Viz settings
#sns.set_style('darkgrid') # darkgrid, white grid, dark, white and ticks
sns.set_context("talk")
#plt.rc('axes', titlesize=18)     # fontsize of the axes title
#plt.rc('axes', labelsize=14)    # fontsize of the x and y labels
#plt.rc('xtick', labelsize=13)    # fontsize of the tick labels
#plt.rc('ytick', labelsize=13)    # fontsize of the tick labels
#plt.rc('legend', fontsize=13)    # legend fontsize
#plt.rc('font', size=13)          # controls default text sizes

#mpl.rcParams['pdf.fonttype'] = 42
#mpl.rcParams['ps.fonttype'] = 42
#mpl.rcParams['font.family'] = 'Arial'

default_img_size = (15, 8)

## From GitHub

[aind-vr-foraging Issue #81](https://github.com/AllenNeuralDynamics/aind-vr-foraging/issues/81)


### Notes:

- Sessions from 10/10/2023
    - 102
      - 7, 18.10, 28.47, 31.57
    - 103
      - 9.23, 10.06, 12.33, 15.08, 29.36 (manually given water), 38.11, 49.23
    - 107
      - 9.17, 10.29, 11.18, 13.40, 15.19


1. Check for Reward Sites
2. Check for ChoiceFeedback
3. Check for Licks
4. Check for GiveReward
5. Check the valve events (both time and onset)


In [None]:
# Load data from a single session
session_path = Path(r"Z:\scratch\vr-foraging\672107\20231013T111657")

# Harp Devices:
HarpBehavior = harp.HarpDevice("Behavior")
harp_behavior_data = data_io.HarpSource(device=HarpBehavior, path=session_path / "Behavior", name="behavior", autoload=False)

software_events = data_io.SoftwareEventSource(path=session_path / "SoftwareEvents", name="software_events", autoload=True)
config = data_io.ConfigSource(path=session_path / "Config", name="config", autoload=True)
operation_control = data_io.OperationControlSource(path=session_path / "OperationControl", name="config", autoload=False)

In [None]:
# Syntactic sugar
def add_position(df: pd.DataFrame, position: data_io.DataStream):

    df = pd.merge_asof(df.sort_index(), position.data.sort_index(), direction='nearest', on="Seconds").set_index("Seconds").sort_index()
    df.columns = [*df.columns[:-1], 'Position']
    return df

In [None]:
# Load treadmill data
treadmill_metadata = config.streams.Rig.data["treadmill"]
encoder = harp.read_harp_bin(harp_behavior_data.streams.AnalogData.path).iloc[:,1]
converter = treadmill_metadata["wheelDiameter"] * np.pi / treadmill_metadata["pulsesPerRevolution"] * (-1 if treadmill_metadata["invertDirection"] else 1)
encoder = encoder.apply(lambda x : x * converter)
encoder.index = pd.to_datetime(encoder.index, unit="s")
encoder = encoder.resample("33ms").sum().interpolate(method="linear") / 0.033
encoder.index = (encoder.index - pd.to_datetime(0))
encoder.index = encoder.index.total_seconds()

In [None]:
win = (0, 1500)
#win = (673, 704)

show_speed = False
save_name = "session_w_speed"

# Find reward sites
sites = software_events.streams.ActiveSite.data
zero_index = sites.index[0]

reward_sites = sites.loc[sites["data"].apply(lambda x : x['label'] == 'Reward'), :]

# Find ChoiceFeedback events (i.e. successful stops)
choice_feedback = software_events.streams.ChoiceFeedback.data

# Check for licks
## mask for digital inputs
digital_inputs = HarpBehavior.module.DigitalInputs

harp_behavior_data.streams.DigitalInputState.load_from_file()
di_state = harp_behavior_data.streams.DigitalInputState.data
di_port0 = di_state["Value"].apply(lambda x : x.HasFlag(digital_inputs.DIPort0))
di_port0 = di_port0.loc[di_port0.diff() == True]
lick_onset = di_port0.loc[di_port0 == True]

# Find give reward event
give_reward = software_events.streams.GiveReward.data

# Find hardware reward events
harp_behavior_data.streams.PulseSupplyPort0.load_from_file() # Duration of each pulse
pulse_duration = harp_behavior_data.streams.PulseSupplyPort0.data
digital_outputs = HarpBehavior.module.DigitalOutputs
harp_behavior_data.streams.OutputSet.load_from_file()
valve_output_pulse = harp_behavior_data.streams.OutputSet.data
valve_output_pulse = valve_output_pulse.loc[valve_output_pulse["Value"].apply(lambda x: x.HasFlag(digital_outputs.SupplyPort0))]

valve_output_pulse['pulse_duration'] = -1
for seconds, contents in valve_output_pulse.iterrows():
    idx, _ = harp.processing.find_closest(seconds, pulse_duration.index.values, mode='below_zero')
    if not np.isnan(idx):
        valve_output_pulse.loc[seconds, 'pulse_duration'] = pulse_duration.iloc[idx].Value
print(valve_output_pulse["pulse_duration"].unique())

label_dict = {
    "Gap": '#808080',
    "Reward_0": '#377eb8',
    "Reward_1": '#e41a1c',
    "InterPatch": '#b3b3b3'}

fig, axs = plt.subplots(2,1, figsize=(25,5), gridspec_kw=dict(height_ratios=[1, 4]))
_legend = {}
for idx, site in enumerate(sites.iloc[:-1].iterrows()):
    site_label = site[1]["data"]["label"]
    if site_label == "Reward":
        site_label = f"{site_label}_{site[1]['data']['odor']['index']}"
        facecolor = label_dict[site_label]
    else:
        facecolor = label_dict[site_label]

    p = Rectangle(
        (sites.index[idx] - zero_index, -2), sites.index[idx+1] - sites.index[idx], 8,
        linewidth = 0, facecolor = facecolor, alpha = .5)
    _legend[site_label] = p
    axs[1].add_patch(p)

s, lw = 400, 2
# Plotting raster
y_idx = 0
# _legend["ChoiceFeedback"] = axs[1].scatter(choice_feedback.index - zero_index,
#            choice_feedback.index * 0 + y_idx,
#            marker=".", s=s, lw=lw, c='#e6ab02',
#            label="ChoiceFeedback")
y_idx += 1
_legend["Lick"] = axs[1].scatter(lick_onset.index - zero_index,
           lick_onset.index * 0 + y_idx,
           marker="|", s=s, lw=lw, c='k',
           label="Lick")
_legend["ValveOpen"] = axs[1].scatter(valve_output_pulse.index - zero_index,
           valve_output_pulse.index*0 + y_idx,
           marker=".", s=s, lw=lw, c='#ff7f00',
           label="ValveOpen")
y_idx += 1

#ax.set_xticks(np.arange(0, sites.index[-1] - zero_index, 10))
axs[1].set_yticklabels([])
axs[1].set_xlabel("Time(s)")
axs[1].set_ylim(bottom=-1, top = 3)
axs[1].grid(False)
axs[1].set_xlim(win)
#plt.legend()

if show_speed:
    ax2 = axs[1].twinx()
    _legend["Velocity"] = ax2.plot(encoder.index - zero_index, encoder, c="k", label="Encoder", alpha = 0.8)[0]
    v_thr = config.streams.TaskLogic.data["operationControl"]["positionControl"]["stopResponseConfig"]["velocityThreshold"]
    _legend["Velocity_Threshold"] = ax2.plot(ax2.get_xlim(), (v_thr, v_thr), c="k", label="Encoder", alpha = 0.5, lw = 2, ls = "--")[0]
    ax2.grid(False)
    ax2.set_ylim((-5, 50))
    ax2.set_ylabel("Velocity (cm/s)")
axs[1].legend(_legend.values(), _legend.keys(), bbox_to_anchor=(1.05, 0.5), loc='center left', borderaxespad=0.)



axs[0].stairs(software_events.streams.RewardAvailableInPatch.data["data"].values[:-1],
          software_events.streams.RewardAvailableInPatch.data["data"].index.values -  zero_index,
          lw = 3, color = 'k', fill=0)
axs[0].set_xlabel("Time(s)")
axs[0].grid(False)
axs[0].set_xlim(win)
axs[0].get_xaxis().set_visible(False)
axs[0].set_ylim(bottom=-1, top = 4)
axs[0].set_yticks([0,3])
axs[0].yaxis.tick_right()
axs[0].set_ylabel('Reward')
axs[0].yaxis.set_label_position("right")

if save_name is not None:
    plt.savefig(f"{save_name}_time.svg", bbox_inches='tight', pad_inches=0.1, transparent=True)
plt.show()

In [None]:
# Start plotting raster
operation_control.streams.CurrentPosition.load_from_file()
position = operation_control.streams.CurrentPosition

fig, ax = plt.subplots(figsize=(25,5))
_legend = {}

_sites = add_position(sites, position=position)

for idx, site in enumerate(sites.iloc[:-1].iterrows()):
    site_label = site[1]["data"]["label"]
    if site_label == "Reward":
        site_label = f"{site_label}_{site[1]['data']['odor']['index']}"
        facecolor = label_dict[site_label]
    else:
        facecolor = label_dict[site_label]

    p = Rectangle(
        (_sites["Position"].iloc[idx], -2), _sites["Position"].iloc[idx+1] - _sites["Position"].iloc[idx], 8,
        linewidth = 0, facecolor = facecolor, alpha = 0.5 )
    _legend[site_label] = p
    ax.add_patch(p)

s, lw = 400, 2
# Plotting raster
y_idx = 0
_legend["ChoiceFeedback"] = ax.scatter(add_position(choice_feedback, position=position)["Position"].values,
           choice_feedback.index * 0 + y_idx,
           marker=".", s=s, lw=lw, c='#e6ab02',
           label="ChoiceFeedback")
y_idx += 1
_legend["Lick"] = ax.scatter(add_position(lick_onset, position=position)["Position"].values,
           lick_onset.index * 0 + y_idx,
           marker="|", s=s, lw=lw, c='k',
           label="Lick")
_legend["ValveOpen"] = ax.scatter(add_position(valve_output_pulse, position=position)["Position"].values,
           valve_output_pulse.index*0 + y_idx,
           marker=".", s=s, lw=lw, c='#ff7f00',
           label="ValveOpen")
y_idx += 1

ax.set_yticklabels([])
ax.set_xlabel("VrSpace(cm)")
ax.set_xlim(_sites.loc[(_sites.index - zero_index > win[0]) & (_sites.index - zero_index < win[1]), :]["Position"].values[[0, -1]])
ax.set_ylim(bottom=-1, top = 3)
ax.grid(False)

if show_speed:
    ax2 = ax.twinx()
    _legend["Velocity"] = ax2.plot(add_position(encoder, position=position)["Position"].values, encoder.values, c="k", label="Encoder", alpha = 0.8)[0]
    v_thr = config.streams.TaskLogic.data["operationControl"]["positionControl"]["stopResponseConfig"]["velocityThreshold"]
    _legend["Velocity_Threshold"] = ax2.plot(ax2.get_xlim(), (v_thr, v_thr), c="k", label="Encoder", alpha = 0.5, lw = 2, ls = "--")[0]
    ax2.grid(False)
    ax2.set_ylim((-5, 50))
    ax2.set_ylabel("Velocity (cm/s)")
ax.legend(_legend.values(), _legend.keys(), bbox_to_anchor=(1.05, 0.5), loc='center left', borderaxespad=0.)
if save_name is not None:
    plt.savefig(f"{save_name}_space.svg", bbox_inches='tight', pad_inches=0.1, transparent=True)
plt.show()




In [None]:
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
sites["label"] = sites.apply(lambda x : x["data"]["label"], axis=1)
for i, (label, group) in enumerate(sites.groupby("label")):
    size = group.apply(lambda x : x["data"]["length"], axis=1)
    axs[i].hist(size.values, bins= np.arange(30, 100, 2.5), density = 1)
    axs[i].vlines(size.mean(), 0, 0.1, color="r", label="Mean")
    axs[i].set_title(f"{label} // Mean: {size.mean():.2f}")
    axs[i].set_xlabel("Length (cm)")
    if i == 0:
        axs[i].set_ylabel("Density")
    axs[i].set_xlim((30, 100))
plt.show()

In [None]:
# align on site entry


window = (-0.5, 1)
for label, label_group in sites.groupby("label"):
    fig, axs = plt.subplots(2,1)
    dist = []
    for index, site in label_group.iterrows():
        this_window = np.array((index - window[0], index + window[1]))
        trace = encoder.loc[index + window[0] : index + window[1]]
        dist.append(encoder.loc[index + window[0] : index].values.mean())
        axs[1].plot(trace.index - index, trace.values, color = 'k', alpha=0.01)

    axs[1].set_xlabel("Time (s)")
    axs[1].set_ylabel("Speed(cm/s)")
    axs[1].set_ylim((-10, 60))

    axs[1].hist(np.array(dist), bins=np.arange(-10, 60, 2.5), density=1)
    axs[1].vlines(np.array(dist).mean(), 0, 0.1, color="r", label="Mean")
    axs[1].set_xlabel(f"Average speed (cm/s) @ {window[0]}:{0}")
    fig.suptitle(label)
