# RSA

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import itertools
from collections.abc import Callable

import numpy as np
import pandas as pd
import seaborn as sns
import torch
import xarray as xr
from bonner.caching import cache
from bonner.computation.metrics import pearson_r, spearman_r
from bonner.plotting import DEFAULT_MATPLOTLIBRC, save_figure
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

from lib.datasets import (
    compute_shared_stimuli,
    filter_by_stimulus,
    nsd,
    split_by_repetition,
)
from lib.spectra import CrossDecomposition
from lib.utilities import MANUSCRIPT_HOME

FIGURES_HOME = MANUSCRIPT_HOME / "figures"

sns.set_theme(context="paper", style="ticks", rc=DEFAULT_MATPLOTLIBRC)

In [None]:
def compute_rsm(x: torch.Tensor) -> torch.Tensor:
    return pearson_r(
        x.transpose(-2, -1),
        correction=0,
        return_diagonal=False,
    )


def extract_upper_triangle(rsm: torch.Tensor) -> torch.Tensor:
    x_indices, y_indices = torch.triu_indices(rsm.shape[-1], rsm.shape[-1], offset=1)
    return rsm[..., x_indices, y_indices]


def convert_dataarray_to_tensor(x: xr.DataArray) -> torch.Tensor:
    return torch.from_numpy(x.to_numpy()).to(dtype=torch.float64)


def _get_correlation_function(correlation: str) -> Callable:
    match correlation:
        case "Pearson":
            func = pearson_r
        case "Spearman":
            func = spearman_r
        case _:
            raise ValueError
    return func


def reconstruct_data(x: xr.DataArray, /, *, n: int) -> xr.DataArray:
    cross_decomposition = CrossDecomposition(randomized=True)
    cross_decomposition.fit(x, x)
    x_transformed = cross_decomposition.transform(x, direction="right")

    return cross_decomposition.inverse_transform(
        x_transformed,
        direction="right",
        components=n,
    )


def compute_rsa_correlation(
    rsm_x: torch.Tensor,
    rsm_y: torch.Tensor,
    /,
    *,
    correlation: str,
    n_bootstraps: int = 5_000,
    subsample_fraction: float = 0.9,
    seed: int = 0,
    batch_size: int = 500,
) -> tuple[float, torch.Tensor]:
    func = _get_correlation_function(correlation)

    n_stimuli = rsm_x.shape[-1]

    rng = np.random.default_rng(seed=seed)

    r_bootstrapped = []

    for bootstrap_indices in tqdm(
        itertools.batched(range(n_bootstraps), n=batch_size),
        desc="bootstrap",
        leave=False,
    ):
        rsms_x, rsms_y = [], []
        for _ in bootstrap_indices:
            samples = rng.permutation(n_stimuli)[: int(subsample_fraction * n_stimuli)]
            rsms_x.append(rsm_x[samples, :][:, samples])
            rsms_y.append(rsm_y[samples, :][:, samples])

        r_bootstrapped.append(
            func(
                extract_upper_triangle(torch.stack(rsms_x)).T,
                extract_upper_triangle(torch.stack(rsms_y)).T,
                return_diagonal=True,
            ),
        )

    r = float(
        func(
            extract_upper_triangle(rsm_x),
            extract_upper_triangle(rsm_y),
        ),
    )
    return r, torch.concatenate(r_bootstrapped)


def clean_rsm(rsm: torch.Tensor, /, *, indices: np.ndarray) -> np.ndarray:
    rsm = rsm[indices, :][:, indices]
    rsm = torch.triu(rsm, diagonal=1)
    rsm[rsm == 0] = torch.nan
    return rsm.cpu().numpy()


