# Spectra

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import cairosvg
import numpy as np
import seaborn as sns
import skunk
import xarray as xr
from bonner.plotting import DEFAULT_MATPLOTLIBRC, save_figure
from matplotlib import pyplot as plt
from matplotlib.ticker import LogFormatterMathtext

from lib.analyses import (
    compute_cross_individual_spectra,
    compute_within_individual_spectra,
)
from lib.datasets import (
    compute_shared_stimuli,
    filter_by_stimulus,
    nsd,
    sample_neuroids,
    split_by_repetition,
)
from lib.spectra import compute_spectra_with_n_fold_cross_validation, plot_spectra
from lib.utilities import MANUSCRIPT_HOME

FIGURES_HOME = MANUSCRIPT_HOME / "figures"

sns.set_theme(context="paper", style="ticks", rc=DEFAULT_MATPLOTLIBRC)

REFERENCE_SUBJECT = 0

## general, significance-test

In [None]:
datasets = {
    subject: nsd.load_dataset(
        subject=subject,
        roi="general",
        preprocessing="fithrf",
        z_score=True,
    )
    for subject in range(nsd.N_SUBJECTS)
}

datasets_within = {
    subject: split_by_repetition(
        filter_by_stimulus(
            dataset,
            stimuli=compute_shared_stimuli([dataset], n_repetitions=2),
        ),
        n_repetitions=2,
    )
    for subject, dataset in datasets.items()
}

shared_stimuli = compute_shared_stimuli(datasets.values(), n_repetitions=2)
datasets_cross = {
    subject: split_by_repetition(
        filter_by_stimulus(dataset, stimuli=shared_stimuli),
        n_repetitions=2,
    )
    for subject, dataset in datasets.items()
}

In [None]:
spectra_within = compute_within_individual_spectra(datasets_within)
spectra_cross = compute_cross_individual_spectra(
    datasets_cross,
    reference_individual=REFERENCE_SUBJECT,
)

In [None]:
nulls_samples = xr.concat(
    [
        (
            compute_spectra_with_n_fold_cross_validation(
                x_train=dataset[0],
                y_train=dataset[1],
                x_test=dataset[0],
                y_test=dataset[1],
                n_folds=8,
                n_permutations=5_000,
                n_bootstraps=5_000,
            )["covariance (permuted)"]
            .isel(component=range(10))
            .mean("fold")
            .expand_dims(subject=[f"{individual + 1}"])
        )
        for individual, dataset in datasets_within.items()
    ],
    dim="subject",
)

