# Explaining cortical thickness development: Modeled data animation

In [1]:
import sys
import os
from os.path import join
from glob import glob
import numpy as np
import pandas as pd

from matplotlib import pyplot as plt
from matplotlib import cm
from matplotlib.cm import get_cmap
from matplotlib.colors import Normalize
import seaborn as sns
import nilearn as nl
from surfplot import Plot
import imageio.v2 as iio

# custom functions
from scripts.templates import get_destrieux

# working path
wd = os.path.dirname(os.path.realpath("__file__"))
print("Working directory:", wd)

# JuSpyce
sys.path.append(os.path.dirname(join(wd, "scripts", "juspyce")))
from juspyce.api import JuSpyce

# plot directories
plot_dir_gif = join(wd, "plots", "prediction_dominance", "animation")

Working directory: /Users/llotter/projects/ntct


## Get data

### Parcellation

In [2]:
# parcellation
parc_destrieux, destrieux_idps = get_destrieux()

### Timepoints

In [3]:
# time points we look at:
tps_whole = (5,30)
tps_steps = [(i,i+5) for i in np.arange(5,86,1)]
tps_index = [f"Δ({tps_whole[0]},{tps_whole[1]})"]+[f"Δ({s[0]},{s[1]})" for s in tps_steps]
print("Age-differences to look at:", tps_steps)

Age-differences to look at: [(5, 10), (6, 11), (7, 12), (8, 13), (9, 14), (10, 15), (11, 16), (12, 17), (13, 18), (14, 19), (15, 20), (16, 21), (17, 22), (18, 23), (19, 24), (20, 25), (21, 26), (22, 27), (23, 28), (24, 29), (25, 30), (26, 31), (27, 32), (28, 33), (29, 34), (30, 35), (31, 36), (32, 37), (33, 38), (34, 39), (35, 40), (36, 41), (37, 42), (38, 43), (39, 44), (40, 45), (41, 46), (42, 47), (43, 48), (44, 49), (45, 50), (46, 51), (47, 52), (48, 53), (49, 54), (50, 55), (51, 56), (52, 57), (53, 58), (54, 59), (55, 60), (56, 61), (57, 62), (58, 63), (59, 64), (60, 65), (61, 66), (62, 67), (63, 68), (64, 69), (65, 70), (66, 71), (67, 72), (68, 73), (69, 74), (70, 75), (71, 76), (72, 77), (73, 78), (74, 79), (75, 80), (76, 81), (77, 82), (78, 83), (79, 84), (80, 85), (81, 86), (82, 87), (83, 88), (84, 89), (85, 90)]


### Dominance analyses

In [4]:
juspyce_ct_slr = JuSpyce.from_pickle(\
    join(wd, "data_rutherford", f"juspyce_ct_slr_fm_500_5.pkl.gz"))
juspyce_ct_dominance = JuSpyce.from_pickle(\
    join(wd, "data_rutherford", f"juspyce_ct_dominance_fm_500_5.pkl.gz"))

# identify significant predictors
slr_sig = (juspyce_ct_slr.p_predictions["slr--fdr_bh"] < 0.05).any()
predictors_sig = slr_sig[slr_sig==True].index.to_list()
print(f"p_fdr < 0.05: n = {len(predictors_sig)}\n", predictors_sig)
dom_sig = (juspyce_ct_dominance.p_predictions["dominance_total--fdr_bh"] < 0.05).any()
predictors_sig_dom = dom_sig[dom_sig==True].index.to_list()
print(f"p_fdr < 0.05: n = {len(predictors_sig_dom)}\n", predictors_sig_dom)

INFO:juspyce.api:Loaded complete object from /Users/llotter/projects/ntct/data_rutherford/juspyce_ct_slr_fm_500_5.pkl.gz.
INFO:juspyce.api:Loaded complete object from /Users/llotter/projects/ntct/data_rutherford/juspyce_ct_dominance_fm_500_5.pkl.gz.


