In [1]:
import re
import time
from collections import OrderedDict
from itertools import accumulate, product
from pathlib import Path

import networkx as nx
import nibabel as nib
import numpy as np
import pandas as pd
import statsmodels.api as sm
import matplotlib as mpl
from matplotlib import cm
from matplotlib import pyplot as plt
from matplotlib import ticker

mpl.use("pgf")
plt.style.use("seaborn-whitegrid")
plt.rcParams["font.family"] = "sans-serif"
plt.rcParams["text.usetex"] = True
plt.rcParams["pgf.rcfonts"] = False
plt.rcParams["pgf.texsystem"] = "lualatex"
plt.rcParams["pgf.preamble"] = """
\\usepackage{fontspec}
\\usepackage[T1]{fontenc}
\\usepackage[utf8]{inputenc}
\\usepackage{unicode-math}

\defaultfontfeatures{
    Extension = .otf,
}

\\setmainfont{HelveticaNeueLTStd}[
    UprightFont=*-Roman,
    ItalicFont=*-It,
    BoldFont=*-Md,
    BoldItalicFont=*-MdIt,
    FontFace={xl}{n}{*-UltLt},
    FontFace={xl}{it}{*-UltLtIt},
    FontFace={l}{n}{*-Lt},
    FontFace={l}{it}{*-LtIt},
    FontFace={mb}{n}{*-Md},
    FontFace={mb}{it}{*-MdIt},
    FontFace={k}{n}{*-Blk},
    FontFace={k}{it}{*-BlkIt},
    Scale=0.9,
]
\\setsansfont{HelveticaNeueLTStd}[
    UprightFont=*-Roman,
    ItalicFont=*-It,
    BoldFont=*-Md,
    BoldItalicFont=*-MdIt,
    FontFace={xl}{n}{*-UltLt},
    FontFace={xl}{it}{*-UltLtIt},
    FontFace={l}{n}{*-Lt},
    FontFace={l}{it}{*-LtIt},
    FontFace={mb}{n}{*-Md},
    FontFace={mb}{it}{*-MdIt},
    FontFace={k}{n}{*-Blk},
    FontFace={k}{it}{*-BlkIt},
    Scale=0.9,
]

\\setmathfont{latinmodern-math.otf}
\\setmathfont[
    range=\\mathup,
    Scale=0.9,
]{HelveticaNeueLTStd-Roman}
"""

## Commonality math

In [2]:
def interleave(b: tuple[int, ...], a: tuple[int, ...]) -> tuple[int, ...]:
    c = np.ones(len(a), dtype=int)
    c[np.array(a, dtype=bool)] = b
    return tuple(c)


def commonality_polynomial(terms: tuple[int, ...]) -> list[tuple[tuple[int, ...], int]]:
    return [
        (
            interleave(exponents, terms),
            -1 if sum(exponents) % 2 == 0 else 1,
        )
        for exponents in product((0, 1), repeat=sum(terms))
    ]


def commonality_polynomials(n: int):
    for terms in product((0, 1), repeat=n):
        if sum(terms) > 0:
            yield (
                terms,
                commonality_polynomial(terms),
            )

In [3]:
def weight(v: tuple[int, ...], w: tuple[int, ...]):
    n = len(v)
    return sum(
        0 if a == b else (10 if i == 0 else 1) for i, (a, b) in enumerate(zip(v, w))
    )


def traveling_salesman_sort(commonality):
    v, _ = zip(*commonality)  # extract component vectors

    adjacency = np.array([[weight(v_i, v_j) for v_j in v] for v_i in v])

    graph = nx.from_numpy_array(adjacency)
    cycle = nx.approximation.traveling_salesman_problem(graph, cycle=True)

    for pos, i in enumerate(cycle):
        if v[i][0] == 1:
            break

    path = cycle[pos:-1] + cycle[:pos]

    return [commonality[i] for i in path]

## Load time series data

In [4]:
steps = ["resample", "smooth", "ica_aroma", "temporal_filter"]

