In [None]:
import sys
sys.path.append('..')
sys.path.append('../..')
sys.path.append('../../..')

import selective_recruitment.globals as gl


import Functional_Fusion.atlas_map as am
import Functional_Fusion.dataset as ds

from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import SUITPy.flatmap as flatmap
import nitools as nt
import nilearn.plotting as plotting
from nilearn import datasets # this will be used to plot sulci on top of the surface
import matplotlib.colors as colors
from matplotlib.colors import LinearSegmentedColormap
from nilearn.plotting.cm import _cmap_d as nilearn_cmaps

import warnings
warnings.filterwarnings('ignore')

In [None]:
# # choose from these colormaps
# nmaps = len(nilearn_cmaps)
# a = np.outer(np.arange(0, 1, 0.01), np.ones(10))

# # Initialize the figure
# plt.figure(figsize=(10, 4.2))
# plt.subplots_adjust(top=0.4, bottom=0.05, left=0.01, right=0.99)

# for index, cmap in enumerate(nilearn_cmaps):
#     plt.subplot(1, nmaps + 1, index + 1)
#     plt.imshow(a, cmap=nilearn_cmaps[cmap])
#     plt.axis("off")
#     plt.title(cmap, fontsize=10, va="bottom", rotation=90)

In [None]:
# setting working directory
wkdir = 'A:\data\Cerebellum\CerebellumWorkingMemory\selective_recruit'
if not Path(wkdir).exists():
    wkdir = '/srv/diedrichsen/data/Cerebellum/CerebellumWorkingMemory/selective_recruit'
if not Path(wkdir).exists():
    wkdir = '/Users/jdiedrichsen/Data/wm_cerebellum/selective_recruit'
if not Path(wkdir).exists():
    wkdir = '/Volumes/diedrichsen_data$/data/Cerebellum/CerebellumWorkingMemory/selective_recruit'
    
figdir = Path(wkdir) / 'figures'

# Working memory task
Digit span task with factors phase, recall direction, and load. 

Phase: encoding, retrieval

recall direction: forward, backward

load: 2, 4, 6

## Setting the defaults for data preparation

In [None]:
# create an instance of the dataset class
dataset = "WMFS"
Data = ds.get_dataset_class(base_dir = gl.base_dir, dataset = dataset)

# visualizations settings
cmap = "cyan_orange" # colormap for activation maps, see above
colorbar = True # whether to plot colorbar
cscale_cereb = [-1,1] # colorbar scale for cerebellar maps
cscale_cortex = [-3,3] # colorbar scale for cortical maps

# setting defaults for plotting and analysis
smooth = 3 # smoothing kernel for activation maps
type =  "CondAll" # type of data to use. Another option is "CondHalf"
ses_id = "ses-02" # session id. For the working memory task, it is "ses-02"
subj = "group" # subject id to plot. For group activation maps, set it to "group"


# get atlas objects
atlas_str_cereb = "SUIT3"
atlas_str_cortex = "fs32k"
atlas_cereb, atlas_info = am.get_atlas(atlas_str = atlas_str_cereb, atlas_dir = gl.atlas_dir)
atlas_cortex, atlas_info = am.get_atlas(atlas_str = atlas_str_cortex, atlas_dir = gl.atlas_dir)

## Prepare data for plotting

In [None]:
# get info 
info = Data.get_info(ses_id=ses_id, type=type, subj="group", fields=None)

# define contrasts
idx_enc = info.phase == 0
c_enc = np.zeros(len(info.index))
c_enc[idx_enc] = 1
idx_ret = info.phase == 1
c_ret = np.zeros(len(info.index))
c_ret[idx_ret] = 1

# get data for the cerebellum
data_cereb,_,_ = ds.get_dataset(gl.base_dir,
                                dataset = dataset,
                                atlas=atlas_str_cereb,
                                sess=ses_id,
                                subj=subj,
                                type = type,  
                                smooth = smooth)

# get data for the cortex
data_cortex,_,_ = ds.get_dataset(gl.base_dir,
                                 dataset = dataset,
                                 atlas=atlas_str_cortex,
                                 sess=ses_id,
                                 subj=subj,
                                 type = type,  
                                 smooth = smooth)