def compute_all_pairwise_rsa_correlations(
    rsms: dict[str, dict[int, dict[int, torch.Tensor]]],
) -> pd.DataFrame:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    rows = []
    for metric in ("Pearson", "Spearman"):
        prefix = f"figures/rsa_correlations/metric={metric}"
        for subject_1, subject_2 in itertools.product(range(nsd.N_SUBJECTS), repeat=2):
            cacher = cache(f"{prefix}/high-D-high-D-S{subject_1}T0-S{subject_2}T1.pkl")
            mean_1, bootstrap_1 = cacher(compute_rsa_correlation)(
                rsms["high-D"][subject_1][0].to(device),
                rsms["high-D"][subject_2][1].to(device),
                correlation=metric,
            )
            cacher = cache(f"{prefix}/high-D-low-D-S{subject_1}T0-S{subject_2}T1.pkl")
            mean_2, bootstrap_2 = cacher(compute_rsa_correlation)(
                rsms["high-D"][subject_1][0].to(device),
                rsms["low-D"][subject_2][1].to(device),
                correlation=metric,
            )
            cacher = cache(f"{prefix}/low-D-high-D-S{subject_1}T0-S{subject_2}T1.pkl")
            mean_3, bootstrap_3 = cacher(compute_rsa_correlation)(
                rsms["low-D"][subject_1][0].to(device),
                rsms["high-D"][subject_2][1].to(device),
                correlation=metric,
            )
            cacher = cache(f"{prefix}/low-D-low-D-S{subject_1}T0-S{subject_2}T1.pkl")
            mean_4, bootstrap_4 = cacher(compute_rsa_correlation)(
                rsms["low-D"][subject_1][0].to(device),
                rsms["low-D"][subject_2][1].to(device),
                correlation=metric,
            )
            means = torch.Tensor(
                [
                    mean_1,
                    mean_2,
                    mean_3,
                    mean_4,
                    (mean_2 + mean_3) / 2,
                ],
            )
            bootstraps = torch.stack(
                [
                    bootstrap_1,
                    bootstrap_2,
                    bootstrap_3,
                    bootstrap_4,
                    torch.stack([bootstrap_2, bootstrap_3]).mean(dim=0),
                ],
            )

            rows.append(
                pd.DataFrame(
                    {
                        "mean": means.tolist(),
                        "std": bootstraps.std(dim=-1).tolist(),
                        "0.025": torch.quantile(bootstraps, q=0.025, dim=-1).tolist(),
                        "0.975": torch.quantile(bootstraps, q=0.975, dim=-1).tolist(),
                        "comparison": [
                            "high-D vs high-D",
                            "high-D vs low-D",
                            "low-D vs high-D",
                            "low-D vs low-D",
                            "mean",
                        ],
                    },
                ).assign(
                    **{
                        "subject (trial 1)": subject_1,
                        "subject (trial 2)": subject_2,
                        "metric": metric,
                    },
                ),
            )
    return pd.concat(rows)

In [None]:
N_PCS = 10

datasets = {
    subject: nsd.load_dataset(subject=subject) for subject in range(nsd.N_SUBJECTS)
}

shared_stimuli = compute_shared_stimuli(datasets.values(), n_repetitions=2)

datasets = {
    subject: split_by_repetition(
        filter_by_stimulus(dataset, stimuli=shared_stimuli),
        n_repetitions=2,
    )
    for subject, dataset in datasets.items()
}

rsms = {
    "high-D": {
        subject: {
            repetition: compute_rsm(
                convert_dataarray_to_tensor(datasets[subject][repetition]),
            )
            for repetition in (0, 1)
        }
        for subject in range(nsd.N_SUBJECTS)
    },
    "low-D": {
        subject: {
            repetition: compute_rsm(
                convert_dataarray_to_tensor(
                    reconstruct_data(datasets[subject][repetition], n=N_PCS),
                ),
            )
            for repetition in (0, 1)
        }
        for subject in range(nsd.N_SUBJECTS)
    },
}

In [None]:
cross_decomposition = CrossDecomposition(randomized=True)
cross_decomposition.fit(datasets[0][0], datasets[1][1])
transformed = cross_decomposition.transform(datasets[0][0], direction="left")[
    :,
    0,
].to_numpy()
indices = np.argsort(transformed)

kwargs_shared = {
    "cmap": "RdBu_r",
    "rasterized": True,
}

kwargs_full = {
    "vmin": -0.1,
    "vmax": 0.1,
} | kwargs_shared

kwargs_reconstructed = {
    "vmin": -0.46,
    "vmax": 0.46,
} | kwargs_shared

fig = plt.figure(figsize=(4.5, 6.5))
axes = fig.subplot_mosaic(
    """
    ABG
    ABE
    ABE
    ABH
    CDI
    CDF
    CDF
    CDJ
    KKK
    """,
    width_ratios=[20, 20, 1],
    height_ratios=[1, 50, 50, 1, 1, 50, 50, 1, 127.5],
)

images = {
    "A": axes["A"].imshow(
        clean_rsm(rsms["high-D"][0][0], indices=indices),
        **kwargs_full,
    ),
    "B": axes["B"].imshow(
        clean_rsm(rsms["high-D"][1][1], indices=indices),
        **kwargs_full,
    ),
    "C": axes["C"].imshow(
        clean_rsm(rsms["low-D"][0][0], indices=indices),
        **kwargs_reconstructed,
    ),
    "D": axes["D"].imshow(
        clean_rsm(rsms["low-D"][1][1], indices=indices),
        **kwargs_reconstructed,
    ),
}
for image in images.values():
    image.set_rasterized(True)