p_fdr < 0.05: n = 9
 ['ni3-FDOPA-DAT-D1-NMDA', 'ni4-GI-5HT1b-MU-A4B2', 'ni5-VAChT-NET', 'ni6-CBF-CMRglu', 'ni9-D2', 'ce3-Micro-OPC', 'ce4-In3-In2-Astro', 'ce5-In6-Ex2', 'ce9-In8']
p_fdr < 0.05: n = 6
 ['ni3-FDOPA-DAT-D1-NMDA', 'ni5-VAChT-NET', 'ni9-D2', 'ce3-Micro-OPC', 'ce4-In3-In2-Astro', 'ce9-In8']


### Regional influence

In [5]:
# prediction error function
def pe(x, y):    
    X = np.c_[x, np.ones(x.shape[0])] 
    beta = np.linalg.pinv((X.T).dot(X)).dot(X.T.dot(y))
    y_hat = np.dot(X, beta)
    pe = y_hat - y
    return pe

# dict to store results
regional_influence = dict()

# exclude missing rows
no_nan = np.array(~juspyce_ct_dominance._nan_bool)

# all predictors
x = juspyce_ct_dominance.X.values[:,no_nan].T

# iterate time points
for i_tp, tp in enumerate(tps_index):
    
    # empty array
    pe_diff = np.full((juspyce_ct_dominance.X.shape[1],juspyce_ct_dominance.X.shape[0]), np.nan)
    
    # iterate predictors/rows
    for i_pred, pred in enumerate(list(juspyce_ct_dominance.X.index)):
        # get all predictors w/o current predictor
        x_red = np.delete(x, i_pred, axis=1)
        # CT differences of current tp
        y = juspyce_ct_dominance.Y.loc[f"{tp}"].values[no_nan]
        # calculate pe
        pe_all = pe(x=x, y=y)
        pe_red = pe(x=x_red, y=y)
        # save
        pe_diff[no_nan,i_pred] = np.abs(pe_red) - np.abs(pe_all)
    # to dataframe
    regional_influence[tp] = pd.DataFrame(
        pe_diff, 
        columns=juspyce_ct_dominance.X.index,
        index=juspyce_ct_dominance.X.columns)

## Plot functions

### Dominance plot

In [6]:
## p-to-asterisk function
def p_to_ast(p_data, pc_data):
    p_1d = np.array(p_data).flatten()
    pc_1d = np.array(pc_data).flatten()
    ast = list()
    for (p, pc) in zip(p_1d, pc_1d):
        if (pc < 0.05): ast.append("★")
        elif (p < 0.05) & (pc >= 0.05): ast.append("☆")
        else: ast.append("")
    return pd.DataFrame(np.array(ast).reshape(p_data.shape), 
                        index=p_data.index, columns=p_data.columns)

## single plot function
def plot_diffs(data, data_p, data_null, ax, title="", colors=None,
               size_text=11,
               title_size=12, title_color="k", 
               legend_color="linecolor", legend_size=12, legend=True):
    pred = list(data.columns)
    if colors is None: colors = get_cmap("tab10")(range(len(pred)))
    alpha_med = 0.7
    alpha_text = 0.4
    pos_text = [-0.02,0.015]

    ## plot null
    for pz, pz_color in zip([(1,99), (5,95), (25,75)], ["0.97", "0.93", "0.89"]):
        # left
        ax.fill_between(
            x=[-1,0,1],
            y1=np.percentile(data_null, pz[0], axis=0)[0],
            y2=np.percentile(data_null, pz[1], axis=0)[0],
            lw=0,
            color=pz_color
        )
        # right
        ax.fill_between(
            x=tps_index[1:],
            y1=np.percentile(data_null, pz[0], axis=0)[1:],
            y2=np.percentile(data_null, pz[1], axis=0)[1:],
            lw=0,
            color=pz_color
        )

    ## plot predictor-wise
    for i, p in enumerate(pred):
        
        if data_p is not None:
            if "★" in data_p[p].to_list():
                label = p + " ★"
            elif "☆" in data_p[p].to_list():
                label = p + " ☆"
            else:
                label = p  
        else:
            label = p
                
        ax.plot(data.index[1:], data[p][1:], color=colors[i], alpha=alpha_med, label=label)
        
        if data_p is not None:
            for x, (y, p) in enumerate(zip(data[p][1:], data_p[p][1:])):
                ax.text(x+pos_text[0], y+pos_text[1], p, ha="center", va="top", 
                        color=colors[i], alpha=alpha_text+0.3, size=size_text)
                
    if legend:
        legend = ax.legend(loc="upper right", ncol=1, prop=dict(size=legend_size), 
                           labelcolor=legend_color)
        legend.set_title(title)
        plt.setp(legend.get_title(), color=title_color, size=title_size, weight="semibold")
    
    # general
    ax.set_xlim(-0.6,0.6)
    ax.set_xticks(list(np.arange(0,86,5)))
    ax.set_xticks(list(range(90)), minor=True)
    ax.set_xlim(-0.8,data.shape[0]-1.2)

