# Schematic

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import pandas as pd
import seaborn as sns
import xarray as xr
from bonner.datasets.allen2021_natural_scenes import plot_brain_map
from bonner.plotting import DEFAULT_MATPLOTLIBRC, add_colorbar, save_figure
from matplotlib import pyplot as plt

from lib.datasets import (
    compute_shared_stimuli,
    filter_by_stimulus,
    nsd,
    split_by_repetition,
)
from lib.spectra import (
    assign_data_to_geometrically_spaced_bins,
    compute_spectra_with_n_fold_cross_validation,
    compute_spectrum,
)
from lib.utilities import MANUSCRIPT_HOME

FIGURES_HOME = MANUSCRIPT_HOME / "figures"

REFERENCE_SUBJECT = 0

sns.set_theme(
    context="paper",
    style="ticks",
    rc=DEFAULT_MATPLOTLIBRC,
)

In [None]:
dataset = nsd.load_dataset(subject=REFERENCE_SUBJECT, roi="general")
dataset = split_by_repetition(
    filter_by_stimulus(
        dataset,
        stimuli=compute_shared_stimuli([dataset], n_repetitions=2),
    ),
    n_repetitions=2,
)

## Method comparison

In [None]:
spectra = {
    "PCA": compute_spectrum(
        x_train=dataset[0],
        y_train=dataset[0],
        x_test=dataset[0],
        y_test=dataset[0],
        normalize=False,
    ),
    "cvPCA": compute_spectrum(
        x_train=dataset[0],
        y_train=dataset[0],
        x_test=dataset[0],
        y_test=dataset[1],
        normalize=False,
    ),
    "cross-decomposition (ours)": compute_spectra_with_n_fold_cross_validation(
        x_train=dataset[0],
        y_train=dataset[1],
        x_test=dataset[0],
        y_test=dataset[1],
        n_folds=8,
        seed=0,
        normalize=False,
    ),
}
n_voxels = dataset[0].sizes["neuroid"]

for method, spectrum in spectra.items():
    spectra[method] = (
        (spectrum / n_voxels)
        .assign_coords(
            {
                "rank": (
                    "component",
                    assign_data_to_geometrically_spaced_bins(
                        spectrum["component"].data,
                        density=3,
                        start=1,
                        stop=10_000,
                    ),
                ),
            },
        )
        .groupby("rank")
        .mean()
        .expand_dims({"method": [method]})
    )

spectra["cross-decomposition (ours)"] = spectra["cross-decomposition (ours)"].mean(
    "fold",
)

spectra = xr.concat(spectra.values(), dim="method")
spectra = spectra.to_dataframe().reset_index()


palette = sns.color_palette("tab10", 3)

fig, ax = plt.subplots(figsize=(3, 3))
for method in pd.unique(spectra["method"].to_numpy()):
    match method:
        case "cross-decomposition (ours)":
            kwargs = {
                "marker": "s",
                "c": palette[0],
                "mew": 0,
            }
        case "cvPCA":
            kwargs = {
                "marker": "^",
                "c": "gray",
                "mew": 1,
                "mfc": "none",
            }
        case "PCA":
            kwargs = {
                "marker": "o",
                "mew": 1,
                "c": "gray",
                "mfc": "none",
            }
        case _:
            raise ValueError

    spectra_ = spectra.loc[spectra["method"] == method]

    ax.plot(
        spectra_["rank"],
        spectra_["covariance"],
        ls="None",
        label=method,
        **kwargs,
    )

ax.set_xlim(left=1, right=2e3)
ax.set_ylim(bottom=1e-7, top=1e-1)
ax.set_title("comparison between methods")
ax.set_ylabel("covariance")
ax.set_xlabel("rank")
ax.set_xscale("log")
ax.set_yscale("log")

_ = ax.legend(loc="lower left")

save_figure(
    fig,
    filepath=FIGURES_HOME / "cache" / "schematic-method_comparison.svg",
)

## Activation heatmap

In [None]:
fig, ax = plt.subplots(figsize=(4.5, 2.5))
image = ax.imshow(
    dataset[0].to_numpy()[:80, :120],
    cmap="RdBu_r",
    vmin=-3,
    vmax=3,
    aspect="equal",
)
cb, _ = add_colorbar(
    ax=ax,
    mappable=image,
    location="bottom",
)
cb.set_label("activations (Z-score)", rotation=0)
cb.set_ticks([-3, -2, -1, 0, 1, 2, 3])
ax.set_xlabel("voxels")
ax.set_ylabel("stimuli")
for spine in ("left", "bottom"):
    ax.spines[spine].set_visible(False)
ax.set_xticks([])
ax.set_yticks([])
ax.xaxis.set_label_position("top")

save_figure(
    fig,
    filepath=FIGURES_HOME / "cache" / "schematic-activation_heatmap.svg",
)

## Brain map

In [None]:
fig, ax = plt.subplots(figsize=(5, 5), subplot_kw={"projection": "3d"})
plot_brain_map(
    dataset[0].isel(presentation=2),
    ax=ax,
    subject=REFERENCE_SUBJECT,
    cmap="RdBu",
    vmin=-1.5,
    vmax=1.5,
)
ax.set_rasterized(True)
save_figure(
    fig,
    filepath=FIGURES_HOME / "cache" / "schematic-brain_map.png",
)

## Equations

$$
\begin{align*}
    \text{cov} \left(X_\text{train}, Y_\text{train}\right)
    &= \dfrac{X_\text{train}^\top Y_\text{train}}{n_\text{train}}\\
    &= U \Sigma V^\top
\end{align*}
$$

$$
\begin{align*}
    \Sigma_\text{test}
    &= \text{cov} \left(X_\text{test} U, Y_\text{test} V\right)\\
    &= \dfrac{\left(X_\text{test} U \right)^\top \left(Y_\text{test} V\right)}{n_\text{test}}\\
\end{align*}
$$

$$
\text{cov} \left(X_\text{train}, Y_\text{train}\right)
= X_\text{train}^\top Y_\text{train} / n_\text{train}
= U_\text{train} \Sigma_\text{train} V_\text{train}^\top
$$

$$
\Sigma_\text{test}
= \text{cov} \left(X_\text{test} U_\text{train}, Y_\text{test} V_\text{train}\right)
= \left(X_\text{test} U_\text{train} \right)^\top \left(Y_\text{test} V_\text{train}\right) / n_\text{test}
$$