In [None]:
for null_quantiles in ((0.997, 0.955, 0.683), ()):
    palette = sns.color_palette("Grays", n_colors=len(null_quantiles))[::-1]
    palette_text = sns.color_palette("Grays", n_colors=5)[::-1][: len(null_quantiles)]
    fig, axes = plt.subplots(figsize=(5.5, 3), ncols=2, sharex=True, sharey=True)

    ax = axes[0]

    plot_spectra(
        spectra=spectra_within,
        ax=ax,
        hue="individual",
        hue_order=[f"{subject + 1}" for subject in range(nsd.N_SUBJECTS)],
        hide_insignificant=len(null_quantiles) != 0,
    )
    ax.set_title("within-subject", pad=10)
    ax.set_ylabel("covariance")
    ax.set_xlabel("rank")
    ax.legend(
        loc="upper right" if len(null_quantiles) != 0 else "lower left",
        title="subject",
        ncols=2,
        columnspacing=0.5,
        handletextpad=0.25,
    )

    if len(null_quantiles) == 0:
        ax_inset = ax.inset_axes([0.7, 0.65, 0.22, 0.22])
        ax_inset.axis("off")
        skunk.connect(ax_inset, "human")
    else:
        null_lines = []
        positions = [
            (5e1, 1e-5),
            (1e2, 1.5e-6),
            (2e2, 2e-7),
        ]
        for quantile, color, textcolor, position in zip(
            null_quantiles,
            palette,
            palette_text,
            positions,
            strict=True,
        ):
            nulls = (
                spectra_within["covariance (permuted)"]
                .mean("fold")
                .quantile(quantile, dim="permutation")
                .mean("individual")
            )
            null_line = ax.plot(
                nulls["rank"],
                nulls,
                ls="-",
                c=color,
                mew=0,
                alpha=0.75,
            )
            ax.text(
                position[0],
                position[1],
                f"{quantile * 100:.1f}%",
                c=textcolor,
                ha="center",
                va="center",
                fontsize="xx-small",
                backgroundcolor="white",
                bbox={"facecolor": "white", "alpha": 0.75, "pad": 1},
            )

        with plt.rc_context(
            rc=DEFAULT_MATPLOTLIBRC
            | {
                "axes.linewidth": 0.75,
                "xtick.major.size": 0,
                "ytick.major.size": 0,
                "xtick.labelsize": "x-small",
                "ytick.labelsize": "x-small",
                "axes.titlesize": "x-small",
            },
        ):
            bottom = 0.35
            ax = ax.inset_axes([0.12, 0.1, 0.36, 0.32])
            sns.lineplot(
                ax=ax,
                data=nulls_samples.mean("subject")
                .isel(permutation=range(20))
                .to_dataframe()
                .reset_index(),
                x="component",
                y="covariance (permuted)",
                color="darkgray",
                estimator=None,
                units="permutation",
                lw=0.5,
            )
            ax.set_ylabel("")
            ax.set_xlabel("")
            ax.axhline(0, ls="-", c="gray", lw=1)
            ax.set_xticks([1, 10])
            ax.set_ylim(bottom=-7e-4, top=7e-4)
            ax.set_xlim(left=1e0, right=1e1)
            ax.xaxis.set_major_formatter(LogFormatterMathtext())
            ax.ticklabel_format(axis="y", scilimits=(-4, -4))
            ax.set_yticks([-5e-4, 0, 5e-4])
            ax.text(
                6.5,
                -5.5e-4,
                "with randomly\nshuffled images",
                ha="center",
                va="baseline",
                fontsize="xx-small",
            )

    ax = axes[1]
    plot_spectra(
        ax=ax,
        spectra=spectra_cross,
        hue="individual",
        palette="flare_r",
        hue_reference=f"{1 + REFERENCE_SUBJECT}",
        hue_order=[f"{subject + 1}" for subject in range(nsd.N_SUBJECTS)],
        hue_labels=[
            f"{subject + 1}*" if subject == REFERENCE_SUBJECT else f"{subject + 1}"
            for subject in range(nsd.N_SUBJECTS)
        ],
        hide_insignificant=len(null_quantiles) != 0,
    )

    ax.set_title(
        f"between-subject,\nrelative to subject {REFERENCE_SUBJECT + 1}",
    )
    ax.set_ylabel("cross-covariance")
    ax.set_xlabel("rank")
    ax.axvline(len(shared_stimuli), ls="--", c="gray", lw=0.5, ymax=0.45)
    ax.text(
        s="number of\nshared images",
        x=3e3,
        y=2e-9,
        fontsize="xx-small",
        ha="center",
        va="bottom",
    )
    ax.legend(
        loc="upper right" if len(null_quantiles) != 0 else "lower left",
        title="subject",
        ncols=2,
        columnspacing=0.5,
        handletextpad=0.25,
    )

    if len(null_quantiles) == 0:
        ax_inset = ax.inset_axes([0.6, 0.65, 0.23, 0.23])
        ax_inset.axis("off")
        skunk.connect(ax_inset, "humans")
    else:
        null_lines = []
        positions = [
            (8e0, 2e-4),
            (1.5e1, 4e-5),
            (2.5e1, 8e-6),
        ]
        for quantile, color, textcolor, position in zip(
            null_quantiles,
            palette,
            palette_text,
            positions,
            strict=True,
        ):
            nulls = (
                spectra_cross["covariance (permuted)"]
                .mean("fold")
                .quantile(quantile, dim="permutation")
                .mean("individual")
            )
            null_line = ax.plot(
                nulls["rank"],
                nulls,
                ls="-",
                c=color,
                mew=0,
                alpha=0.75,
            )
            null_lines.append(null_line[0])
            ax.text(
                position[0],
                position[1],
                f"{quantile * 100:.1f}%",
                c=textcolor,
                ha="center",
                va="center",
                fontsize="xx-small",
                backgroundcolor="white",
                bbox={"facecolor": "white", "alpha": 0.75, "pad": 1},
            )
        ax.text(
            1.5e1,
            8e-7,
            "percentiles of\nnull distributtion",
            ha="center",
            va="center",
            fontsize="xx-small",
        )

    ax.yaxis.set_tick_params(labelbottom=True)

    ax.set_xscale("log")
    ax.set_yscale("log")

    ax.set_xlim(left=1, right=1e4)
    ax.set_xticks([1, 1e1, 1e2, 1e3, 1e4])
    ax.set_ylim(top=1e-1, bottom=1e-9)

    if len(null_quantiles) == 0:
        svg = skunk.insert(
            {
                "human": f"{FIGURES_HOME}/human.svg",
                "humans": f"{FIGURES_HOME}/humans.svg",
            },
        )
        cairosvg.svg2pdf(bytestring=svg, write_to=f"{FIGURES_HOME}/general.pdf")
    else:
        save_figure(
            fig,
            filepath=FIGURES_HOME / "significance-test.pdf",
        )