### Brain & scatter plot

In [7]:
# function to make parcellated gifti from input vector
def get_parc_gifti(data, parc_gifti=parc_destrieux):
    lh = np.zeros(parc_gifti[0].darrays[0].data.shape)
    rh = np.zeros(parc_gifti[1].darrays[0].data.shape)
    for parcel_idx, parcel_val  in enumerate(data, start=1):
        lh[parc_gifti[0].darrays[0].data==parcel_idx] = parcel_val
        rh[parc_gifti[1].darrays[0].data==parcel_idx] = parcel_val
    lh[parc_gifti[0].darrays[0].data==0] = np.nan
    rh[parc_gifti[1].darrays[0].data==0] = np.nan
    return lh, rh

# surfplot function    
def plot_surf_ax(lh, fig, ax, template="pial", views=["lateral", "medial"], size=(1000,400),
                 layout="row", c="viridis_r", c_lims=None, 
                 cbar_symm=False, cbar=True, rotate_labels=False, 
                 zoom=1.6):
    fsaverage = nl.datasets.fetch_surf_fsaverage()
    if c_lims is None:
        if cbar_symm==False:
            c_lims = (np.nanmin(lh), np.nanmax(lh))
        elif cbar_symm==True:
            c_lims = (-np.nanmax(np.abs(lh)), np.nanmax(np.abs(lh)))
    p = Plot(fsaverage[f"{template}_left"], layout=layout, size=size, zoom=zoom, views=views)
    p.add_layer(dict(left=lh), cmap=c, color_range=c_lims)
    p.build(fig=fig, ax=ax, colorbar=False)
    # legend
    if cbar==True:
        cbar_width = 0.45
        cbar = fig.colorbar(
            cm.ScalarMappable(
                norm=Normalize(c_lims[0],c_lims[1]), 
                cmap=c), 
            cax=ax.inset_axes([(1-cbar_width)/2, 0.05, cbar_width, 0.06]),
            orientation="horizontal")
        if rotate_labels:
            plt.setp(cbar.ax.get_xticklabels(), rotation=-40, ha="left", rotation_mode="anchor")
            cbar.ax.tick_params(axis="x", which="major", pad=2)

def scatter(x, y, hue, fig, ax, hue_lims=None, r=None, c=None):
    # limits
    if hue_lims is None:
        hue_max = np.nanmax(np.abs(np.array(hue)))
        hue_lims = (-hue_max, hue_max)
    # scatter
    sns.scatterplot(
        x=x, 
        y=y, 
        ax=ax, 
        hue=hue, palette="RdBu_r", alpha=0.7, hue_norm=hue_lims, edgecolor='k', legend=None)
    sns.regplot(
        x=x, 
        y=y, 
        ax=ax, 
        color=c, scatter=False)
    if r:
        ax.annotate(
            text=f"${r:.2f}$",
            xy=(0.05,0.81) if r>0 else (0.96,0.81),
            xycoords="axes fraction",
            ha="left" if r>0 else "right",
            c=c,
            bbox=dict(boxstyle="round,pad=0.1", fc="white", alpha=0.5)
        )
    ax.set_xlabel("")
    ax.set_ylabel("")

## Plot

