In [1]:
import os, re, glob, json
from pathlib import Path
import numpy as np
import pandas as pd
import nibabel as nib
from nilearn.glm.first_level import FirstLevelModel
from nilearn.glm.second_level import SecondLevelModel, non_parametric_inference
from nilearn import image
import matplotlib.pyplot as plt
from nilearn.plotting import plot_design_matrix
from nilearn.plotting import plot_stat_map
from nilearn.glm import threshold_stats_img
from nilearn import plotting
from nilearn.image import load_img, resample_to_img
from sklearn.model_selection import LeaveOneGroupOut
from nilearn.decoding import Decoder
from sklearn.metrics import classification_report


In [2]:
BASE = "/local/anpa/ds003242-1.0.0"
DERIVATIVES = f"{BASE}/derivatives/"
DERIVATIVES_FMRIPREP = f"{DERIVATIVES}/fmriprep/" 
FIRST_LEVEL_SEP_RUNS = f"{DERIVATIVES}/firstlevel_separate_runs/"  
TASK = "CIC"
TR = 2.0

In [3]:
def resample_mask_to_bold(anat_mask, bold_img):
    """Resample anat mask to the space of the bold image.
    As in tutorial https://nilearn.github.io/dev/auto_examples/04_glm_first_level/plot_first_level_details.html
    """
    return resample_to_img(
        anat_mask,
        bold_img,
        interpolation="nearest",
        copy_header=True,
        force_resample=True,
    )

def fprep_func_dir(sub):
    return Path(BASE)/f"derivatives/fmriprep/sub-{sub}/func"

In [4]:
all_sub_dirs = sorted([p.name.split("sub-")[-1] for p in Path(FIRST_LEVEL_SEP_RUNS).glob("sub-*") if p.is_dir()])
all_sub_dirs[:3]

['SAXSISO01b', 'SAXSISO01f', 'SAXSISO01s']

In [5]:
fasting_participants = [s for s in all_sub_dirs if s.endswith("f")]
social_participants = [s for s in all_sub_dirs if s.endswith("s")]
baseline_participants = [s for s in all_sub_dirs if s.endswith("b")]

In [6]:
fasting_food = [
    f
    for s in fasting_participants
    for f in Path(f"{FIRST_LEVEL_SEP_RUNS}/sub-{s}").rglob("*Food_*_zmap.nii.gz")]
fasting_social = [
    f
    for s in fasting_participants
    for f in Path(f'{FIRST_LEVEL_SEP_RUNS}/sub-{s}').rglob("*Social_*_zmap.nii.gz")]
fasting_control = [
    f
    for s in fasting_participants
    for f in Path(f'{FIRST_LEVEL_SEP_RUNS}/sub-{s}').rglob("*Control_*_zmap.nii.gz")]


social_food = [
    f
    for s in social_participants
    for f in Path(f'{FIRST_LEVEL_SEP_RUNS}/sub-{s}').rglob("*Food_*_zmap.nii.gz")]
social_social = [
    f
    for s in social_participants
    for f in Path(f'{FIRST_LEVEL_SEP_RUNS}/sub-{s}').rglob("*Social_*_zmap.nii.gz")]
social_control = [
    f
    for s in social_participants
    for f in Path(f'{FIRST_LEVEL_SEP_RUNS}/sub-{s}').rglob("*Control_*_zmap.nii.gz")]


baseline_food = [
    f
    for s in baseline_participants
    for f in Path(f'{FIRST_LEVEL_SEP_RUNS}/sub-{s}').rglob("*Food_*_zmap.nii.gz")]
baseline_social = [
    f
    for s in baseline_participants
    for f in Path(f'{FIRST_LEVEL_SEP_RUNS}/sub-{s}').rglob("*Social_*_zmap.nii.gz")]
baseline_control = [
    f
    for s in baseline_participants
    for f in Path(f'{FIRST_LEVEL_SEP_RUNS}/sub-{s}').rglob("*Control_*_zmap.nii.gz")]

In [7]:
def compare(labels: tuple, z_maps: tuple):

    assert len(labels) == len(z_maps), "Labels and z_maps must have the same length"

    X = np.array([])
    y = np.array([])
    groups = np.array([])

    for label, zmap in zip(labels,z_maps):
        runs = [str(f).split('/')[-1][0] for f in zmap]
        labels = [label] * len(zmap)
        X = np.concatenate((X, np.array(zmap))) if X.size else np.array(zmap)
        y = np.concatenate((y, np.array(labels))) if y.size else np.array(labels)
        groups = np.concatenate((groups, np.array(runs))) if groups.size else np.array(runs)

    decoder = Decoder(t_r=TR, estimator='svc', scoring='accuracy', mask=None, standardize=False, cv=LeaveOneGroupOut(), n_jobs=-1)
    decoder.fit(X, y, groups=groups)

    classification_accuracy = np.mean(list(decoder.cv_scores_.values()))
    print(
        f"Classification accuracy: {classification_accuracy:.4f} / "
    )

    for label, scores in decoder.cv_scores_.items():
        print(label, np.mean(scores))

    return decoder, (X, y)

