In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
from scipy.io import loadmat
from sklearn.metrics import confusion_matrix

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import hsv_to_rgb
from matplotlib.patches import Circle

In [None]:
%load_ext autoreload
%autoreload 2
plt.style.use('bayesee.academic')

In [None]:
repo_path = Path.cwd().parents[0]
print(repo_path)

In [None]:
subject = "AZ_t1"
file_name = repo_path / f"data/covert-search/large-field/p4_data_{subject}.mat"
data = loadmat(file_name)
print(data.keys())

In [None]:
target_amplitude = data["targetAmplitude"]
target_location = data["tLocation"]
spot_center = data["spotCenters"]
human_response = data["hResponse"]
n_location = data["nLocations"][0][0]
spot_diameter = data["spotLength"][0][0]
stimulus_design_size = data["totalLength"][0][0]
monitor_width = data["monitorPx"][0][0]
monitor_height = data["monitorPx"][0][1]
ppd = data["ppd"][0][0]

In [None]:
shifted_spot_center = spot_center.copy()
shifted_spot_center[:, 0] += (monitor_height - stimulus_design_size) // 2
shifted_spot_center[:, 1] += (monitor_width - stimulus_design_size) // 2

accurate_response = target_location == human_response

In [None]:
spatial_statistics = pd.DataFrame({"id": range(1, n_location)})

In [None]:
list_spot_region = [np.zeros(monitor_width, monitor_height)] * (n_location - 1)
pixel_col, pixel_row = np.meshgrid(np.arange(monitor_width), np.arange(monitor_height))

for index_location in range(n_location - 1):
    list_spot_region[index_location] = (
        pixel_row - shifted_spot_center[index_location, 0]
    ) ** 2 + (
        pixel_col - shifted_spot_center[index_location, 1]
    ) ** 2 <= spot_diameter**2 / 4

In [None]:
stimulus_region = np.zeros((monitor_height, monitor_width))

for index_location in range(n_location - 1):
    stimulus_region[list_spot_region[index_location]] = 1

In [None]:
fig, ax = plt.subplots()
ax.imshow(stimulus_region)

for index_location in range(n_location - 1):
    ax.text(
        shifted_spot_center[index_location, 1],
        shifted_spot_center[index_location, 0],
        f"{index_location+1}",
        ha="center",
        va="center",
    )