### Plot cerebellar data

In [None]:
# get the contrast
data_cereb_enc = c_enc @ data_cereb
data_cereb_ret = c_ret @ data_cereb

enc_nii = atlas_cereb.data_to_nifti(data_cereb_enc)
ret_nii = atlas_cereb.data_to_nifti(data_cereb_ret)

# plot cerebellar maps
enc_flat = flatmap.vol_to_surf([enc_nii], 
                                stats='nanmean', 
                                space = 'SUIT', 
                                ignore_zeros=True)
ret_flat = flatmap.vol_to_surf([ret_nii],
                                stats='nanmean',
                                space = 'SUIT',
                                ignore_zeros=True)
enc_suit_fig = flatmap.plot(data=enc_flat, 
                            render="plotly", 
                            hover='auto', 
                            cmap = cmap, 
                            colorbar = colorbar, 
                            bordersize = 1.5,
                            bordercolor = "black",  
                            cscale = cscale_cereb)
ret_suit_fig = flatmap.plot(data=ret_flat,
                            render="plotly",
                            hover='auto',
                            cmap = cmap,
                            colorbar = colorbar,
                            bordersize = 1.5,
                            bordercolor = "black",
                            cscale = cscale_cereb)

# show the figure anad save?
# add title to plotly figure object
enc_suit_fig.update_layout(title={'text': 'Encoding', 'x': 0.5, 'y': 0.98})
enc_suit_fig.show()
# enc_suit_fig.write_image("working_memory_avg_activation_enc_suit.pdf")

# add title to plotly figure object
ret_suit_fig.update_layout(title={'text': 'Retrieval', 'x': 0.5, 'y': 0.98})
ret_suit_fig.show()
# ret_suit_fig.write_image("working_memory_avg_activation_ret_suit.pdf")

### Plot neocortical data

In [None]:
# plot cortical surfaces
data_cortex_enc = c_enc @ data_cortex
data_cortex_ret = c_ret @ data_cortex

# get inflated cortical surfaces
surfs = [gl.atlas_dir + f'/tpl-fs32k/tpl_fs32k_hemi-{h}_inflated.surf.gii' for i, h in enumerate(['L', 'R'])]

# first convert to cifti
enc_cii = atlas_cortex.data_to_cifti(data_cortex_enc.reshape(-1, 1).T)
ret_cii = atlas_cortex.data_to_cifti(data_cortex_ret.reshape(-1, 1).T)
enc_img = nt.surf_from_cifti(enc_cii)
ret_img = nt.surf_from_cifti(ret_cii)


In [None]:
# plot cortical maps
view = "lateral"
colorbar = False
enc_fs_fig = []
ret_fs_fig = []
for h,hemi in enumerate(['left', 'right']):
    enc_fs_fig.append(plotting.plot_surf_stat_map(
                        surfs[h], enc_img[h], hemi=hemi,
                        colorbar=True, 
                        view = view,
                        cmap=cmap,
                        engine='plotly',
                        symmetric_cbar = True,
                        vmax = cscale_cortex[1]
                    ).figure)

    ret_fs_fig.append(plotting.plot_surf_stat_map(
                        surfs[h], ret_img[h], hemi=hemi,
                        colorbar=colorbar, 
                        view = view,
                        cmap=cmap,
                        engine='plotly',
                        symmetric_cbar = True,
                        vmax = cscale_cortex[1]
                    ).figure)
    
    
# setting up cameras to get a better view of the hand area in the lateral view
# rotate to get a better view of M1 only on lateral view
camera_params = []
camera_params.append(dict( #left hemi
    center=dict(x=0,y=0,z=0),
    eye=dict(x=-1.5, y=0,z=0.9),
    up=dict(x=0,y=0,z=1),
))
camera_params.append(dict( # right hemi
    center=dict(x=0,y=0,z=0),
    eye=dict(x=1.5, y=0,z=0.9),
    up=dict(x=0,y=0,z=1),
))
    