# Fasting day VS Baseline day. Binary classification. Food VS Control

Cross classification accuracy should be decent here. Same tasks, different conditions

In [8]:
decoder_ff_fc, data_ff_fc = compare(('Food', 'Control'), (fasting_food, fasting_control))

Classification accuracy: 0.8401 / 
Control 0.8400537634408601
Food 0.8400537634408601


In [9]:
decoder_bf_bc, data_bf_bc = compare(('Food', 'Control'), (baseline_food, baseline_control))

Classification accuracy: 0.7886 / 
Control 0.7885584677419355
Food 0.7885584677419355


In [10]:
# Lets check cross classification

y_pred = decoder_ff_fc.predict(data_bf_bc[0])
y_true = data_bf_bc[1]
print(classification_report(y_true, y_pred, target_names=['Control', 'Food']))

              precision    recall  f1-score   support

     Control       0.80      0.78      0.79       573
        Food       0.79      0.80      0.79       573

    accuracy                           0.79      1146
   macro avg       0.79      0.79      0.79      1146
weighted avg       0.79      0.79      0.79      1146



In [11]:
y_pred = decoder_bf_bc.predict(data_ff_fc[0])
y_true = data_ff_fc[1]
print(classification_report(y_true, y_pred, target_names=['Control', 'Food']))

              precision    recall  f1-score   support

     Control       0.85      0.76      0.81       570
        Food       0.79      0.87      0.83       570

    accuracy                           0.82      1140
   macro avg       0.82      0.82      0.82      1140
weighted avg       0.82      0.82      0.82      1140



# Social isolation day VS Baseline day. Binary classification. Social VS Control

Cross classification accuracy should be decent here. Same tasks, different conditions

In [12]:
decoder_ss_sc, data_ss_sc = compare(('Social', 'Control'), (social_social, social_control))

Classification accuracy: 0.9601 / 
Control 0.9600694444444445
Social 0.9600694444444445


In [13]:
decoder_bs_bc, data_bs_bc = compare(('Social', 'Control'), (baseline_social, baseline_control))

Classification accuracy: 0.9597 / 
Control 0.959733422939068
Social 0.959733422939068


In [14]:
# Lets check cross classification

y_pred = decoder_ss_sc.predict(data_bs_bc[0])
y_true = data_bs_bc[1]
print(classification_report(y_true, y_pred, target_names=['Control', 'Social']))

              precision    recall  f1-score   support

     Control       0.94      0.96      0.95       573
      Social       0.96      0.93      0.95       573

    accuracy                           0.95      1146
   macro avg       0.95      0.95      0.95      1146
weighted avg       0.95      0.95      0.95      1146



In [15]:
y_pred = decoder_bs_bc.predict(data_ss_sc[0])
y_true = data_ss_sc[1]
print(classification_report(y_true, y_pred, target_names=['Control', 'Social']))

              precision    recall  f1-score   support

     Control       0.98      0.94      0.96       576
      Social       0.94      0.98      0.96       576

    accuracy                           0.96      1152
   macro avg       0.96      0.96      0.96      1152
weighted avg       0.96      0.96      0.96      1152



# Cross classification. Food vs Control on fasting day VS Social pictures vs Control on social isolation day

In [16]:
y_pred = decoder_ff_fc.predict(data_ss_sc[0])
y_pred[y_pred == 'Food'] = "Social"

y_true = data_ss_sc[1]

print(classification_report(y_true, y_pred, target_names=['Control', 'Social']))

              precision    recall  f1-score   support

     Control       0.47      0.74      0.58       576
      Social       0.39      0.17      0.23       576

    accuracy                           0.45      1152
   macro avg       0.43      0.45      0.41      1152
weighted avg       0.43      0.45      0.41      1152



In [17]:
y_pred = decoder_ss_sc.predict(data_ff_fc[0])
y_pred[y_pred == 'Social'] = "Food"

y_true = data_ff_fc[1]

print(classification_report(y_true, y_pred, target_names=['Control', 'Food']))

              precision    recall  f1-score   support

     Control       0.49      0.96      0.65       570
        Food       0.22      0.01      0.02       570

    accuracy                           0.49      1140
   macro avg       0.36      0.49      0.34      1140
weighted avg       0.36      0.49      0.34      1140



In [19]:
from collections import Counter
Counter(y_pred)

Counter({'Control': 1113, 'Food': 27})