In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
from scipy.io import loadmat
from scipy.stats import norm
import pickle

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import hsv_to_rgb
from matplotlib.patches import Circle

In [None]:
from bayesee.evaluation import *

In [None]:
%load_ext autoreload
%autoreload 2
plt.style.use('bayesee.academic')

In [None]:
repo_path = Path.cwd().parents[0]
print(repo_path)

In [None]:
subject = "AZ"
file_name = repo_path / f"data/covert-search/large-field/p3_data_{subject}.pickle"

with open(file_name, "rb") as f:
    stimulus, response = pickle.load(f)

metadata = stimulus["metadata"]
spot_centers = metadata["spot_centers"]
monitor_width, monitor_height = metadata["monitor_size"]
stimulus_size = metadata["stimulus_size"]
n_location = metadata["n_location"]
spot_size = metadata["spot_size"]
stimulus_ppd = metadata["stimulus_ppd"]
target_amplitude = metadata["target_amplitude"]
target = metadata["target"]

file_name = (
    repo_path
    / f"data/covert-search/large-field/derived/p2_spatial_statistics_{subject}.csv"
)
spatial_statistics_human = pd.read_csv(file_name)
local_dp = spatial_statistics_human["dp"].values

In [None]:
prior = np.array((0.5, *((0.5 / (n_location - 1),) * (n_location - 1))))
assert np.allclose(prior.sum(), 1.0)
log_prior_ratio = np.log(prior / prior[0])
log_likelihood_ratio = np.zeros_like(prior)

n_trial = 100000

