# Cross-similarity

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import contextlib
import itertools
from collections.abc import Collection, Sequence

import numpy as np
import pandas as pd
import seaborn as sns
import xarray as xr
from bonner.plotting import DEFAULT_MATPLOTLIBRC, save_figure
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

from lib.datasets import (
    compute_shared_stimuli,
    filter_by_stimulus,
    nsd,
    split_by_repetition,
)
from lib.spectra import (
    assign_data_to_geometrically_spaced_bins,
    compute_spectra_with_n_fold_cross_validation,
    offset_spectra,
)
from lib.utilities import MANUSCRIPT_HOME


def convert_datasets_to_mni_space(
    x: dict[int, xr.DataArray],
    /,
    *,
    order: int = 1,
    resolution: float = 1.8,
) -> dict[int, dict[int, xr.DataArray]]:
    shared_stimuli = compute_shared_stimuli(x.values(), n_repetitions=2)
    datasets = {
        subject: split_by_repetition(
            filter_by_stimulus(dataset, stimuli=shared_stimuli),
            n_repetitions=2,
        )
        for subject, dataset in x.items()
    }
    datasets = {
        subject: {
            repetition: nsd.convert_array_to_mni(
                dataset,
                subject=subject,
                order=order,
            ).set_index({"neuroid": ["x", "y", "z"]})
            for repetition, dataset in datasets_.items()
        }
        for subject, datasets_ in datasets.items()
    }
    shared_voxels = list(
        set.intersection(
            *[set(datasets_[0]["neuroid"].data) for datasets_ in datasets.values()],
        ),
    )

    return {
        subject: {
            repetition: nsd.resample_1mm_mni(
                dataset.load().sel(neuroid=shared_voxels),
                resolution=resolution,
            )
            .set_index({
                "presentation": ["stimulus", "repetition"],
                "neuroid": ["x", "y", "z"],
            })
            .rename(f"{dataset.name}.resampled.{resolution}mm")
            for repetition, dataset in datasets_.items()
        }
        for subject, datasets_ in datasets.items()
    }


def compute_all_cross_individual_spectra(
    datasets: dict[int, dict[int, xr.DataArray]],
    *,
    rank_ranges: Sequence[tuple[int | None, int | None]] | None = None,
    log_bin: bool = True,
    n_folds: int = 8,
    stop: int = 10_000,
) -> dict[frozenset[int], xr.Dataset]:
    spectra = {}
    for (individual_1, dataset_1), (individual_2, dataset_2) in tqdm(
        itertools.combinations_with_replacement(datasets.items(), r=2),
        desc="pair of individuals",
        leave=False,
    ):
        spectra_ = [
            compute_spectra_with_n_fold_cross_validation(
                x_train=dataset_1[0],
                y_train=dataset_2[1],
                x_test=dataset_1[0],
                y_test=dataset_2[1],
                n_folds=n_folds,
            ).expand_dims({"comparison": [0], "alignment": ["functional"]}),
            compute_spectra_with_n_fold_cross_validation(
                x_train=dataset_1[1],
                y_train=dataset_2[0],
                x_test=dataset_1[1],
                y_test=dataset_2[0],
                n_folds=n_folds,
            ).expand_dims({"comparison": [1], "alignment": ["functional"]}),
        ]
        with contextlib.suppress(Exception):
            spectra_ += [
                compute_spectra_with_n_fold_cross_validation(
                    x_train=dataset_1[0],
                    y_train=dataset_2[1],
                    x_test=dataset_2[1],
                    y_test=dataset_1[0],
                    n_folds=n_folds,
                ).expand_dims({"comparison": [0], "alignment": ["anatomical"]}),
                compute_spectra_with_n_fold_cross_validation(
                    x_train=dataset_1[1],
                    y_train=dataset_2[0],
                    x_test=dataset_2[0],
                    y_test=dataset_1[1],
                    n_folds=n_folds,
                ).expand_dims({"comparison": [1], "alignment": ["anatomical"]}),
            ]

        spectrum = xr.merge(spectra_)

        if rank_ranges is not None:
            spectrum = xr.concat(
                [
                    spectrum.isel(component=slice(start, stop))
                    .sum("component")
                    .expand_dims({"range": [i_range]})
                    .assign_coords({
                        "start": ("range", [start + 1] if start is not None else [1]),
                        "stop": (
                            "range",
                            [stop]
                            if stop is not None
                            else [spectrum.sizes["component"]],
                        ),
                    })
                    for i_range, (start, stop) in enumerate(rank_ranges)
                ],
                dim="range",
            )
        else:
            spectrum = spectrum.assign_coords({
                "rank": (
                    "component",
                    assign_data_to_geometrically_spaced_bins(
                        spectrum["component"].data,
                        density=3,
                        start=1,
                        stop=stop,
                    ),
                ),
            })

            if log_bin:
                spectrum = spectrum.groupby("rank").mean()

        spectra[frozenset([individual_1, individual_2])] = spectrum.mean("comparison")

    return spectra