variable_groups = ["task", "ica_aroma_signal", "ica_aroma_noise", "motion", "wm_csf", "a_comp_cor", "global_signal"]
variable_group_patterns = dict(
    ica_aroma_signal=[
        r"aroma_signal_[0-9]+",
    ],
    ica_aroma_noise=[
        r"aroma_noise_[0-9]+",
    ],
    motion=[
        r"framewise_displacement",
        r"dvars",
        r"std_dvars",
        r"rmsd",
        r"(trans|rot)_[xyz](_derivative1)?(_power2)?",
    ],
    wm_csf=[
        r"(white_matter|csf)(_derivative1)?(_power2)?",
        r"csf_wm",
    ],
    a_comp_cor=[
        r"a_comp_cor_0[0-4]",
    ],
    global_signal=[
        r"global_signal(_derivative1)?(_power2)?",
    ],
)

In [5]:
voxel_coordinate = (61, 15, 47)
repetition_time = 2.0

In [6]:
design_file = "data/sub-01_task-faces_run-01_feature-taskBased_desc-design_matrix.tsv"

confound_files = dict(
    resample="data/merge_with_header.tsv",
    smooth="data/merge_with_header.tsv",
    ica_aroma="data/merge_with_header_regfilt.tsv",
    temporal_filter="data/merge_with_header_regfilt_bptf_addmean.tsv",
)

image_files = dict(
    resample="data/vol0000_xform-00000_merged_masked.nii.gz",
    smooth="data/vol0000_xform-00000_merged_masked_afni.nii.gz",
    ica_aroma="data/vol0000_xform-00000_merged_masked_afni_grandmeanscaled_regfilt.nii.gz",
    temporal_filter="data/vol0000_xform-00000_merged_masked_afni_grandmeanscaled_regfilt_bptf_addmean.nii.gz",
)

need_to_scale = set(["resample", "smooth"])

In [7]:
# Reconstruct grand mean scaling factor

unscaled = pd.read_table("data/confounds_expansion_desc-motion_outliers.tsv").global_signal
data_frame = pd.read_table(confound_files["resample"])
grand_mean_scaling_factor = (data_frame.global_signal / unscaled).mean()

In [8]:
design = pd.read_table(design_file)

In [9]:
image_data = dict()
for step, image_file in image_files.items():
    image = nib.load(image_file)
    image_data[step] = image.dataobj[voxel_coordinate].astype(float)
    if step in need_to_scale:
        image_data[step] *= grand_mean_scaling_factor

regressor_data = dict()
for step, confound_file in confound_files.items():
    regressor_data[step] = dict(
        task=design,
    )

    data_frame = pd.read_table(confound_file)
    data_frame = data_frame.sub(data_frame.mean())  # demean
    
    for variable_group, patterns in variable_group_patterns.items():
        columns = [
            column
            for pattern in patterns
            for column in data_frame.columns
            if re.fullmatch(pattern, column) is not None
        ]

        regressor_data[step][variable_group] = data_frame[columns]

In [19]:
plt.figure()
for step in steps:
    plt.plot(image_data[step])

## Calculation

In [11]:
n = len(variable_groups)
m = len(steps)

p = list(commonality_polynomials(n))

var = np.std
var_label = "$\\sqrt{\\text{Variance}}$ [AU]"

In [12]:
def predict(y, x):
    return (
        sm.OLS(
            endog=y,
            exog=np.hstack([np.ones([y.size, 1]), x.fillna(0).values]),
        )
        .fit()
        .predict()
    )

In [13]:
y_dict = dict()
for step in steps:
    y = np.copy(image_data[step])
    y -= y.mean()
    y_dict[step] = y

rsquareds_dict = dict()
for step in steps:
    y = y_dict[step] 
    
    rsquareds_dict[step] = {
        is_selected: sm.OLS(
            endog=y,
            exog=np.hstack(
                [
                    np.ones([y.size, 1]),  # intercept
                    *[
                        regressor_data[step][variable_groups[i]].fillna(0).values
                        for i in range(len(variable_groups))
                        if is_selected[i] == 1
                    ],
                ]
            ),
        )
        .fit()
        .rsquared
        for is_selected in product((0, 1), repeat=len(variable_groups))
    }

In [14]:
maximum_variance = max(var(y) for y in y_dict.values())

In [15]:
# Labels

def circled(s):
    return "\\raisebox{.5pt}{\\textcircled{\\raisebox{-.9pt} {" + str(s) + "}}}"