target_location = np.zeros((n_trial,), dtype=np.int64)
target_location[: n_trial // 2] = np.random.randint(1, n_location, size=n_trial // 2)
response_location = np.zeros_like(target_location)

assert np.allclose(np.dot(target.flatten(), target.flatten()), 1.0)

for index_trial in range(n_trial):
    array_standard_normal = np.random.normal(size=(n_location - 1,))
    log_likelihood_ratio[1:] = array_standard_normal * local_dp - local_dp**2 / 2
    if target_location[index_trial] > 0:
        log_likelihood_ratio[target_location[index_trial]] += (
            local_dp[target_location[index_trial] - 1] ** 2
        )

    log_posterior_ratio = log_prior_ratio + log_likelihood_ratio
    response_location[index_trial] = np.argmax(log_posterior_ratio)

accurate_response = target_location == response_location

In [None]:
shifted_spot_center = metadata["spot_centers"].copy()
shifted_spot_center[:, 0] += (monitor_height - stimulus_size) // 2
shifted_spot_center[:, 1] += (monitor_width - stimulus_size) // 2

list_spot_region = [np.zeros(monitor_width, monitor_height)] * (n_location - 1)
pixel_col, pixel_row = np.meshgrid(np.arange(monitor_width), np.arange(monitor_height))

for location_index in range(n_location - 1):
    list_spot_region[location_index] = (
        pixel_row - shifted_spot_center[location_index, 0]
    ) ** 2 + (
        pixel_col - shifted_spot_center[location_index, 1]
    ) ** 2 <= spot_size**2 / 4

In [None]:
stimulus_region = np.zeros((monitor_height, monitor_width))

for location_index in range(n_location - 1):
    stimulus_region[list_spot_region[location_index]] = 1

fig, ax = plt.subplots(figsize=(6, 4))
ax.imshow(stimulus_region)

for location_index in range(n_location - 1):
    ax.text(
        shifted_spot_center[location_index, 1],
        shifted_spot_center[location_index, 0],
        f"{location_index+1}",
        ha="center",
        va="center",
    )

In [None]:
print(
    f"Overall accuracy: {accurate_response.mean():.4f}, near hit: {accurate_response[np.isin(target_location, location_near)].mean():.4f}, far hit: {accurate_response[np.isin(target_location, location_far)].mean():.4f}"
)

print(f"Overall cr_rate: {accurate_response[target_location == 0].mean():.4f}")

In [None]:
spatial_statistics = pd.DataFrame({"location": range(1, n_location)})
spatial_statistics["ecc"] = spatial_statistics_human["ecc"]
spatial_statistics["orientation"] = spatial_statistics_human["orientation"]

In [None]:
hit_rate = np.array(
    [
        sum((target_location == location_index) & (response_location == location_index))
        / sum(target_location == location_index)
        if sum(target_location == location_index) != 0
        else 0
        for location_index in range(1, n_location)
    ]
)

spatial_statistics["hit_rate"] = hit_rate

miss_rate = np.array(
    [
        sum((target_location == location_index) & (response_location == 0))
        / sum(target_location == location_index)
        if sum(target_location == location_index) != 0
        else 0
        for location_index in range(1, n_location)
    ]
)

spatial_statistics["miss_rate"] = miss_rate

n_present = np.array(
    [sum(target_location == location_index) for location_index in range(1, n_location)]
)

spatial_statistics["n_present"] = n_present

In [None]:
spatial_statistics.to_csv(
    repo_path
    / f"data/covert-search/large-field/derived/p3_spatial_statistics_ideal_observer_for_{subject}.csv",
    index=False,
)

In [None]:
orientation_hsv_color_map = np.zeros((n_location - 1, 3))
orientation_hsv_color_map[:, 0] = array_orientation

for location_index in range(n_location - 1):
    orientation_hsv_color_map[location_index, 1] = 1

    if location_index == 0:
        orientation_hsv_color_map[location_index, 2] = 0
    else:
        orientation_hsv_color_map[location_index, 2] = 0.75

rgb_orientation_color_map = hsv_to_rgb(orientation_hsv_color_map)

orientation_color_palette = dict(
    zip(spatial_statistics["orientation"], rgb_orientation_color_map)
)

In [None]:
interested_statistics = ["hit_rate"] + ["miss_rate"]

errors_interested_statistics = [None for _ in range(len(interested_statistics))]
errors_hit_binomial = 2 * np.sqrt(
    hit_rate * (1 - hit_rate) / (n_trial / (n_location - 1))
)
errors_interested_statistics[0] = np.repeat(
    errors_hit_binomial[np.newaxis, :], 2, axis=0
)
errors_interested_statistics[1] = errors_interested_statistics[0]

In [None]:
for statistic_index, statistic in enumerate(interested_statistics):
    fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(8, 6))

    bar_width = 0.1

    df_sorted = spatial_statistics.sort_values(by=["ecc", "orientation"], kind="stable")
    for ecc_index, ecc in enumerate(df_sorted["ecc"].unique()):
        ecc_condition = df_sorted["ecc"] == ecc
        orientation_sorted = df_sorted.loc[ecc_condition, "orientation"]
        n_x_bar = len(orientation_sorted)
        x_bar = np.linspace(
            ecc_index - (n_x_bar - 1) * bar_width / 2,
            ecc_index + (n_x_bar + 1) * bar_width / 2,
            n_x_bar,
        )
        y_bar = df_sorted.loc[ecc_condition, statistic]

        if errors_interested_statistics[statistic_index] is None:
            y_error = None
        else:
            y_error = errors_interested_statistics[statistic_index][:, df_sorted.index][
                :, ecc_condition
            ]

        ax.bar(
            x_bar,
            y_bar,
            yerr=y_error,
            color=[
                orientation_color_palette[orientation]
                for orientation in np.sort(orientation_sorted, kind="stable")
            ],
            width=bar_width,
            error_kw={
                "elinewidth": 3,
                "capsize": 6,
                "capthick": 3,
                "alpha": 0.75,
                "ecolor": "orange" if ecc_index == 0 else "k",
            },
        )

    ax.set(
        ylim=(0, 1),
        xlabel="Eccentral distance (deg)",
        ylabel=statistic,
        xticks=range(len(df_sorted["ecc"].unique())),
    )

    ax.set_xticklabels(df_sorted["ecc"].unique())

    inner_ax = fig.add_axes([-0.025, -0.1, 0.144, 0.2405])
    inner_ax.imshow(stimulus_region)
    inner_ax.set(xticks=[], yticks=[])

    for location_index in range(n_location - 1):
        circle = Circle(
            (
                shifted_spot_center[location_index, 1],
                shifted_spot_center[location_index, 0],
            ),
            spot_size / 2,
            color=rgb_orientation_color_map[location_index],
            linewidth=1,
        )
        inner_ax.add_patch(circle)

    plt.show()