# Singular vectors

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import functools
from collections.abc import Sequence

import seaborn as sns
import xarray as xr
from bonner.datasets.allen2021_natural_scenes import plot_brain_map
from bonner.plotting import DEFAULT_MATPLOTLIBRC, save_figure
from matplotlib import pyplot as plt
from matplotlib.axes import Axes

from lib.datasets import (
    compute_shared_stimuli,
    filter_by_stimulus,
    nsd,
    split_by_repetition,
)
from lib.spectra import CrossDecomposition
from lib.utilities import MANUSCRIPT_HOME


def plot_row_of_singular_vectors(
    *,
    ax: Axes,
    ranks: Sequence[int],
    subject: int,
    singular_vectors: xr.DataArray,
    bottom: float,
    separation: float = 0.15,
    height: float = 0.55,
    width: float = 0.55,
) -> None:
    for i_rank, rank in enumerate(reversed(ranks)):
        ax_ = ax.inset_axes(
            (1 - width - separation * i_rank, bottom, width, height),
            projection="3d",
            facecolor=None,
        )
        ax_.set_rasterized(True)

        plot_brain_map(
            singular_vectors.sel(component=rank),
            ax=ax_,
            subject=subject,
            cmap="cold_hot",
            threshold=1e-8,
            view=(0, 200),
        )
        ax_.set_facecolor((1, 1, 1, 0))
        ax_.set_title(f"{rank}", y=0.9)


FIGURES_HOME = MANUSCRIPT_HOME / "figures"

REFERENCE_SUBJECT = 0

sns.set_theme(context="paper", style="ticks", rc=DEFAULT_MATPLOTLIBRC)

In [None]:
dataset = nsd.load_dataset(subject=REFERENCE_SUBJECT, roi="general")
datasets = split_by_repetition(
    filter_by_stimulus(
        dataset,
        stimuli=compute_shared_stimuli([dataset], n_repetitions=2),
    ),
    n_repetitions=2,
)
cross_decomposition = CrossDecomposition(randomized=True)
cross_decomposition.fit(datasets[0], datasets[1])

singular_vectors = cross_decomposition.singular_vectors(direction="left").set_index({
    "neuroid": ["x", "y", "z"],
})

In [None]:
fig, ax = plt.subplots(figsize=(6, 4))
ax.axis("off")

func = functools.partial(
    plot_row_of_singular_vectors,
    ax=ax,
    singular_vectors=singular_vectors,
    subject=REFERENCE_SUBJECT,
)
func(
    ranks=[1, 2, 3, 4, 5, 10],
    bottom=0.5,
)
func(
    ranks=[20, 50, 100, 500, 1000, 5000],
    bottom=0.05,
)

save_figure(
    fig,
    filepath=FIGURES_HOME / "singular-vectors.pdf",
    dpi=300,
)