def compute_cross_correlations(
    spectra: dict[frozenset[int], xr.Dataset],
    *,
    individuals: Collection[int],
    reference_individual: int,
) -> pd.DataFrame:
    return pd.concat(
        [
            (
                spectra[frozenset([reference_individual, individual])]
                / np.sqrt(
                    spectra[frozenset([individual, individual])]
                    * spectra[frozenset([reference_individual, reference_individual])],
                )
            )
            .expand_dims({"individual": [f"{1 + individual}"]})
            .rename({"covariance": "correlation"})
            .to_dataframe()
            .reset_index()
            .dropna()
            for individual in individuals
            if individual != reference_individual
        ],
    )


FIGURES_HOME = MANUSCRIPT_HOME / "figures"

sns.set_theme(context="paper", style="ticks", rc=DEFAULT_MATPLOTLIBRC)

REFERENCE_SUBJECT = 3

In [None]:
datasets = {
    subject: nsd.load_dataset(
        subject=subject,
        roi="general",
    )
    for subject in range(nsd.N_SUBJECTS)
}
datasets_mni = convert_datasets_to_mni_space(
    datasets,
    order=1,
)

## cross-detectability

In [None]:
spectra = compute_all_cross_individual_spectra(datasets_mni)
spectra = pd.concat(
    [
        spectra[frozenset([REFERENCE_SUBJECT, subject])]
        .mean("fold")
        .to_dataframe()
        .reset_index()
        .assign(individual=f"{1 + subject}")
        for subject in range(nsd.N_SUBJECTS)
        if subject != REFERENCE_SUBJECT
    ],
)

spectra = offset_spectra(
    spectra.assign(covariance=lambda x: x["covariance"] / (x["rank"] ** (-1.5))),
    keys=["alignment"],
)

fig, ax = plt.subplots(figsize=(3, 3))

cs = {
    "functional": "dimgray",
    "anatomical": "darkgray",
}

for alignment in ["functional", "anatomical"]:
    spectra_ = spectra.loc[spectra["alignment"] == alignment].drop(
        columns=["alignment", "individual"],
    )

    groupby = spectra_.groupby("rank")
    mean = groupby.mean().reset_index()
    std = groupby.std().reset_index()

    sem = std["covariance"] / np.sqrt(nsd.N_SUBJECTS - 1)
    mask = (mean["covariance"] - 3 * sem) > 0

    ax.errorbar(
        mean["rank"][mask],
        mean["covariance"][mask],
        std["covariance"][mask],
        ls="None",
        marker="o" if alignment == "functional" else "^",
        c=cs[alignment],
        label=alignment,
        mew=0,
        ms=6 if alignment == "anatomical" else 5,
    )

    ax.errorbar(
        mean["rank"][~mask],
        mean["covariance"][~mask],
        std["covariance"][~mask],
        ls="None",
        marker="o" if alignment == "functional" else "^",
        c=cs[alignment],
        mew=1,
        mfc="none",
        alpha=0.5,
        ms=6 if alignment == "anatomical" else 5,
    )

