In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
from scipy.io import loadmat
from scipy.stats import norm
import pickle

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import hsv_to_rgb
from matplotlib.patches import Circle

In [None]:
from bayesee.evaluation import *

In [None]:
%load_ext autoreload
%autoreload 2
plt.style.use('bayesee.academic')

In [None]:
repo_path = Path.cwd().parents[0]
print(repo_path)

In [None]:
prior = np.array((0.5, 0.25, 0.25))
assert np.allclose(prior.sum(), 1.0)
log_prior_ratio = np.log(prior / prior[0])
log_likelihood_ratio = np.zeros_like(prior)

n_trial = 80000

target_location = np.zeros((n_trial,), dtype=np.int64)
target_location[: n_trial // 2] = np.random.randint(1, 3, size=n_trial // 2)
response_location = np.zeros_like(target_location, dtype=np.int64)

In [None]:
array_dp_base = np.array((1.25, 1.875, 2.5, 3.75, 5.0))

for dp_base in array_dp_base:
    array_ratio = np.linspace(0, 2, 20)
    model_simulation = pd.DataFrame()

    for ratio in array_ratio:
        ratio_local_dp = np.array([dp_base, dp_base * ratio])

        for index_trial in range(n_trial):
            array_standard_normal = np.random.normal(size=(2,))
            log_likelihood_ratio[1:] = (
                array_standard_normal * ratio_local_dp - ratio_local_dp**2 / 2
            )
            if target_location[index_trial] > 0:
                log_likelihood_ratio[target_location[index_trial]] += (
                    ratio_local_dp[target_location[index_trial] - 1] ** 2
                )

            log_posterior_ratio = log_prior_ratio + log_likelihood_ratio
            response_location[index_trial] = np.argmax(log_posterior_ratio)

        model_simulation = pd.concat(
            [
                model_simulation,
                pd.DataFrame(
                    {
                        "ratio": ratio,
                        "location": target_location,
                        "response_location": response_location,
                    }
                ),
            ],
            ignore_index=True,
        )

    file_name = (
        repo_path
        / f"data/covert-search/large-field/no-sync/scaled_duplet_ideal_observer_base{dp_base}.csv"
    )

    model_simulation.to_csv(
        file_name,
        index=False,
    )

In [None]:
array_color = [
    "#BD5500",
    "#0077BB",
    "#33BBEE",
    "#EE3377",
    "#009988",
    "#CC3311",
    "#BBBBBB",
]

fig, axs = plt.subplots(nrows=3, ncols=1, figsize=(8, 20))
for dp_base_index, dp_base in enumerate(array_dp_base):
    file_name = (
        repo_path
        / f"data/covert-search/large-field/no-sync/scaled_duplet_ideal_observer_base{dp_base}.csv"
    )

    model_simulation = pd.read_csv(file_name)
    array_ratio = model_simulation["ratio"]
    target_location = model_simulation["location"]
    response_location = model_simulation["response_location"]
    accurate_response = target_location == response_location

    unique_ratio = array_ratio.unique()
    overall_accuracy = np.zeros_like(unique_ratio)
    overall_cr_rate = np.zeros_like(unique_ratio)
    overall_hit_rate = np.zeros_like(unique_ratio)

    for ratio_index, ratio in enumerate(unique_ratio):
        ratio_condition = array_ratio == ratio
        overall_accuracy[ratio_index] = accurate_response[ratio_condition].mean()
        overall_cr_rate[ratio_index] = accurate_response[
            ratio_condition & (target_location == 0)
        ].mean()
        overall_hit_rate[ratio_index] = accurate_response[
            ratio_condition & (target_location != 0)
        ].mean()

    axs[0].plot(
        unique_ratio,
        overall_accuracy,
        c=array_color[dp_base_index],
        label=f"dp_base={dp_base}",
    )

    axs[1].plot(unique_ratio, overall_cr_rate, c=array_color[dp_base_index])
    axs[2].plot(unique_ratio, overall_hit_rate, c=array_color[dp_base_index])

    if dp_base_index < 3:
        axs[0].axhline(
            overall_accuracy[-1], c=array_color[dp_base_index], ls="--", lw=1
        )
        axs[1].axhline(overall_cr_rate[-1], c=array_color[dp_base_index], ls="--", lw=1)
        axs[2].axhline(
            overall_hit_rate[-1], c=array_color[dp_base_index], ls="--", lw=1
        )

axs[0].axvline(0.5, c="k", ls="--", lw=1)
axs[1].axvline(0.5, c="k", ls="--", lw=1)
axs[2].axvline(0.5, c="k", ls="--", lw=1)

axs[0].legend(loc="best", fontsize=12)

axs[0].set(ylabel="Accuracy")
axs[1].set(ylabel="CR Rate")
axs[2].set(xlabel="Duplet d' Ratio", ylabel="Hit Rate")