## general-all

In [None]:
spectra = {
    reference_subject: compute_cross_individual_spectra(
        datasets_cross,
        reference_individual=reference_subject,
        n_permutations=0,
        n_bootstraps=0,
    )
    for reference_subject in range(nsd.N_SUBJECTS)
}

In [None]:
palette = sns.color_palette("flare_r", nsd.N_SUBJECTS)

fig, axes = plt.subplots(figsize=(5, 6), ncols=3, nrows=3, sharex=True, sharey=True)
for reference_subject, ax in zip(range(nsd.N_SUBJECTS), axes.flat, strict=False):
    plot_spectra(
        spectra[reference_subject],
        ax=ax,
        hue="individual",
        hue_reference=reference_subject,
        hue_order=[f"{1 + subject}" for subject in range(nsd.N_SUBJECTS)],
        hue_labels=[
            f"{subject + 1}*" if subject == reference_subject else f"{subject + 1}"
            for subject in range(nsd.N_SUBJECTS)
        ],
        palette=palette,
    )
    ax.set_xscale("log")
    ax.set_yscale("log")
    ax.set_title(f"subject {reference_subject + 1}")
    ax.set_xlim(left=1, right=1e3)
    ax.set_ylim(bottom=1e-7, top=1e-1)

ax = axes.flat[-1]
for subject in range(nsd.N_SUBJECTS):
    ax.errorbar(
        [0, 0],
        [0, 0],
        [0, 0],
        marker="o",
        c=palette[subject],
        ls="None",
        label=f"{subject + 1}",
        mew=0,
    )
ax.legend(loc="center", title="subject", ncol=1)
axes.flat[-1].axis("off")

fig.supxlabel("rank", y=0.025, x=0.57)
fig.supylabel("cross-covariance", x=0.03)
fig.suptitle("between-subject, relative to ...", x=0.57)
fig.tight_layout()

save_figure(fig, filepath=FIGURES_HOME / "general-all.pdf")

## v1-to-v4

In [None]:
datasets = {
    roi: {
        subject: nsd.load_dataset(
            subject=subject,
            roi=roi,
            preprocessing="fithrf",
            z_score=True,
        )
        for subject in range(nsd.N_SUBJECTS)
    }
    for roi in ("V1", "V2", "V3", "V4")
}

datasets_within = {
    roi: {
        subject: split_by_repetition(
            filter_by_stimulus(
                dataset,
                stimuli=compute_shared_stimuli([dataset], n_repetitions=2),
            ),
            n_repetitions=2,
        )
        for subject, dataset in datasets_.items()
    }
    for roi, datasets_ in datasets.items()
}

shared_stimuli = compute_shared_stimuli(datasets["V1"].values(), n_repetitions=2)
datasets_cross = {
    roi: {
        subject: split_by_repetition(
            filter_by_stimulus(dataset, stimuli=shared_stimuli),
            n_repetitions=2,
        )
        for subject, dataset in datasets_.items()
    }
    for roi, datasets_ in datasets.items()
}

In [None]:
spectra_within = {
    roi: compute_within_individual_spectra(
        datasets_,
        n_permutations=0,
        n_bootstraps=0,
    ).expand_dims({"region of interest": [roi]})
    for roi, datasets_ in datasets_within.items()
}

spectra_cross = {
    roi: compute_cross_individual_spectra(
        datasets_,
        reference_individual=REFERENCE_SUBJECT,
        n_permutations=0,
        n_bootstraps=0,
    ).expand_dims({
        "region of interest": [roi],
    })
    for roi, datasets_ in datasets_cross.items()
}

In [None]:
rois = ("V1", "V2", "V3", "V4")

fig, axes = plt.subplots(
    nrows=2,
    ncols=4,
    figsize=(6, 4),
    sharex=True,
    sharey=True,
)

for i_roi, roi in enumerate(rois):
    ax = axes[0, i_roi]
    plot_spectra(
        spectra=spectra_within[roi],
        ax=ax,
        palette="crest_r",
        hue="individual",
        hue_order=[f"{1 + subject}" for subject in range(nsd.N_SUBJECTS)],
        hue_labels=[f"{1 + subject}" for subject in range(nsd.N_SUBJECTS)],
    )
    ax.set_title(roi)

    ax = axes[1, i_roi]
    plot_spectra(
        spectra=spectra_cross[roi],
        ax=ax,
        palette="flare_r",
        hue="individual",
        hue_reference=f"{1 + REFERENCE_SUBJECT}",
        hue_order=[f"{1 + subject}" for subject in range(nsd.N_SUBJECTS)],
        hue_labels=[
            f"{subject + 1}*" if subject == REFERENCE_SUBJECT else f"{subject + 1}"
            for subject in range(nsd.N_SUBJECTS)
        ],
    )