ax.set_xscale("log")
ax.set_ylim(bottom=-0.005, top=0.025)
ax.set_ylabel(r"cross-covariance / (rank)$^\mathsf{-1.5}$")
ax.set_xticks([1, 1e1, 1e2])
ax.set_xlim(left=1, right=4e2)
ax.axhline(0, ls="--", c="gray")
ax.legend(loc="upper right", title="alignment")
ax.set_title(
    f"between-subject detectability\nrelative to subject {REFERENCE_SUBJECT + 1}",
)
ax.set_xlabel("rank")

save_figure(
    fig,
    filepath=FIGURES_HOME / "cross-detectability.pdf",
)

## cross-correlations

In [None]:
rank_ranges = (
    (0, 10),
    (10, 50),
    (50, 100),
)

spectra = compute_all_cross_individual_spectra(
    datasets_mni,
    rank_ranges=rank_ranges,
)
correlations = {
    reference_subject: compute_cross_correlations(
        spectra=spectra,
        individuals=range(nsd.N_SUBJECTS),
        reference_individual=reference_subject,
    )
    for reference_subject in range(nsd.N_SUBJECTS)
}

In [None]:
palette = sns.color_palette("flare_r", n_colors=nsd.N_SUBJECTS)

fig, axes = plt.subplots(
    figsize=(1.5 * len(rank_ranges), 3),
    ncols=len(rank_ranges),
    sharex=True,
    sharey=True,
    squeeze=False,
)

for i_range, ((start, stop), ax) in enumerate(
    zip(rank_ranges, axes.flat, strict=False),
):
    filter_range = correlations[REFERENCE_SUBJECT]["range"] == i_range
    for alignment in ("functional", "anatomical"):
        filter_alignment = correlations[REFERENCE_SUBJECT]["alignment"] == alignment

        i_tick = 0
        for individual in range(nsd.N_SUBJECTS):
            if individual == REFERENCE_SUBJECT:
                continue

            filter_individual = (
                correlations[REFERENCE_SUBJECT]["individual"] == f"{1 + individual}"
            )

            correlations_ = correlations[REFERENCE_SUBJECT].loc[
                filter_range & filter_alignment & filter_individual
            ]

            i_tick += 1

            mean = correlations_["correlation"].mean()
            std = correlations_["correlation"].std()
            sem = std / np.sqrt(len(correlations_) - 1)
            significant = (mean - 3 * sem) > 0

            ax.errorbar(
                i_tick * (1 / 1.025 if alignment == "functional" else 1.025),
                mean,
                std,
                ls="None",
                marker="o" if alignment == "functional" else "^",
                c=palette[individual],
                mew=0 if significant else 1,
                mfc=palette[individual] if significant else "none",
                ms=6 if alignment == "anatomical" else 5,
                alpha=1 if alignment == "functional" else 0.5,
            )

    ax.set_xticks(
        1 + np.arange(nsd.N_SUBJECTS - 1),
        [
            f"{1 + subject}"
            for subject in range(nsd.N_SUBJECTS)
            if subject != REFERENCE_SUBJECT
        ],
    )
    ax.set_ylim(bottom=-0.2, top=1.05)
    ax.axhline(1, ls="--", c="gray", lw=0.5)
    ax.axhline(0, ls="--", c="gray", lw=0.5)

    if start is None or stop is None:
        title = "all ranks"
    else:
        title = f"rank {start + 1} to {stop}"
    ax.set_title(title)

ax = axes.flat[0]
ax.set_ylabel("cross-correlation")
ax.text(4, 0.65, "functional", ha="center", va="center", fontsize="small")
ax.text(4, -0.1, "anatomical", ha="center", va="center", fontsize="small")