circle = Circle(
    (monitor_width // 2, monitor_height // 2),
    spot_diameter * 1.5,
    edgecolor="r",
    facecolor="none",
    linewidth=3,
)
ax.add_patch(circle)

In [None]:
array_eccentral_distance = np.zeros((n_location - 1,))

for index_location in range(n_location - 1):
    array_eccentral_distance[index_location] = np.sqrt(
        (spot_center[index_location, 0] - stimulus_design_size // 2) ** 2
        + (spot_center[index_location, 1] - stimulus_design_size // 2) ** 2
    )

array_eccentral_distance /= ppd

print(array_eccentral_distance)

In [None]:
pixel_precision_array_eccentral_distance = array_eccentral_distance.copy()

for index_d1, distance1 in enumerate(np.unique(array_eccentral_distance)):
    for index_d2, distance2 in enumerate(
        np.unique(array_eccentral_distance)[index_d1 + 1 :]
    ):
        if distance1 != distance2 and np.abs(distance2 - distance1) < 0.5:
            print(index_d1, distance1, index_d2, distance2)
            pixel_precision_array_eccentral_distance[
                array_eccentral_distance == distance2
            ] = distance1

print(pixel_precision_array_eccentral_distance)

In [None]:
spatial_statistics["ecc"] = pixel_precision_array_eccentral_distance.round(3)

In [None]:
eccentral_distance_sorted_indexes = np.argsort(
    pixel_precision_array_eccentral_distance, kind="stable"
)
extra_eccentral_distance_sorted_indexes = np.insert(
    eccentral_distance_sorted_indexes + 1, 0, 0
)
print(extra_eccentral_distance_sorted_indexes)

In [None]:
indexes_near = np.arange(1, n_location)[
    pixel_precision_array_eccentral_distance < spot_diameter * 1.5 / ppd
]
indexes_far = np.arange(1, n_location)[
    pixel_precision_array_eccentral_distance > spot_diameter * 1.5 / ppd
]
print(indexes_near, indexes_far)

In [None]:
near_region = np.zeros((monitor_height, monitor_width))

for index_location in range(n_location - 1):
    if index_location + 1 in indexes_near:
        near_region[list_spot_region[index_location]] = 1

In [None]:
fig, ax = plt.subplots()
ax.imshow(near_region)

for index_location in range(n_location - 1):
    if index_location + 1 in indexes_near:
        ax.text(
            shifted_spot_center[index_location, 1],
            shifted_spot_center[index_location, 0],
            f"{index_location+1}",
            ha="center",
            va="center",
        )

In [None]:
accuracy_all = accurate_response.mean()
accuracy_near = accurate_response[np.isin(target_location, indexes_near)].mean()
accuracy_far = accurate_response[np.isin(target_location, indexes_far)].mean()
print(accuracy_all, accuracy_near, accuracy_far)

In [None]:
confusion_mat = confusion_matrix(
    human_response.flatten(), target_location.flatten(), labels=range(n_location)
)
sorted_confusion_mat = confusion_mat[extra_eccentral_distance_sorted_indexes, :][
    :, extra_eccentral_distance_sorted_indexes
]
array_index_location = np.arange(n_location)[extra_eccentral_distance_sorted_indexes]

In [None]:
correct_rejection = sorted_confusion_mat[0][0]
second_largest_confusion_mat = np.max(
    np.delete(sorted_confusion_mat, np.where(sorted_confusion_mat == correct_rejection))
)
plotted_sorted_confusion_mat = sorted_confusion_mat.copy()
plotted_sorted_confusion_mat[0][0] = second_largest_confusion_mat
print(second_largest_confusion_mat)

In [None]:
fig, ax = plt.subplots(figsize=(15, 12))
im = ax.imshow(plotted_sorted_confusion_mat)

for index_target_location in array_index_location:
    for index_human_response in array_index_location:
        if not (-index_target_location == 0 and index_human_response == 0):
            ax.text(
                index_human_response,
                index_target_location,
                f"{sorted_confusion_mat[index_target_location, index_human_response]:.0f}",
                ha="center",
                va="center",
            )

ax.annotate(
    f"{correct_rejection:.0f}",
    xy=(0, 0),
    xytext=(-0.5, -1),
    arrowprops=dict(facecolor="black", shrink=0.05),
)

ax.set(
    xticks=np.arange(n_location),
    yticks=np.arange(n_location),
    xticklabels=array_index_location,
    yticklabels=array_index_location,
    xlabel="Target location sorted by eccentral distance",
    ylabel="Human response sorted by eccentral distance",
)

inner_ax = fig.add_axes([0.25, 0.3, 0.2, 0.32])
inner_ax.imshow(stimulus_region)
inner_ax.set(xticks=[], yticks=[])

for index_location in range(n_location - 1):
    inner_ax.text(
        shifted_spot_center[index_location, 1],
        shifted_spot_center[index_location, 0],
        f"{index_location+1}",
        ha="center",
        va="center",
        fontsize=16,
    )

plt.show()

In [None]:
local_hit_rate = np.array(
    [
        confusion_mat[index_location + 1, index_location + 1]
        / confusion_mat[:, index_location + 1].sum()
        for index_location in range(n_location - 1)
    ]
)

local_miss_rate = np.array(
    [
        confusion_mat[0, index_location + 1]
        / confusion_mat[:, index_location + 1].sum()
        for index_location in range(n_location - 1)
    ]
)

local_fa_rate = np.array(
    [
        confusion_mat[index_location + 1, 0] / confusion_mat[:, 0].sum()
        for index_location in range(n_location - 1)
    ]
)

cr = (confusion_mat[0, 0]).sum() / confusion_mat[:, 0].sum()

In [None]:
spatial_statistics["hit_rate"] = local_hit_rate
spatial_statistics["miss_rate"] = local_miss_rate
spatial_statistics["fa_rate"] = local_fa_rate

In [None]:
for statistic in [local_hit_rate, local_miss_rate, local_fa_rate]:
    statistic_region = np.empty((monitor_height, monitor_width))
    for index_location in range(n_location - 1):
        if index_location + 1 in indexes_near:
            statistic_region[list_spot_region[index_location]] = statistic[
                index_location
            ]

    fig, ax = plt.subplots()
    ax.imshow(statistic_region)

    for index_location in range(n_location - 1):
        ax.text(
            shifted_spot_center[index_location, 1],
            shifted_spot_center[index_location, 0],
            f"{statistic[index_location]:.3f}",
            ha="center",
            va="center",
        )

    plt.show()

print(f"Correct Rejection:{cr}")

In [None]:
local_present_count = np.array(
    [
        confusion_mat[:, index_location + 1].sum()
        for index_location in range(n_location - 1)
    ]
)

local_hit_count = np.array(
    [
        confusion_mat[index_location + 1, index_location + 1]
        for index_location in range(n_location - 1)
    ]
)

local_miss_count = np.array(
    [confusion_mat[0, index_location + 1] for index_location in range(n_location - 1)]
)

local_fa_count = np.array(
    [confusion_mat[index_location + 1, 0] for index_location in range(n_location - 1)]
)

local_fh_from_count = np.array(
    [
        confusion_mat[1:, index_location + 1].sum()
        - confusion_mat[index_location + 1, index_location + 1]
        for index_location in range(n_location - 1)
    ]
)

local_fh_to_count = np.array(
    [
        confusion_mat[index_location + 1, 1:].sum()
        - confusion_mat[index_location + 1, index_location + 1]
        for index_location in range(n_location - 1)
    ]
)

In [None]:
spatial_statistics["present_count"] = local_present_count
spatial_statistics["hit_count"] = local_hit_count
spatial_statistics["miss_count"] = local_miss_count
spatial_statistics["fa_count"] = local_fa_count
spatial_statistics["fh_from_count"] = local_fh_from_count
spatial_statistics["fh_to_count"] = local_fh_to_count

In [None]:
for count in [
    local_present_count,
    local_hit_count,
    local_miss_count,
    local_fa_count,
    local_fh_from_count,
    local_fh_to_count,
]:
    count_region = np.empty((monitor_height, monitor_width))
    for index_location in range(n_location - 1):
        count_region[list_spot_region[index_location]] = count[index_location]

    fig, ax = plt.subplots()
    ax.imshow(count_region)

    for index_location in range(n_location - 1):
        ax.text(
            shifted_spot_center[index_location, 1],
            shifted_spot_center[index_location, 0],
            f"{count[index_location]:.0f}",
            ha="center",
            va="center",
        )

    plt.show()

In [None]:
array_orientation = np.ones((n_location - 1,)) * -1

for index_location in range(1, n_location - 1):
    dy = spot_center[index_location, 0] - stimulus_design_size // 2
    dx = stimulus_design_size // 2 - spot_center[index_location, 1]
    array_orientation[index_location] = (1 - np.arctan2(dy, dx) / np.pi) / 2

print(array_orientation)

In [None]:
spatial_statistics["orientation"] = array_orientation.round(3)

In [None]:
orientation_hsv_color_map = np.zeros((n_location - 1, 3))
orientation_hsv_color_map[:, 0] = array_orientation

for index_location in range(n_location - 1):
    orientation_hsv_color_map[index_location, 1] = 1

    if index_location == 0:
        orientation_hsv_color_map[index_location, 2] = 0
    else:
        orientation_hsv_color_map[index_location, 2] = 0.75

rgb_orientation_color_map = hsv_to_rgb(orientation_hsv_color_map)

In [None]:
spatial_statistics = spatial_statistics.dropna()

In [None]:
near_pixel_precision_array_eccentral_distance = (
    pixel_precision_array_eccentral_distance[indexes_near - 1]
)
near_rgb_orientation_color_map = rgb_orientation_color_map[indexes_near - 1, :]

In [None]:
for statistic in [local_hit_rate, local_miss_rate, local_fa_rate]:
    fig, ax = plt.subplots()
    near_statistic = statistic[indexes_near - 1]
    ax.scatter(
        near_pixel_precision_array_eccentral_distance,
        near_statistic,
        s=250,
        c=near_rgb_orientation_color_map,
    )

    ax.set(xlabel="Eccentral distance (deg)", ylabel="Proportion")

    inner_ax = fig.add_axes([0.2, 0.2, 0.288, 0.461])
    inner_ax.imshow(near_region)
    inner_ax.set(xticks=[], yticks=[])

    for index_location in range(n_location - 1):
        if index_location + 1 in indexes_near:
            circle = Circle(
                (
                    shifted_spot_center[index_location, 1],
                    shifted_spot_center[index_location, 0],
                ),
                spot_diameter / 2,
                color=rgb_orientation_color_map[index_location],
                linewidth=1,
            )
            inner_ax.add_patch(circle)

    plt.show()

In [None]:
orientation_color_palette = dict(
    zip(spatial_statistics["orientation"], near_rgb_orientation_color_map)
)

In [None]:
for statistic in ["hit_rate", "miss_rate", "fa_rate"]:
    fig, ax = plt.subplots()
    sns.barplot(
        data=spatial_statistics,
        x="ecc",
        y=statistic,
        hue="orientation",
        palette=orientation_color_palette,
        width=0.4,
        ax=ax,
    )
    ax.legend_.remove()
    ax.set(xlabel="Eccentral distance (deg)")

    for i, patch in enumerate(ax.patches):
        if i == 0:
            patch.set_x(patch.get_x() + 0.175)

    inner_ax = fig.add_axes([0.125, 0.7, 0.144, 0.2405])
    inner_ax.imshow(near_region)
    inner_ax.set(xticks=[], yticks=[])

    for index_location in range(n_location - 1):
        if index_location + 1 in indexes_near:
            circle = Circle(
                (
                    shifted_spot_center[index_location, 1],
                    shifted_spot_center[index_location, 0],
                ),
                spot_diameter / 2,
                color=rgb_orientation_color_map[index_location],
                linewidth=1,
            )
            inner_ax.add_patch(circle)

    plt.show()

In [None]:
spatial_statistics.to_csv(
    repo_path
    / f"data/covert-search/large-field/derived/p4_spatial_statistics_{subject}.csv",
    index=False,
)