In [None]:
from bids import BIDSLayout
from tqdm.auto import tqdm
from tqdm.contrib.concurrent import process_map
from pathlib import Path
import nibabel as nib
import numpy as np
from src.metrics import corr
from sklearn.preprocessing import StandardScaler
from nilearn.glm import fdr_threshold
from scipy.stats import ttest_1samp, norm
import pandas as pd
import matplotlib as mpl
from nilearn import plotting, image
import os, h5py
from joblib import Parallel, delayed
import re
os.makedirs("figs", exist_ok=True)
lang = "EN"
n_runs = 9
layout = BIDSLayout("data/li2022/derivatives", validate=False, is_derivative=True)
subjects = layout.get_subjects()
subjects = [s for s in subjects if re.match(lang+r"\d+$", s)]
n_subjects = len(subjects)

# Build mean EN

In [None]:
def average(run):
    layout = BIDSLayout("data/li2022/derivatives", validate=False, is_derivative=True)
    subjects = [s for s in layout.get_subjects() if s.startswith(lang)]
    a = None
    affine = None
    for subject in tqdm(subjects, desc=f"Run {run}", leave=False):
        imgs = sorted([f.path for f in layout.get(subject=subject)])
        assert len(imgs) == n_runs
        img = nib.load(imgs[run])
        if a is None:
            a = img.get_fdata()
            affine = img.affine
        else:
            a += img.get_fdata()
        img = layout.get(subject=subject)[0]
        new_path = img.path.replace("sub-"+subject, "sub-mean"+lang)
        new_path = new_path.replace(f"run-{img.entities['run']}", f"run-{run+1:02d}")
    a /= len(subjects)
    Path(new_path).parent.mkdir(parents=True, exist_ok=True)
    nib.Nifti1Image(a, affine).to_filename(new_path)
    print(new_path)

In [None]:
for run in tqdm(range(n_runs)):
    average(run)

# Keep only brain

In [None]:
mask = nib.load("data/li2022/colin27_t1_tal_lin_mask.nii")
mask = image.resample_to_img(mask, nib.load(layout.get()[0]), interpolation="nearest").get_fdata().astype(bool)
mask = np.where(mask)

In [None]:
def slice_brain(subject, path):
    new_path = path.replace("sub-" + subject, "sub-" + subject + "brain")
    new_path = Path(new_path.replace(".nii.gz", ".hf5"))
    if new_path.exists():
        return
    new_path.parent.mkdir(parents=True, exist_ok=True)
    with h5py.File(new_path, "w") as f:
        f.create_dataset("data", data=nib.load(path).get_fdata()[mask].T, compression="gzip")

In [None]:
_ = Parallel(n_jobs=-2, verbose=10)(delayed(slice_brain)(subject, img.path) for subject in subjects for img in layout.get(subject=subject))

# ISC

In [None]:
def read(path):
    return path, nib.load(path).get_fdata()

In [None]:
mean_imgs = sorted([f.path for f in layout.get(subject="meanEN")])
mean_imgs = sorted(process_map(read, mean_imgs[:3]))
mean_imgs = np.concatenate([m[1] for m in mean_imgs], axis=-1) * n_subjects

In [None]:
corrs = []
for subject in tqdm(subjects):
    subject_imgs = sorted([f.path for f in layout.get(subject=subject)])
    subject_imgs = sorted(process_map(read, subject_imgs[:3], leave=False))
    subject_imgs = np.concatenate([m[1] for m in subject_imgs], axis=-1)
    c = corr(mean_imgs - subject_imgs, subject_imgs, axis=-1)
    corrs.append(c)
corrs = np.stack(corrs)

In [None]:
pvalues = np.nan_to_num(ttest_1samp(np.arctanh(corrs), popmean=0, axis=0).pvalue, nan=1)
zscores = norm.ppf(1 - pvalues)
thresh = fdr_threshold(zscores.reshape(-1), 5e-2)
signif = np.where(zscores > thresh)

In [None]:
affine = nib.load(layout.get()[0]).affine
df = pd.DataFrame(image.coord_transform(*signif, affine), index=["x", "y", "z"]).T
df["zscore"] = zscores[signif]
df.to_csv(f"data/li2022/ISC_voxels_{lang}.csv", index=False)
np.save(f"data/li2022/ISC_voxels_{lang}.npy", {"pval": pvalues, "zscore": zscores, "thresh": thresh, "signif": signif})

In [None]:
isc = np.load("data/li2022/ISC_voxels_EN.npy", allow_pickle=True).item()
zscores = isc["zscore"]
signif = isc["signif"]
thresh = isc["thresh"]
df = pd.read_csv(f"data/li2022/ISC_voxels_{lang}.csv")

In [None]:
with tqdm(total=(n_subjects + 1) * n_runs) as pbar:
    for subject in subjects + [f"mean{lang}"]:
        imgs = sorted([f.path for f in layout.get(subject=subject)])
        assert len(imgs) == n_runs
        for img_file in imgs:
            new_img_file = img_file.replace("sub-"+subject, "sub-"+subject+"ISC")
            new_img_file = Path(new_img_file.replace(".nii.gz", ".hf5"))
            if new_img_file.exists():
                pbar.update(1)
                continue
            new_img_file.parent.mkdir(parents=True, exist_ok=True)
            img = nib.load(img_file).get_fdata()[signif].T
            with h5py.File(new_img_file, "w") as f:
                f.create_dataset("data", data=img, compression="gzip")
            pbar.update(1)

In [None]:
cmap = mpl.cm.jet
df_display = df.sample(len(df) // 10)
max_zscore = df_display.zscore.max()
norm = mpl.colors.Normalize(vmin=0, vmax=max_zscore)
cbar = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
colors = cmap(norm(df_display.zscore))
sizes = df_display.zscore ** 2
sizes = 10 * sizes / sizes.max()

In [None]:
plotting.view_connectome(adjacency_matrix=np.diag(df_display.zscore), node_coords=df_display[["x", "y", "z"]], node_color=colors, edge_cmap=cmap, node_size=sizes, symmetric_cmap=False, title=f"Correlation zscore from ISC on the {n_subjects} {lang} subjects, zscore FDR thresh {thresh:.3g}, 10% of the {len(df)} selected voxels are displayed").save_as_html("figs/li2022_isc.html")

# SRM

In [None]:
def read(subject, run, path):
    img = nib.load(path).get_fdata()
    img = img.reshape(-1, img.shape[-1])
    img = StandardScaler(copy=False).fit_transform(img.T).T
    return subject, run, img

In [None]:
res = Parallel(n_jobs=-2, verbose=2)(delayed(read)(subject, run, path) for subject in subjects[:2] for run, path in enumerate(layout.get(subject=subject)[:2]))
res = {(subject, run): img for subject, run, img in res}

In [None]:
imgs = [[res[(subject, run)] for run in range(2)] for subject in subjects[:2]]

In [None]:
from fastsrm.identifiable_srm import IdentifiableFastSRM

In [None]:
srm = IdentifiableFastSRM(n_components=10)

In [None]:
X = srm.fit_transform(imgs)

In [None]:
W = srm.basis_list[0]