fig.supxlabel("subject", x=0.56, y=0.06)
fig.suptitle(
    f"between-subject similarity, relative to subject {REFERENCE_SUBJECT + 1}",
    x=0.56,
)

save_figure(fig, filepath=FIGURES_HOME / "cross-correlations.pdf")

## cross-correlations-all

In [None]:
fig, axes = plt.subplots(
    figsize=(7, 7),
    ncols=7,
    nrows=4,
    sharex=True,
    sharey=True,
    gridspec_kw={"hspace": 0.45},
    width_ratios=[1, 1, 1, 0.5, 1, 1, 1],
)

for reference_subject in range(nsd.N_SUBJECTS):
    i_row = reference_subject // 2
    i_col_group = reference_subject % 2

    for i_range, ((start, stop)) in enumerate(rank_ranges):
        i_col = i_range if i_col_group == 0 else i_col_group * 3 + i_range + 1
        ax = axes[i_row, i_col]

        filter_range = correlations[reference_subject]["range"] == i_range
        for alignment in ("functional", "anatomical"):
            filter_alignment = correlations[reference_subject]["alignment"] == alignment

            for individual in range(nsd.N_SUBJECTS):
                if individual == reference_subject:
                    continue

                filter_individual = (
                    correlations[reference_subject]["individual"] == f"{1 + individual}"
                )

                correlations_ = correlations[reference_subject].loc[
                    filter_range & filter_alignment & filter_individual
                ]

                mean = correlations_["correlation"].mean()
                std = correlations_["correlation"].std()
                sem = std / np.sqrt(len(correlations_) - 1)
                significant = (mean - 3 * sem) > 0

                ax.errorbar(
                    (1 + individual)
                    * (1 / 1.025 if alignment == "functional" else 1.025),
                    mean,
                    std,
                    ls="None",
                    marker="o" if alignment == "functional" else "^",
                    c=palette[individual],
                    mew=0 if significant else 1,
                    mfc=palette[individual] if significant else "none",
                    ms=6 if alignment == "anatomical" else 5,
                    alpha=1 if alignment == "functional" else 0.5,
                )

        ax.set_xticks(
            range(1, nsd.N_SUBJECTS + 1),
            [f"{1 + subject}" for subject in range(nsd.N_SUBJECTS)],
        )
        ax.set_ylim(bottom=-0.2, top=1.2)
        ax.axhline(1, ls="--", c="gray", lw=0.5)
        ax.axhline(0, ls="--", c="gray", lw=0.5)

        if start is None or stop is None:
            title = "all ranks"
        else:
            title = (
                f"rank {start + 1} to {stop}"
                if i_range == 0
                else f"{start + 1} to {stop}"
            )
        if i_row == 0:
            ax.set_title(title)
        if i_row == 3 and i_range == 1:
            ax.set_xlabel("subject", labelpad=10)

for row in range(4):
    axes[row, 3].axis("off")

ax = axes[0, 0]
ax.text(4.5, 0.65, "functional", ha="center", va="center", fontsize="small")
ax.text(4.5, -0.1, "anatomical", ha="center", va="center", fontsize="small")

fig.supylabel("cross-correlation", x=0.04)
fig.suptitle("between-subject similarity, relative to ...")
for subject in range(nsd.N_SUBJECTS):
    i_row = subject // 2
    i_col = subject % 2
    if i_row == 0:
        fig.text(
            x=0.296 + i_col * 0.432,
            y=0.92 - i_row * 0.232,
            s=f"subject {1 + subject}",
            fontsize="medium",
            ha="center",
        )
    else:
        fig.text(
            x=0.296 + i_col * 0.432,
            y=0.692 - (i_row - 1) * 0.210,
            s=f"subject {1 + subject}",
            fontsize="medium",
            ha="center",
        )

fig.tight_layout()

save_figure(fig, filepath=FIGURES_HOME / "cross-correlations-all.pdf")