# left hemisphere  
# enc_fs_fig[0].update_layout(title={'text': 'Encoding', 'x': 0.5, 'y': 0.98})
if view == "lateral":
    enc_fs_fig[0].update_layout(scene_camera=camera_params[0])
enc_fs_fig[0].show()
# enc_fs_fig[0].write_image(f"working memory_avg_activation_enc_{view}_L.pdf")


# ret_fs_fig[0].update_layout(title={'text': 'Retrieval', 'x': 0.5, 'y': 0.98})
if view == "lateral":
    ret_fs_fig[0].update_layout(scene_camera=camera_params[0])
ret_fs_fig[0].show()
# ret_fs_fig[0].write_image(f"working memory_avg_activation_ret_{view}_L.pdf")

# right hemisphere  
# enc_fs_fig[1].update_layout(title={'text': 'Encoding', 'x': 0.5, 'y': 0.98})
if view == "lateral":
    enc_fs_fig[1].update_layout(scene_camera=camera_params[1])
enc_fs_fig[1].show()
# enc_fs_fig[1].write_image(f"working memory_avg_activation_enc_{view}_R.pdf")


# ret_fs_fig[1].update_layout(title={'text': 'Retrieval', 'x': 0.5, 'y': 0.98})
if view == "lateral":
    ret_fs_fig[1].update_layout(scene_camera=camera_params[1])
ret_fs_fig[1].show()
# ret_fs_fig[1].write_image(f"working memory_avg_activation_ret_{view}_R.pdf")

## Plotting specific contrasts

### Prepare data for load and recall direction contrasts

In [None]:
phase = 0 # choose the phase. 0 for encoding, 1 for retrieval
phase_str = ["encoding", "retrieval"]
# initialize contrasts
c_load = np.zeros(len(info.index))
c_recall = np.zeros(len(info.index))

# prep contrast vector for load effect
idx_load_eff = (info.load == 6) & (info.recall == 1) & (info.phase == phase)
idx_load_base = (info.load == 2) & (info.recall == 1) & (info.phase == phase)
c_load[idx_load_eff] = 1/sum(idx_load_eff)
c_load[idx_load_base] = -1/sum(idx_load_base)


# prep contrast vector for recall effect
idx_recall_eff = (info.recall == 0) & (info.phase == phase)
idx_recall_base = (info.recall == 1) & (info.phase == phase)
c_recall[idx_recall_eff] = 1/sum(idx_recall_eff)
c_recall[idx_recall_base] = -1/sum(idx_recall_base) 

### Plot cerebellar data for contrasts during encoding and retrieval

In [None]:
# get the contrast map for load
data_cereb_load = c_load @ data_cereb
load_nii = atlas_cereb.data_to_nifti(data_cereb_load)


# get the contrast map for recall direction
data_cereb_recall = c_recall @ data_cereb
recall_nii = atlas_cereb.data_to_nifti(data_cereb_recall)


load_nii = atlas_cereb.data_to_nifti(data_cereb_load)
recall_nii = atlas_cereb.data_to_nifti(data_cereb_recall)

# plot cerebellar maps
load_flat = flatmap.vol_to_surf([load_nii], 
                                stats='nanmean', 
                                space = 'SUIT', 
                                ignore_zeros=True)
recall_flat = flatmap.vol_to_surf([recall_nii],
                                stats='nanmean',
                                space = 'SUIT',
                                ignore_zeros=True)
load_suit_fig = flatmap.plot(data=load_flat, 
                            render="plotly", 
                            hover='auto', 
                            cmap = cmap, 
                            colorbar = colorbar, 
                            bordersize = 1.5,
                            bordercolor = "black",  
                            cscale = [-0.05, 0.05])
recall_suit_fig = flatmap.plot(data=recall_flat,
                            render="plotly",
                            hover='auto',
                            cmap = cmap,
                            colorbar = colorbar,
                            bordersize = 1.5,
                            bordercolor = "black",
                            cscale = [-0.05, 0.05])

