In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import glob

import numpy as np
from aind_behavior_gym.dynamic_foraging.task import CoupledBlockTask, UncoupledBlockTask
from aind_dynamic_foraging_models.generative_model import ForagerCollection

In [None]:
from pynwb import NWBHDF5IO

LOCAL_NWB_TMP = "/data/foraging_nwb_bonsai"

def get_nwb_from_local_tmp(session_id):
    """Get NWB file from session_id.

    Overwrite this function to get NWB file from other places.

    Parameters
    ----------
    session_id : _type_
        _description_
    """
    io = NWBHDF5IO(f"{LOCAL_NWB_TMP}/{session_id}.nwb", mode="r")
    nwb = io.read()
    return nwb


def get_history_from_nwb(nwb):
    """Get choice and reward history from nwb file
    
    #TODO move this to aind-behavior-nwb-util
    """

    df_trial = nwb.trials.to_dataframe()

    autowater_offered = (df_trial.auto_waterL == 1) | (df_trial.auto_waterR == 1)
    choice_history = df_trial.animal_response.map({0: 0, 1: 1, 2: np.nan}).values
    reward_history = df_trial.rewarded_historyL | df_trial.rewarded_historyR
    p_reward = [
        df_trial.reward_probabilityL.values,
        df_trial.reward_probabilityR.values,
    ]
    random_number = [
        df_trial.reward_random_number_left.values,
        df_trial.reward_random_number_right.values,
    ]

    baiting = False if "without baiting" in nwb.protocol.lower() else True

    return (
        baiting,
        choice_history,
        reward_history,
        p_reward,
        autowater_offered,
        random_number,
    )

In [None]:
# subject_id = '781370'  # uncoupled, no baiting
# subject_id = '764769'  # uncoupled, baiting
# subject_id = '776293'  # uncoupled, baiting
subject_id = '769884'  # uncoupled, baiting



for session_name in sorted(glob.glob(f'{LOCAL_NWB_TMP}/{subject_id}_*'), reverse=True):
    print('############################################')
    session_id = session_name.split('/')[-1].split('.')[0]
    print(session_id)

    nwb = get_nwb_from_local_tmp(session_id=session_id)
    (
        baiting,
        choice_history,
        reward_history,
        _,
        autowater_offered,
        random_number,
    ) = get_history_from_nwb(nwb)


    # Remove NaNs
    ignored = np.isnan(choice_history)
    choice_history = choice_history[~ignored]
    reward_history = reward_history[~ignored].to_numpy()
    
    # handle invalid sessions if there are too few trials
    # -- Skip if len(valid trials) < 50 --
    if len(choice_history) < 10:
        fit_result = {
            "status": "skipped. valid trials < 50",
            "upload_figs_s3": {},
            "upload_pkls_s3": {},
            "upload_record_docDB": {},
        }
        print(f"Skipping session {session_id} due to too few trials n={len(choice_history)}.")
    
    else:
        # -- Initialize model --
        # forager = ForagerCollection().get_forager(
        #     agent_class_name="ForagerCompareThreshold",
        #     agent_kwargs={
        #         'choice_kernel': "none",
        #     },
        # )
        forager_ctt = ForagerCollection().get_preset_forager("CompareToThreshold")
        fitting_result_ctt, _ = forager_ctt.fit(
            choice_history,
            reward_history,
            clamp_params={
                # "biasL": 0, 
                # "softmax_inverse_temperature": 5.0
            },
            DE_kwargs=dict(
                workers=4, 
                disp=True, 
                seed=np.random.default_rng(42)
            ),
            # k_fold_cross_validation=None
        )

        forager_hattori = ForagerCollection().get_preset_forager("Hattori2019")
        fitting_result_hattori, _ = forager_hattori.fit(
            choice_history,
            reward_history,
            DE_kwargs=dict(
                workers=4, 
                disp=True, 
                seed=np.random.default_rng(42)
            ),
            # k_fold_cross_validation=None
        )


        # Check fitted parameters
        for model_ind, fitting_result in enumerate([fitting_result_ctt, fitting_result_hattori]):
            fit_names = fitting_result.fit_settings["fit_names"]
            print(f'Model: {['CompareToThreshold', 'Hattori'][model_ind]}')
            print(f"Num of trials: {len(choice_history)}")
            print(f"Likelihood-Per-Trial: {fitting_result.LPT}")
            print(f"AIC: {fitting_result.AIC}")
            print(f"BIC: {fitting_result.BIC}")
            print(f"Prediction accuracy full dataset: {fitting_result.prediction_accuracy}")
            print(f"Fitted parameters: {fit_names}")
            print(f'Fitted:       {[f"{num:.4f}" for num in fitting_result.x]}\n')

        
        fig_fitting_ctt, axes_ctt = forager_ctt.plot_fitted_session(if_plot_latent=True)
        fig_fitting_hattori, axes_hattori = forager_hattori.plot_fitted_session(if_plot_latent=True)

        fig_fitting_ctt.savefig(f'/results/{session_id}-ctt.png', dpi=150)
        fig_fitting_hattori.savefig(f'/results/{session_id}-hattori.png', dpi=150)



