---
# Model Architectures
----

Build, simulate & test different architectures for the model.

-> [model theory](https://ccn.studentorg.berkeley.edu/pdfs/papers/WilsonCollins_modelFitting.pdf)

---
```
: zach.wolpe@medibio.com.au
: 02.07.2024
```
---


In [None]:
import plotly.graph_objects as go
from scipy.optimize import minimize # finding optimal params in models
from scipy import stats             # statistical tools
import numpy as np                  # matrix/array functions
import pandas as pd                 # loading and manipulating data
import ipywidgets as widgets        # interactive display
import matplotlib.pyplot as plt     # plotting
%matplotlib inline
from tqdm import tqdm

np.random.seed(2021)                # set seed for reproducibility


from plotly.subplots import make_subplots
import plotly.graph_objects as go


from src.rescorla_wagner_model import (RoscorlaWagner)
from src.rescorla_wagner_model_plots import (RescorlaWagnerPlots)
from src.rescorla_wagner_model_simulation import (RescorlaWagnerSimulate)
from src.rescorla_wagner_model_diagnostics import (RoscorlaWagerModelDiagnostics)

from src.cog_sci_learning_model_base import (MultiArmedBanditModels, add_diag_line)



In [None]:
from src.rescorla_wagner_model import (RoscorlaWagner)
from src.rescorla_wagner_model_plots import (RescorlaWagnerPlots)
from src.rescorla_wagner_model_simulation import (RescorlaWagnerSimulate)
from src.rescorla_wagner_model_diagnostics import (RoscorlaWagerModelDiagnostics)

In [None]:

from src.cog_sci_random_response_model import (RandomResponseModel)
from src.cog_sci_win_stay_lose_shift_model import (WinStayLoseShiftModel)
from src.cog_sci_learning_model_base import (MultiArmedBanditModels)
from src.cog_sci_roscorla_wagner_model import RoscorlaWagnerModel


np.random.seed(2021)                # set seed for reproducibility

---
## Model 1: Random Responding
---

### Initialize

In [None]:
b = 0.3
rr = RandomResponseModel()
rr.simulate(b, N=1000)
rr.simulated_params
# rr.simulated_experiment

### Example fit

In [None]:
b_bounds = (0,1)
action = rr.simulated_experiment['action']
reward = rr.simulated_experiment['reward']

# brute force and scikit optim
res_brute = rr.optimize_brute_force(loss_function=rr.neg_log_likelihood, bounds=b_bounds, actions=action, rewards=reward)
b_hat_brute_force = res_brute['b_pred']

_, _, b_hat_scikit = rr.optimize_scikit(loss_function=rr.neg_log_likelihood, init_guess=[0.5], args=(action, reward), bounds=[b_bounds])
b_hat_scikit = b_hat_scikit[0]

res_brute, b_hat_scikit

In [None]:
res = rr.optimize_scikit(loss_function=rr.neg_log_likelihood, init_guess=[0.5], args=(action, reward), bounds=[b_bounds])
res

### Examine Likelihood

In [None]:
rr.plot_neg_log_likelihood(b_true=b);

### Compare Optimization Procedures

In [None]:
results, res_plot = rr.compare_fitting_procedures()
res_plot;

### Sensitivity Analysis: Parameter Recovery

In [None]:
res, plot = rr.perform_sensitivity_analysis()
print(res.head(2))
plot.show()

----
# Model 2. Noisy win-stay-lose-shift
----

### Init

In [None]:
wsls = WinStayLoseShiftModel()
EPSILON = 0.3
wsls.simulate(EPSILON, N=1000, noise=0)
wsls.simulated_experiment
wsls.simulated_params


action = wsls.simulated_experiment['action']
reward = wsls.simulated_experiment['reward']
wsls.neg_log_likelihood(0.4, action, reward)
wsls.plot_neg_log_likelihood(EPSILON)

### Compare Optimisation Strategies

In [None]:
_, f = wsls.compare_fitting_procedures()
f

In [None]:
wsls.perform_sensitivity_analysis()

---
# Model 3. Roscorla Wagner Model
---

### Instantiate

In [None]:
rwm = RoscorlaWagnerModel()
ALPHA = 0.2
THETA = 3
N = 1000
rwm.simulate(ALPHA, THETA, N=N, noise=True)
rwm.simulated_experiment
rwm.simulated_params


# simulated experiment
action = rwm.simulated_experiment['action']
reward = rwm.simulated_experiment['reward']
rwm.neg_log_likelihood((ALPHA, THETA), action, reward)


## Perform Sensitivity Analysis

In [None]:
# Se    nsitivity Analysis: Parameter Recovery
res, plot = rwm.perform_sensitivity_analysis(
    # alpha_range=np.linspace(0, 1, 10), # [0.4], #np.linspace(0, 1, 10),
    # theta_range = np.linspace(0.1, 10, 10), # [7], # np.linspace(0.1, 10, 10),
    alpha_range=np.linspace(0.1, 1, 10), 
    theta_range=np.linspace(1, 10, 10),
    N=1000,
    log_progress=True)
print(res.head(2))
plot.show()

## Compare Fitting Procedures

In [None]:

# Compare Optimization Procedures
results, res_plot =rwm.compare_fitting_procedures(
    alpha_range=np.linspace(0.1, 1, 10), 
    theta_range=np.linspace(1, 10, 10),
    fit_brute_force=True,
    # fit_scikit=True,
    bounds = [(0.1, 1), (1, 10)],
    N=1000
)
res_plot.show();

results

## re-simulate

In [None]:
rwm.simulate(0.2, 3, N=1000, noise=True)
plt.figure(figsize=(10, 6))
rwm.plot_reward()
plt.show();
plt.figure(figsize=(10, 6))
rwm.plot_Q_estimates();
plt.show();


## Visual Log Likelihood Space

In [None]:
# rwm.simulate(ALPHA, THETA, N=100, noise=True)

# Examine Likelihood
plt, negll, theta_pred, theta_range, alpha_pred, alpha_range = rwm.plot_neg_log_likelihood()

## Brute Force Search

In [None]:
# # test optimization procedures
alpha_bounds = (0,1)
theta_bounds = (.1,10)
bounds = (alpha_bounds, theta_bounds)

# # brute force and scikit optim
brute_force_results = \
    rwm.optimize_brute_force(loss_function=rwm.neg_log_likelihood, bounds=bounds, actions=action, rewards=reward)
alpha_hat_brute_force = brute_force_results['alpha_pred']
theta_hat_brute_force = brute_force_results['theta_pred']
BIC_brute_force = brute_force_results['BIC']

def log_results(alpha_true, theta_true, alpha_pred, theta_pred, BIC=None, name=None):
    msg = f"""
    ----------------------------------------------------------------------------------
        : optimistion class: {name}

        : alpha (true):                     {alpha_true}
        : theta (true):                     {theta_true}
        : alpha (pred):                     {alpha_pred}
        : theta (pred):                     {theta_pred}
        : BIC:                              {BIC}
    ----------------------------------------------------------------------------------
    """
    print(msg)

log_results(ALPHA, THETA, alpha_hat_brute_force, theta_hat_brute_force, BIC_brute_force, 'Brute Force')

## Scikit Optimize

In [None]:

negLL, params_opt, BIC, optimal_init_params = rwm.optimize_scikit_model_over_init_parameters(
        actions=action,
        rewards=reward,
        loss_function=None,
        alpha_init_range=np.linspace(0, 1, 5),
        theta_init_range=np.linspace(.1, 10, 7),
        bounds=((0,1), (1,12)),
        log_progress=True
        )
alpha_hat_sci_opt, theta_hat_sci_opt = params_opt
print('optimal_init_params: ', optimal_init_params)
        

log_results(ALPHA, THETA, alpha_hat_sci_opt, theta_hat_sci_opt, BIC, 'Scikit')

In [None]:
log_results(ALPHA, THETA, alpha_hat_sci_opt, theta_hat_sci_opt, BIC, 'Scikit')

## Scikit Single Run (unreliable)

In [None]:
_, _, params = rwm.optimize_scikit(
    loss_function=rwm.neg_log_likelihood,
    init_guess=[0.2, 1],
    args=(action, reward),
    bounds=bounds)

params, ALPHA, THETA