In [8]:
# ct limits
ct_lims = (juspyce_ct_dominance.Y.iloc[1:,:].min().min() * 100, 
           juspyce_ct_dominance.Y.iloc[1:,:].max().max() * 100)
ct_lims_abs = (-np.abs(ct_lims).max(), np.abs(ct_lims).max())

# pe limits
pe_lim = list()
for tp in tps_index[1:]:
    pe_lim.append(np.abs(regional_influence[tp]).max().max())
pe_lims = (-max(pe_lim), max(pe_lim))

for tp in tps_index[73:]:
    print(tp)
    delta = juspyce_ct_dominance.Y.loc[tp,:] * 100
    
    pe_lim = (np.abs(regional_influence[tp]).max().max())   # ACTIVATE IF TP-WISE INDIVIDUAL LIMITS
    pe_lims = (-pe_lim, pe_lim)                             # ACTIVATE IF TP-WISE INDIVIDUAL LIMITS
    
    ## figure
    fig = plt.figure(figsize=(16.5,15))
    colors = get_cmap("tab10")(range(10))
    gs = fig.add_gridspec(
        5, 8+3, 
        height_ratios=(0.25,0.25,0.25,1.1,0.5), 
        width_ratios=(0.5,1,0.025, 0.5,1,0.025, 0.5,1,0.025, 0.5,1),
        wspace=0.1,
        hspace=0.5)
    fig.patch.set_facecolor('w')

    ## brains
    # CT
    ax_ct = fig.add_subplot(gs[0,:2])
    plot_surf_ax(get_parc_gifti(delta)[0], fig, ax_ct, c="RdBu_r", 
                 cbar_symm=True, rotate_labels=True,
                 #c_lims=ct_lims_abs                        # DEACTIVATE IF TP-WISE INDIVIDUAL LIMITS
                 )
    ax_ct.set_title("CT "+tp, size=13, weight="semibold", c="k")
    ax_ct.annotate("Cortical thickness development [%-change]", 
                   xy=(-0.5,-1.4), xycoords="axes fraction", 
                   ha="center", va="center", rotation=90, size=13)
    
    # whitespace
    ax_ws1 = fig.add_subplot(gs[:3,2])
    ax_ws1.axis("off")
    ax_ws2 = fig.add_subplot(gs[:3,5])
    ax_ws2.axis("off")
    
    # predictors    
    for i, pred, gs_i in zip(    
        range(len(predictors_sig)),
        predictors_sig,
        [(0,3), (0,6), (0,9), (1,0), (1,3), (1,6), (1,9), (2,0), (2,3)]
    ):
        
        # data
        z = juspyce_ct_dominance.X.loc[pred,:]
        perror = regional_influence[tp][pred]
        # scatter
        ax_pred1 = fig.add_subplot(gs[gs_i])
        scatter(z, delta, perror, fig, ax_pred1, c=colors[i], hue_lims=pe_lims,
                r=juspyce_ct_dominance.predictions["spearman"].loc[tp,pred])
        #ax_pred1.set_ylim(ct_lims[0], ct_lims[1])          # DEACTIVATE IF CONSTANT INDIVIDUAL LIMITS
        # brain
        ax_pred2 = fig.add_subplot(gs[gs_i[0], gs_i[1]+1])
        plot_surf_ax(get_parc_gifti(perror)[0], fig, ax_pred2, c="RdBu_r", c_lims=pe_lims, 
                     rotate_labels=True, cbar=False)
        # title
        ax_pred1.set_title(pred, loc="left", ha="left", size=13, weight="semibold", c=colors[i])
        # labels
        #if gs_i[1]==0:
            #ax_pred1.set_ylabel("CT (%-change)", size=12)
        if gs_i in [(2,0), (2,3), (1,6), (1,9)]:
            ax_pred1.set_xlabel("Predictors [Z]", size=12)
            
    # colorbar
    ax_cbar = fig.add_subplot(gs[2,7])
    ax_cbar.axis("off")
    cbar = fig.colorbar(
        cm.ScalarMappable(
            norm=Normalize(pe_lims[0], pe_lims[1]), 
            cmap="RdBu_r"
        ), 
        ax=ax_cbar,
        orientation="horizontal",
        fraction=0.25,
        pad=-1)
    plt.setp(cbar.ax.get_xticklabels(), rotation=-40, ha="left", rotation_mode="anchor")
    cbar.ax.tick_params(axis="x", which="major", pad=2)
    cbar.ax.set_title("Residual difference")
    
    ## dominance analysis
    ax_dom = fig.add_subplot(gs[3,:])
    plot_diffs(
        data=juspyce_ct_dominance.predictions["dominance_total"],
        data_p=p_to_ast(
            juspyce_ct_dominance.p_predictions["dominance_total"], 
            juspyce_ct_dominance.p_predictions["dominance_total--fdr_bh"]),
        data_null=pd.DataFrame(
            data=np.column_stack(
                [juspyce_ct_dominance.nulls["predictions-dominance"][i]["dominance_total"] \
                    for i in range(10000)]).T,
            index=range(10000*len(predictors_sig)),
            columns=tps_index
        ),
        ax=ax_dom,
        legend=True,
        legend_size=12,
    )
    ax_dom.axvline(tp, c="k")
    
    # total explained R2
    ax_dom.annotate(
        text="$\\bfTotal\ adjusted\ R^2="
             f"{juspyce_ct_dominance.predictions['dominance_full_r2'].loc[tp].values[0]:.02f}$",
        xy=(1,0.28),
        size=15,
        c="0.4"
    )

    ## spearman correlations
    ax_cor = fig.add_subplot(gs[4,:])
    max_pred = juspyce_ct_dominance.predictions["dominance_total"].iloc[1:,:].idxmax()
    sig_correlations = juspyce_ct_dominance.predictions["spearman"]
    colors = get_cmap("tab10")(range(len(predictors_sig)))
    for i, p in enumerate(predictors_sig):
        ax_cor.plot(
            list(sig_correlations.index[1:]), 
            sig_correlations[p][1:], 
            color=colors[i], alpha=0.7, label=p
        )

    ax_cor.axhline(0, c="k", linewidth=1)
    ax_cor.set_xticks(ax_dom.get_xticks())
    ax_cor.set_xticks(ax_dom.get_xticks(minor=True), minor=True)
    ax_cor.set_xticklabels(juspyce_ct_dominance.y_lab.to_list()[1::5] + [""], 
                           rotation=-30, ha="left", rotation_mode="anchor")
    ax_cor.set_xlim(ax_dom.get_xlim())
    ax_cor.axvline(tp, c="k")
    ax_cor.tick_params(axis="both", which="major", labelsize=11)

    # finish
    ax_dom.set_ylim(-0.01, 0.32)
    ax_cor.set_ylim(-0.85,0.85)
    ax_dom.set_ylabel("CT change explained [$adjusted\ R^2$]", size=13)
    ax_cor.set_ylabel("Z $[Spearman's\ rho]$", size=13)
    ax_cor.set_xlabel("Age windows [5-year steps]", size=13, labelpad=10)
    ax_dom.set_xticklabels(ax_cor.get_xticklabels(),
                           rotation=-30, ha="left", rotation_mode="anchor")
    ax_dom.tick_params(axis="both", which="major", labelsize=11)
    
    save_path = join(plot_dir_gif, f"dev_ct_animation_fm_500_5_{tps_index.index(tp):02d}.png")
    fig.savefig(save_path,  bbox_inches="tight", dpi=100, transparent=False)
    plt.close()

Δ(77,82)
Δ(78,83)
Δ(79,84)
Δ(80,85)
Δ(81,86)


Context leak detected, msgtracer returned -1


Δ(82,87)
Δ(83,88)
Δ(84,89)


Context leak detected, msgtracer returned -1


Δ(85,90)


## Gif

In [9]:
imgs = glob(join(plot_dir_gif, "*.png"))
imgs.sort()
frames = []
[frames.append(iio.imread(i)) for i in imgs],
iio.mimwrite(join(plot_dir_gif, "dev_ct_animation_fm_500_5.gif"), frames, fps=1.5)