In [None]:
# -- Load data --
# session_id = '781896_2025-04-10_14-11-57'

# session_id = '781370_2025-02-03_11-09-28'
# session_id = '781370_2025-02-05_11-25-51'
session_id = '781370_2025-03-20_11-12-56'
# session_id = '781370_2025-02-14_11-26-21'
# session_id = '781370_2025-02-17_11-11-23'

# session_id = '784806_2025-04-21_13-13-39'

# session_id = '770527_2025-01-15_11-01-55'

# session_id = '739977_2024-10-03_09-04-34'

# session_id = '786866_2025-04-10_11-24-47'


nwb = get_nwb_from_local_tmp(session_id=session_id)
(
    baiting,
    choice_history,
    reward_history,
    _,
    autowater_offered,
    random_number,
) = get_history_from_nwb(nwb)

In [None]:
# Remove NaNs
ignored = np.isnan(choice_history)
choice_history = choice_history[~ignored]
reward_history = reward_history[~ignored].to_numpy()

# -- Skip if len(valid trials) < 50 --
if len(choice_history) < 50:
    fit_result = {
        "status": "skipped. valid trials < 50",
        "upload_figs_s3": {},
        "upload_pkls_s3": {},
        "upload_record_docDB": {},
    }

# -- Initialize model --
# forager = ForagerCollection().get_forager(
#     agent_class_name="ForagerCompareThreshold",
#     agent_kwargs={
#         'choice_kernel': "none",
#     },
# )

forager_ctt = ForagerCollection().get_preset_forager("CompareToThreshold")
fitting_result_ctt, _ = forager_ctt.fit(
    choice_history,
    reward_history,
    clamp_params={
        # "biasL": 0, 
        # "softmax_inverse_temperature": 5.0
    },
    DE_kwargs=dict(
        workers=4, 
        disp=True, 
        seed=np.random.default_rng(42)
    ),
    # k_fold_cross_validation=None
)


forager_hattori = ForagerCollection().get_preset_forager("Hattori2019")
fitting_result_hattori, _ = forager_hattori.fit(
    choice_history,
    reward_history,
    DE_kwargs=dict(
        workers=4, 
        disp=True, 
        seed=np.random.default_rng(42)
    ),
    # k_fold_cross_validation=None
)

In [None]:
# Check fitted parameters
for model_ind, fitting_result in enumerate([fitting_result_ctt, fitting_result_hattori]):
    fit_names = fitting_result.fit_settings["fit_names"]
    print(f'Model: {['CompareToThreshold', 'Hattori'][model_ind]}')
    print(f"Num of trials: {len(choice_history)}")
    print(f"Likelihood-Per-Trial: {fitting_result.LPT}")
    print(f"AIC: {fitting_result.AIC}")
    print(f"BIC: {fitting_result.BIC}")
    print(f"Prediction accuracy full dataset: {fitting_result.prediction_accuracy}")
    print(f"Fitted parameters: {fit_names}")
    print(f'Fitted:       {[f"{num:.4f}" for num in fitting_result.x]}\n')

In [None]:
fig_fitting, axes = forager_ctt.plot_fitted_session(if_plot_latent=True)
fig_fitting, axes = forager_hattori.plot_fitted_session(if_plot_latent=True)