In [None]:
#warnings
import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

# Core
import sys
import pandas as pd
import numpy as np
from pathlib import Path
from typing import Dict

## Plotting
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns

## Harp/Bonsai
sys.path.append('../../src/')
from bonsai import load_bonsai_config
load_bonsai_config(r"C:\git\AllenNeuralDynamics\aind-vr-foraging\Bonsai")

#Global Viz settings
sns.set_style('darkgrid') # darkgrid, white grid, dark, white and ticks
plt.rc('axes', titlesize=18)     # fontsize of the axes title
plt.rc('axes', labelsize=14)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=13)    # fontsize of the tick labels
plt.rc('ytick', labelsize=13)    # fontsize of the tick labels
plt.rc('legend', fontsize=13)    # legend fontsize
plt.rc('font', size=13)          # controls default text sizes

mpl.rcParams['pdf.fonttype'] = 42
mpl.rcParams['ps.fonttype'] = 42
mpl.rcParams['font.family'] = 'Arial'

default_img_size = (15, 8)


session_path = [
    Path(r"Z:\scratch\vr-foraging\672103\20231013T100814"),
    Path(r"Z:\scratch\vr-foraging\672107\20231013T111657"),
    Path(r"Z:\scratch\vr-foraging\672106\20231013T101026"),
    Path(r"Z:\scratch\vr-foraging\672104\20231013T092240"),
    Path(r"Z:\scratch\vr-foraging\672102\20231012T094718")
    ]
from psychometric_across_animals import load_session_data, parse_reward_sites

session_data = {}
for session in session_path:
    print(session)
    current_session = load_session_data(session)
    animal_id = current_session["config"].streams.Subject.data["metadata"]["subject"]
    session_id = session.name
    session_data[animal_id+"_"+session_id] = current_session
    session_data[animal_id+"_"+session_id]["reward_sites"] = parse_reward_sites(current_session)

In [None]:
session_fraction = .85
fig, ax = plt.subplots(1, 2, figsize=(12,4))

for session in session_data:
    this:Dict[str, any] = session_data[session]
    this_reward_sites: pd.DataFrame = this["reward_sites"]
    session_duration = this_reward_sites.index[-1] - this_reward_sites.index[0]
    threshold = this_reward_sites.index[0] +  session_duration * session_fraction
    for patch_type, patch_type_df in this_reward_sites[this_reward_sites.index < threshold].groupby("patch_label"):
        patch_index = np.where([s["label"] == patch_type for s in this["config"].streams.TaskLogic.data["environmentStatistics"]["patches"]])[0][0]
        patch_init_reward = this["config"].streams.TaskLogic.data["environmentStatistics"]["patches"][patch_index]["patchRewardFunction"]["initialRewardAmount"] > 0
        idx = 0 if patch_init_reward is True else 1
        choice_probability = np.zeros(len(patch_type_df["site_number"].unique()))
        n_choices = np.zeros_like(choice_probability)
        for site_number, site_number_df in patch_type_df.groupby("site_number"):
            choice_probability[site_number] = site_number_df["is_choice"].sum() / len(site_number_df)
            n_choices[site_number] = len(site_number_df)
        ax[idx].plot(
            np.arange(0, choice_probability.shape[0]),
            choice_probability, label=this["config"].streams.Subject.data["metadata"]["subject"],
            lw = 3)


ax[0].set_title("Rewarded")
ax[1].set_title("Non-rewarded")

ax[0].set_xlabel("Visit number")
ax[1].set_xlabel("Visit number")

ax[0].set_ylabel("P(Stay)")
ax[1].set_ylabel("P(Stay)")

ax[0].set_xlim([0, 4])
ax[1].set_xlim([0, 1])

ax[0].set_ylim([-0.05, 1.05])
ax[1].set_ylim([-0.05, 1.05])

ax[1].legend()
ax[0].legend()

plt.savefig("psychometric_85perc.svg", bbox_inches="tight")