for key, ax in axes.items():
    if key != "K":
        ax.set_xticks([])
        ax.set_yticks([])
        for spine in ax.spines.values():
            spine.set_visible(False)

axes["A"].set_ylabel("high-D")
axes["A"].set_title("subject 1", pad=23)
axes["B"].set_title("subject 2", pad=23)
axes["C"].set_ylabel("low-D")

kwargs = {
    "c": "gray",
    "transform": axes["A"].transAxes,
    # "ha": "center",
    # "va": "baseline",
    "fontsize": "x-small",
    "style": "italic",
}
axes["A"].text(
    0.5,
    1.05,
    "766 stimuli",
    ha="center",
    va="baseline",
    **kwargs,
)
axes["A"].text(
    1.025,
    0.5,
    "766 stimuli",
    rotation=270,
    ha="left",
    va="center",
    **kwargs,
)

cax = axes["E"]
cb = fig.colorbar(mappable=images["B"], cax=cax, use_gridspec=True)
cb.outline.set_visible(False)
cax.tick_params(length=0, labelleft=True, labelright=False)
cb.set_ticks([-0.1, 0, 0.1])

cax = axes["F"]
cb = fig.colorbar(mappable=images["D"], cax=cax, use_gridspec=True)
cb.outline.set_visible(False)
cax.tick_params(length=0, labelleft=True, labelright=False)
cb.set_ticks([-0.4, 0, 0.4])

axes["G"].set_title("Pearson\ncorrelation")

ax = axes["K"]
palette = sns.color_palette("Purples_r", n_colors=40)
average_cross_comparisons = True

output = compute_all_pairwise_rsa_correlations(rsms=rsms)

output_ = output.loc[
    (output["subject (trial 1)"] == 0) & (output["subject (trial 2)"] == 1)
]
if average_cross_comparisons:
    arrow_width = 4.75
    output_ = output_.loc[
        np.isin(output_["comparison"], ["high-D vs high-D", "low-D vs low-D", "mean"])
    ]
    output_ = output_.iloc[[0, 2, 1, 3, 5, 4], :]
    palette = [
        palette[5],
        palette[15],
        palette[-15],
        palette[5],
        palette[15],
        palette[-15],
    ]
    output_["comparison"] = output_["comparison"].replace("mean", "high-D vs low-D")
else:
    arrow_width = 5
    output_ = output_.loc[output_["comparison"] != "mean"]
    palette = [
        palette[5],
        palette[14],
        palette[16],
        palette[-15],
        palette[5],
        palette[14],
        palette[16],
        palette[-15],
    ]

n_points = len(output_)
n_groups = 2
n_points_per_group = n_points // n_groups

std_factor = 2

x = list(range(n_points))
for i_point, (x_, y, low, high) in enumerate(
    zip(
        x,
        output_["mean"].to_numpy(),
        output_["0.025"].to_numpy(),
        output_["0.975"].to_numpy(),
        strict=False,
    ),
):
    ax.errorbar(
        x_,
        y,
        std_factor * np.array([y - low, high - y]).reshape((2, 1)),
        ls="None",
        c=palette[i_point],
        marker="o",
        barsabove=True,
    )

kwargs = {
    "ha": "center",
    "va": "center",
    "arrowprops": {
        "arrowstyle": f"-[, widthB={arrow_width}, lengthB=0.5",
        "lw": 1.0,
        "color": "dimgray",
    },
    "c": "dimgray",
    "fontsize": "small",
}
x_ = (n_points_per_group - 1) / 2
ax.annotate("Pearson\n(linear)", xy=(x_, 0.14), xytext=(x_, 0.07), **kwargs)

x_ = n_points_per_group + (n_points_per_group - 1) / 2
ax.annotate("Spearman\n(rank-order)", xy=(x_, 0.14), xytext=(x_, 0.07), **kwargs)

xticklabels = output_["comparison"].to_list()
ax.set_xticks(x, xticklabels, rotation=45, ha="right", rotation_mode="anchor")
ax.set_ylim(bottom=0, top=0.3)
ax.set_xlim(left=-1, right=n_points)
ax.set_ylabel("RSA correlation")

fig.suptitle("representational similarity matrices (RSMs)", x=0.5, y=1)

save_figure(
    fig,
    filepath=FIGURES_HOME / "rsa.pdf",
    dpi=300,
)