# Simulation and parameter recovery of dynamic foraging task.

In [None]:
pip install -e /root/capsule

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2
import numpy as np
from aind_behavior_gym.dynamic_foraging.task import CoupledBlockTask, UncoupledBlockTask
from aind_dynamic_foraging_models.generative_model import ForagerCollection

## Get all foragers

In [None]:
forager_collection = ForagerCollection()
df = forager_collection.get_all_foragers()
df.columns

In [None]:
df[["agent_class_name", "preset_name", "n_free_params", "params"]]

## Initialize an agent

In [None]:
# Initialize the model
forager_gen = ForagerCollection().get_preset_forager("CompareToThreshold", seed=42)
forager_gen.set_params(
    threshold=0.2,
    learn_rate=0.3,
    softmax_inverse_temperature=10,
    biasL=3,
)

# forager_gen = ForagerCollection().get_preset_forager("Hattori2019", seed=42)
# forager_gen.set_params(
#     learn_rate_rew=0.2659, 
#     learn_rate_unrew=0.0643, 
#     forget_rate_unchosen=0.5541,
#     softmax_inverse_temperature=5.1442,
#     biasL=0.5099,
# )

# forager_gen = ForagerCollection().get_preset_forager("Rescorla-Wagner", seed=42)
# forager_gen.set_params(
#     biasL=0,
# )


# Create the task environment
# task = CoupledBlockTask(reward_baiting=True, num_trials=1000, seed=42)
task = UncoupledBlockTask(reward_baiting=False, num_trials=1000, seed=53)

## Simulation the task

In [None]:
# Run the model
forager_gen.perform(task)

# Capture the results
ground_truth_params = forager_gen.params.model_dump()
ground_truth_choice_prob = forager_gen.choice_prob
# ground_truth_value = forager_gen.value
# Get the history
choice_history = forager_gen.get_choice_history()
reward_history = forager_gen.get_reward_history()

# Plot the session results
fig, axes = forager_gen.plot_session(if_plot_latent=True)

## Parameter Recovery

In [None]:
# Fit the model to recover parameters
forager_fit = ForagerCollection().get_preset_forager("CompareToThreshold", seed=42)
# forager_fit = ForagerCollection().get_preset_forager("Hattori2019", seed=42)
forager_fit.fit(
    choice_history,
    reward_history,
    # fit_bounds_override={"softmax_inverse_temperature": [0, 100]},
    clamp_params={
        # "biasL": 0, 
        # "softmax_inverse_temperature": 5.0
    },
    DE_kwargs=dict(
        workers=4, 
        disp=True, 
        seed=np.random.default_rng(42)
    ),
    k_fold_cross_validation=None,
)

fitting_result = forager_fit.fitting_result

In [None]:
# Check fitted parameters
fit_names = fitting_result.fit_settings["fit_names"]
ground_truth = [num for name, num in ground_truth_params.items() if name in fit_names]
print(f"Num of trials: {len(choice_history)}")
print(f"Fitted parameters: {fit_names}")
print(f'Ground truth: {[f"{num:.4f}" for num in ground_truth]}')
print(f'Fitted:       {[f"{num:.4f}" for num in fitting_result.x]}')
print(f"Likelihood-Per-Trial: {fitting_result.LPT}")
print(f"Prediction accuracy full dataset: {fitting_result.prediction_accuracy}\n")

In [None]:
# Plot the fitted session results
fig_fitting, axes = forager_fit.plot_fitted_session(if_plot_latent=True)

# Overlay the ground truth Q-values for comparison
# axes[0].plot(ground_truth_q_value[0], lw=1, color="red", ls="-", label="actual_Q(L)")
# axes[0].plot(ground_truth_q_value[1], lw=1, color="blue", ls="-", label="actual_Q(R)")""
# axes[0].legend(fontsize=6, loc="upper left", bbox_to_anchor=(0.6, 1.3), ncol=4)