for row in (0, 1):
    _ = axes[row, 0].legend(
        loc="upper right",
        title="subject",
        ncols=2,
        columnspacing=0,
        handletextpad=0.0,
        borderpad=0,
        labelspacing=0.2,
        borderaxespad=0,
        markerscale=0.8,
    )
    axes[row, -1].yaxis.set_label_position("right")
    axes[0, 0].set_ylabel("covariance")
    axes[1, 0].set_ylabel("cross-covariance")
    axes[0, -1].set_ylabel("within-subject", rotation=-90, labelpad=10)
    axes[1, -1].set_ylabel(
        f"between-subject,\nrelative to subject {REFERENCE_SUBJECT + 1}",
        rotation=-90,
        labelpad=22,
    )

ax_inset = axes[0, -1].inset_axes([0.52, 0.7, 0.23, 0.23])
ax_inset.axis("off")
skunk.connect(ax_inset, "human")

ax_inset = axes[1, -1].inset_axes([0.5, 0.7, 0.3, 0.3])
ax_inset.axis("off")
skunk.connect(ax_inset, "humans")

ax = axes[0, 0]
ax.set_xscale("log")
ax.set_yscale("log")
ax.set_xticks([1, 1e1, 1e2, 1e3])
ax.set_ylim(bottom=1e-7, top=1e-1)

fig.supxlabel("rank", y=0.05, x=0.52)

fig.tight_layout()

svg = skunk.insert(
    {
        "human": f"{FIGURES_HOME}/human.svg",
        "humans": f"{FIGURES_HOME}/humans.svg",
    },
)
cairosvg.svg2pdf(bytestring=svg, write_to=f"{FIGURES_HOME}/v1-to-v4.pdf")

## vary-n-voxels

In [None]:
n_voxels = np.geomspace(100, 10**4, num=5).astype(int)

dataset = nsd.load_dataset(
    subject=REFERENCE_SUBJECT,
    preprocessing="fithrf",
    roi="general",
    z_score=True,
)
repeated_stimuli = compute_shared_stimuli([dataset], n_repetitions=2)

datasets = {
    n_voxels_: sample_neuroids(dataset, n_neuroids=n_voxels_, random_state=0)
    for n_voxels_ in n_voxels
}

datasets = {
    n_voxels_: split_by_repetition(
        filter_by_stimulus(dataset, stimuli=repeated_stimuli),
        n_repetitions=2,
    )
    for n_voxels_, dataset in datasets.items()
}

In [None]:
spectra = {
    normalize: xr.concat(
        [
            compute_within_individual_spectra(
                {0: datasets_},
                normalize=normalize,
                n_permutations=0,
                n_bootstraps=0,
            ).expand_dims(n_voxels=[str(n_voxels_)])
            for n_voxels_, datasets_ in datasets.items()
        ],
        dim="n_voxels",
    )
    for normalize in (True, False)
}

In [None]:
fig, axes = plt.subplots(
    figsize=(5, 3),
    ncols=2,
    sharex=True,
    sharey=False,
)

for normalize, ax in zip((False, True), axes.flat, strict=False):
    plot_spectra(
        spectra=spectra[normalize],
        ax=ax,
        hue="n_voxels",
        hue_order=[str(_) for _ in n_voxels],
        hue_labels=[str(_) for _ in n_voxels],
        palette="viridis_r",
    )
    ax.set_xscale("log")
    ax.set_yscale("log")

    ax.set_xlim(left=1, right=1e4)
    ax.set_xticks([1, 1e1, 1e2, 1e3, 1e4])
    if normalize:
        ax.set_ylim(bottom=1e-8, top=1e-1)
        ax.set_yticks([10**x for x in range(-8, 0)])
    else:
        ax.set_ylim(bottom=1e-4, top=1e3)
        ax.set_yticks([10**x for x in range(-4, 4)])
    prefix = "after" if normalize else "before"
    ax.set_title(f"{prefix} normalization")

axes[1].legend(
    loc="lower left",
    title="# voxels",
    alignment="right",
)
axes[0].set_ylabel("covariance")
axes[1].set_ylabel("covariance")
for ax in axes.flat:
    ax.set_xlabel("rank")

fig.suptitle(f"within-subject, subject {1 + REFERENCE_SUBJECT}", y=0.95)
fig.tight_layout()

save_figure(
    fig,
    filepath=FIGURES_HOME / "vary-n-voxels.pdf",
)