# show the figure anad save?
# add title to plotly figure object
load_suit_fig.update_layout(title={'text': f"load {phase_str[phase]}", 'x': 0.5, 'y': 0.98})
load_suit_fig.show()
# load_suit_fig.write_image(f"working_memory_avg_activation_load_{phase_str[phase]}_suit.pdf")

# add title to plotly figure object
recall_suit_fig.update_layout(title={'text': f"recall {phase_str[phase]}", 'x': 0.5, 'y': 0.98})
recall_suit_fig.show()
# recall_suit_fig.write_image(f"working_memory_avg_activation_recall_{phase_str[phase]}_suit.pdf")

### Plot cortical data

In [None]:
# get the contrast
data_cortex_load = c_load @ data_cortex
data_cortex_recall = c_recall @ data_cortex

# first convert to cifti
load_cii = atlas_cortex.data_to_cifti(data_cortex_load.reshape(-1, 1).T)
recall_cii = atlas_cortex.data_to_cifti(data_cortex_recall.reshape(-1, 1).T)
load_img = nt.surf_from_cifti(load_cii)
recall_img = nt.surf_from_cifti(recall_cii)

# plot cortical maps
view = "medial"
colorbar = False
load_fs_fig = []
recall_fs_fig = []
for h,hemi in enumerate(['left', 'right']):
    load_fs_fig.append(plotting.plot_surf_stat_map(
                        surfs[h], load_img[h], hemi=hemi,
                        colorbar=False, 
                        view = view,
                        cmap=cmap,
                        engine='plotly',
                        symmetric_cbar = True,
                        vmax = 0.1
                    ).figure)

    recall_fs_fig.append(plotting.plot_surf_stat_map(
                        surfs[h], recall_img[h], hemi=hemi,
                        colorbar=False, 
                        view = view,
                        cmap=cmap,
                        engine='plotly',
                        symmetric_cbar = True,
                        vmax = 0.1
                    ).figure)
    
    
# setting up cameras to get a better view of the hand area in the lateral view
# rotate to get a better view of M1 only on lateral view
camera_params = []
camera_params.append(dict( #left hemi
    center=dict(x=0,y=0,z=0),
    eye=dict(x=-1.5, y=0,z=0.9),
    up=dict(x=0,y=0,z=1),
))
camera_params.append(dict( # right hemi
    center=dict(x=0,y=0,z=0),
    eye=dict(x=1.5, y=0,z=0.9),
    up=dict(x=0,y=0,z=1),
))
    
    
# show the figure anad save?
# left hemisphere  
# enc_fs_fig[0].update_layout(title={'text': 'Encoding', 'x': 0.5, 'y': 0.98})
if view == "lateral":
    load_fs_fig[0].update_layout(scene_camera=camera_params[0])
load_fs_fig[0].show()
# load_fs_fig[0].write_image(f"working memory_avg_activation_load_{phase_str[phase]}_{view}_L.pdf")


# ret_fs_fig[0].update_layout(title={'text': 'Retrieval', 'x': 0.5, 'y': 0.98})
if view == "lateral":
    recall_fs_fig[0].update_layout(scene_camera=camera_params[0])
recall_fs_fig[0].show()
# recall_fs_fig[0].write_image(f"working memory_avg_activation_recall_{phase_str[phase]}_{view}_L.pdf")

# right hemisphere  
# enc_fs_fig[1].update_layout(title={'text': 'Encoding', 'x': 0.5, 'y': 0.98})
if view == "lateral":
    load_fs_fig[1].update_layout(scene_camera=camera_params[1])
load_fs_fig[1].show()
# load_fs_fig[1].write_image(f"working memory_avg_activation_load_{phase_str[phase]}_{view}_R.pdf")


# ret_fs_fig[1].update_layout(title={'text': 'Retrieval', 'x': 0.5, 'y': 0.98})
if view == "lateral":
    recall_fs_fig[1].update_layout(scene_camera=camera_params[1])
recall_fs_fig[1].show()
# recall_fs_fig[1].write_image(f"working memory_avg_activation_recall_{phase_str[phase]}_{view}_R.pdf")