step_labels = dict(
    resample=f"{circled(0)} \\textbf{{Resampled image}}",
    smooth=f"{circled(1)} \\textbf{{Post smoothing}}",
    ica_aroma=f"{circled(2)} \\textbf{{Post ICA-AROMA component regression}}",
    temporal_filter=f"{circled(3)} \\textbf{{Post temporal filter (high-pass only)}}",
)

variable_group_labels = dict(
    task="Task",
    ica_aroma_signal="ICA-AROMA\nSignal",
    ica_aroma_noise="ICA-AROMA\nNoise",
    motion="Motion",
    wm_csf="WM/CSF",
    global_signal="Global signal",
    a_comp_cor="aCompCor",
    auto_corr="Autocorrelation",
)

In [16]:
# Plotting helpers

time_formatter = ticker.FuncFormatter(
    lambda seconds, x: time.strftime("%M:%S", time.gmtime(seconds))
)

cmap = cm.get_cmap("Set2")

In [17]:
fig, axs = plt.subplots(
    m,
    2,
    figsize=(10, 12),
    gridspec_kw=dict(width_ratios=[0.7, 1]),
)
fig.tight_layout()

axs[0, 0].get_shared_y_axes().join(*[axs[k, 0] for k in range(m)])
axs[0, 1].get_shared_x_axes().join(*[axs[k, 1] for k in range(m)])

for k, step in enumerate(steps):
    y = y_dict[step]
    rsquareds = rsquareds_dict[step]

    commonality = [
        (
            component,
            sum(weight * rsquareds[term] for term, weight in terms),
        )
        for component, terms in p
    ]

    axs[k, 0].set_title(step_labels[step], loc="left", pad=10)

    axs[k, 0].set_yticks([])
    axs[k, 0].grid(False)
    axs[k, 0].xaxis.set_major_formatter(time_formatter)
    axs[k, 0].set(xlabel="Time")

    axs[k, 1].set_yticks([i - 0.125 for i in range(-1, n)])
    axs[k, 1].set_yticklabels(
        ["Total"] + [variable_group_labels[key] for key in reversed(variable_groups)],
        linespacing=0.9,
    )
    axs[k, 1].yaxis.tick_right()
    axs[k, 1].set(xlabel=var_label)

    for key in ["task"]:
        y_hat = predict(y, regressor_data[step][key])

        axs[k, 0].plot(
            np.arange(len(y)) * repetition_time,
            y_hat,
            "o--",
            color=(
                *cmap(variable_groups.index(key))[:3],
                0.75,  # opacity
            ),
            linewidth=0.5,
            markersize=1,
        )

    axs[k, 0].plot(
        np.arange(len(y)) * repetition_time,
        y,
        "o-",
        color="black",
        linewidth=0.5,
        markersize=1,
    )

    negative_commonality = sum(
        proportion for _, proportion in commonality if proportion < 0
    )
    variance = var(y)

    commonality_sorted = traveling_salesman_sort(
        [
            (component, proportion)
            for component, proportion in commonality
            if not np.isclose(proportion, 0) and not proportion < 0
        ]
    )

    for i in range(n):
        intervals = []

        x = 0
        for component, proportion in commonality_sorted:
            if proportion < 0:
                continue

            x += proportion

            if component[i] != 1:
                continue  # not relevant for this row

            intervals.append(((x - proportion) * variance, proportion * variance))

        axs[k, 1].broken_barh(
            intervals, (n - i - 1.5, 0.75), color=cmap(i), capstyle="butt"
        )

        x_max = max(map(sum, intervals))

        if i == 0:
            axs[k, 1].axvline(0, color=cmap(i), linewidth=0.5)
            axs[k, 1].axvline(x_max, color=cmap(i), linewidth=0.5)

        rsquared = rsquareds[tuple(1 if j == i else 0 for j in range(n))]

        axs[k, 1].text(
            x_max + 0.05 * maximum_variance,
            n - i - 1.175,
            f"{rsquared * 100:.0f}%",  # ({suppression * 100:.0f}%)",
            ha="left",
            va="center",
            color="black",
            backgroundcolor="white",
        )

    intervals = [
        (0, variance / (1 + negative_commonality)),
    ]
    axs[k, 1].broken_barh(intervals, (-1.5, 0.75), color=(0.8, 0.8, 0.8, 1.0))


plt.subplots_adjust(wspace=0.05, hspace=0.35, top=0.95, bottom=0.05, right=0.91)

plt.savefig("confs.pdf", backend="pgf")