In [1]:
state_reordered = [2,6,1,7,3,5,0,4]

brain_states_dict = {
    1: "Higher-order Anterior",
    2: "Higher-order Posterior",
    3: "Unimodal Visual",
    4: "Unimodal Sensorimotor",
    5: "Lateralized Language versus SMN",
    6: "Lateralized Visual versus DMN",
    7: "Suppressed DMN-FPN-Visual",
    8: "Suppressed General"
    }

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors
from matplotlib.image import imread

from osl_dynamics.analysis import spectral
from osl_dynamics.analysis import power, connectivity
from osl_dynamics.utils import plotting
from osl_dynamics.utils.plotting import plot_markers
from osl_dynamics.utils.plotting import plot_psd_topo
from osl_dynamics.utils.parcellation import Parcellation

def custom_colormap(n_color,n,colormap):

    """Assign white color to a specific range of the colorbar

    Parameters
    ----------
    n_color : int, Absolute range of the color bar i.e., the range between vmin and vmax
    n : int, Range where white should be assigned
    colormap : Should be a colormap: e.g., plt.cm.jet or plt.cm.RdBu_r

    Returns
    -------
    tmap : a cmap object
    """
    cmap = colormap
    lower = cmap(np.linspace(0, 0.5, int((n_color-n)/2)))
    white = np.ones((n,4))
    upper = cmap(np.linspace(1-0.5, 1, int((n_color-n)/2)))
    colors = np.vstack((lower, white, upper))
    tmap = matplotlib.colors.LinearSegmentedColormap.from_list('map_white', colors)

    return tmap

n_states = 8

BASE_DIR = "E:/Research_Projects/MEG_CamCAN"
f = np.load(f"{BASE_DIR}/TDE_HMM/results/inf_params/{n_states:02d}_states/f.npy")  # (n_freq,)
# Source reconstruction files
mask_file = "MNI152_T1_8mm_brain.nii.gz"
parcellation_file = "Glasser52_binary_space-MNI152NLin6_res-8x8x8.nii.gz"

output_dir = f"{BASE_DIR}/TDE_HMM/output"
idx_parcels = np.load(f"{output_dir}/2_Brain_state_analysis/idx_parcels.npy") # for selecting parcels
idx_parcels_mni_reordered = np.load(f"{output_dir}/2_Brain_state_analysis/idx_parcels_mni_reordered.npy") # for selecting parcels in mni order
weights_freqranges = np.load(f"{output_dir}/2_Brain_state_analysis/weights_freqranges.npy") # For averaging over frequencies

PLS SPECTRAL

In [7]:
# Bootstrap sampling ratios
OUTPUT_DIR = "TDE_HMM/output/3_Neurocognitive_analysis/output_PLS_spectral_new"
BSR_PLS = np.load(f"{BASE_DIR}/{OUTPUT_DIR}/BSR_PLS_spectral.npy")
BSR_PLS_salient = np.copy(BSR_PLS)
BSR_PLS_salient[abs(BSR_PLS_salient) < 3] = 0 # (2, 8, 20, 88)
nparcels = 20

expanded_BSR_PLS_salient = np.zeros([2, 8, 52, 88])
for comp in range(2):
    # Iterate over each state
    for state in range(idx_parcels.shape[0]):
        # Get the indices of selected parcels for the current state
        selected_indices_state = idx_parcels[state,-nparcels:]
        # Iterate over each selected parcel index
        for idx, value in enumerate(selected_indices_state):
            # Copy the selected parcel to its original position in the expanded array
            if comp == 0: expanded_BSR_PLS_salient[comp, state, value,:] = BSR_PLS_salient[comp, state, idx,:]*(-1)
            elif comp == 1: expanded_BSR_PLS_salient[comp, state, value,:] = BSR_PLS_salient[comp, state, idx,:]

Power spatial maps

In [None]:
for comp in range(2):
    for state in range(8):
        pow_state = power.variance_from_spectra(f, expanded_BSR_PLS_salient[comp, state,:,:], method="mean")

        power.save(
        pow_state,
        mask_file=mask_file,
        parcellation_file=parcellation_file,
        plot_kwargs={
                    "cmap": "RdBu_r",
                    "bg_on_data": True,
                    "darkness": 1,
                    "alpha": 1,
                    "symmetric_cbar":True,
                    "views": ['lateral'],
                }
        )
        plt.savefig(f"{BASE_DIR}/{OUTPUT_DIR}/BSR_powermap_state_{state+1}_comp{comp+1}.png", dpi = 400)
        plt.close()

Topoplot of the MEG atlas used in the study

In [None]:
# Reordering in AP direction
parcellation = Parcellation(parcellation_file)
roi_centers = parcellation.roi_centers()
order = np.argsort(roi_centers[:, 1]) # Anterior-to-Posterior reordering of the parcels

plot_markers(np.arange(52),
            roi_centers[order],
            node_cmap= plt.cm.viridis_r,
            node_size=50, # type: ignore
            colorbar=False,
            display_mode='z', # type: ignore
        )

Main plot

In [None]:
expanded_comp_mni_reordered = expanded_BSR_PLS_salient[:,:,order,:]

fig, axs = plt.subplots(2, 4, figsize=(15, 15))
plt.subplots_adjust(wspace=0, hspace=-0.75)
display_nparcels = 20

for comp in range(2):
    for state, ax_subplot in enumerate(axs.flat):
        #################################################################################
        # Main plot
        #################################################################################
        fig, ax = plot_psd_topo(f,expanded_comp_mni_reordered[comp,state,...],
                    only_show=idx_parcels_mni_reordered[state,-display_nparcels:],
                        ) # type: ignore
        saliency = np.ma.masked_where(np.abs(expanded_comp_mni_reordered[comp,state,...]) < 3, expanded_comp_mni_reordered[comp,state,...])
        saliency_mean = np.ma.mean(saliency,axis=-2).filled(0)
        saliency_sem = np.ma.divide(np.ma.std(saliency,axis=-2,ddof=1),np.sqrt(52)).filled(0)
        ax.fill_between(f, saliency_mean-saliency_sem, saliency_mean+saliency_sem, color='black', alpha=0.7,zorder=11) # type: ignore
        ax.plot(f, saliency_mean, color='r',zorder=12) # type: ignore

        ax.set_title(f"State {state+1} ({brain_states_dict[state+1]})", loc='left', fontsize = 18) # type: ignore
        if state % 2 == 0: ax.set_ylabel("Salient BSR", size=18) # type: ignore
        if state % 2 != 0: ax.set_ylabel(None) # type: ignore
        ax.set_ylim(-20,20) # type: ignore
        ax.set_yticks(ticks=[-20,-15,-10,-3,3,10,15,20]) # type: ignore

        ax.axhspan(-3, 3, color='w', alpha=1, zorder=12) # type: ignore
        ax.axhline(y=-3, color = 'black', linestyle = '--', linewidth = 2/3,zorder=12) # type: ignore
        ax.axhline(y=3, color = 'black', linestyle = '--', linewidth = 2/3,zorder=12) # type: ignore

        plot_markers(
            np.arange(display_nparcels),
            roi_centers[idx_parcels[state,-(display_nparcels):]], # ordered by state-relevant parcels
            node_cmap= matplotlib.colors.LinearSegmentedColormap.from_list(
                'custom_cmap',
                plt.cm.viridis_r(np.linspace(0, 1, 52))[idx_parcels_mni_reordered[state,-(display_nparcels):]]
                ),
            node_size=12, # type: ignore
            colorbar=False,
            display_mode='lyrz',
            axes=ax.inset_axes([0.44, -0.14, 0.55, 0.6]), # type: ignore
        )

        # Save individual state before embedding in subplots #############################
        plt.tight_layout()
        psd_img_path = f"{BASE_DIR}/{OUTPUT_DIR}/BSR_state_{state+1}_{comp+1}.png"
        plt.savefig(psd_img_path, dpi = 400)
        plt.close()

        #################################################################################
        # Subplot Embedding
        #################################################################################
        main_img = imread(psd_img_path, format="png")
        # Display the image in the subplot
        ax_subplot.imshow(main_img)
        ax_subplot.axis('off')

        # Adding Spatial power maps #######################################################
        # png_path = f"{BASE_DIR}/{OUTPUT_DIR}/BSR_powermap_state_{state+1}_comp{comp+1}.png"
        # img_power_map = imread(png_path, format="png")
        # # Create a rectangular mask to crop the bottom part
        # mask = np.ones_like(img_power_map)
        # mask[:200, :] = 0 
        # mask[-500:, :] = 0 
        # cropped_image = img_power_map * mask

        # inside_ax_subplot2 = ax_subplot.inset_axes([0.52,0, 0.45, 0.45])
        # inside_ax_subplot2.imshow(cropped_image)
        # inside_ax_subplot2.axis('off')
        
    plt.savefig(f"{BASE_DIR}/{OUTPUT_DIR}/BSR_all_states_comp{comp+1}.png", dpi = 400)  

In [None]:
# from matplotlib.pyplot import contourf, pcolormesh
# from matplotlib.ticker import MaxNLocator

# data = np.mean(expanded_BSR_PLS_salient[1],axis=-2)
# data[abs(data)<3] = 0

# x= f
# y=[f"{brain_states_dict[i+1]}\n (State {i+1})" for i in range(8)] # Get the corresponding label
# z=data # Average across parcels
# ax = contourf(x,y[::-1],z[::-1,:],
#                 vmin=-4,vmax=4, 
#                 # Get the custom colormap
#                 cmap = plt.cm.RdBu_r, 
#                 levels=MaxNLocator(nbins=9,symmetric=True).tick_values(z.min(), z.max()),
#                 )
# # Colorbar
# cbar = plt.colorbar(mappable=ax,ticks=[-8,-5, -3, 3, 5,8], extend="both", extendrect=True)
# cbar.ax.tick_params(labelsize=8)
# cbar.set_label('BSR', rotation=270, fontsize = 9)

# plt.tight_layout()

PLS Transition

In [None]:
# Bootstrap sampling ratios
OUTPUT_DIR = "TDE_HMM/output/3_Neurocognitive_analysis/output_PLS_transition"
BSR_PLS = np.load(f"{BASE_DIR}/{OUTPUT_DIR}/BSR_PLS_transition.npy")
BSR_PLS_salient = np.copy(BSR_PLS)
BSR_PLS_salient[0] = BSR_PLS_salient[0]*(-1) # stay consistent with positive correlation with age
BSR_PLS_salient[abs(BSR_PLS_salient) < 3] = 0 # (2, 8, 8)
# BSR_PLS_salient = np.where(BSR_PLS_salient == 0, np.nan, BSR_PLS_salient)

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

for comp in range(2):
    # i = y axis - State at t
    # j = x axis - State at t+1

    data = BSR_PLS_salient[comp]
    plt.figure(figsize=(12, 8))
    sns.heatmap(data, annot=True, cmap='RdBu_r', center=0, cbar = False)
    # Add dotted grid lines
    for i in range(8):
        plt.axhline(i+1, color='black', linestyle='dotted', linewidth=0.5)
        plt.axvline(i+1, color='black', linestyle='dotted', linewidth=0.5)

    plt.title(f'Salient state-to-state transitions - LC{comp+1}')
    plt.xlabel('States at t+1')
    plt.xticks(ticks=np.arange(len(data)) + 0.5, labels=[f"State {i+1} \n({brain_states_dict[i+1]})" for i in range(8)], size = 9)
    # Alternative labels: list(brain_states_dict.keys())
    plt.ylabel('States at t')
    plt.yticks(ticks=np.arange(len(data)) + 0.5, labels=[f"State {i+1} \n({brain_states_dict[i+1]})" for i in range(8)], size = 9, rotation = 25)
    plt.savefig(f"{BASE_DIR}/{OUTPUT_DIR}/BSR_transitions_comp{comp+1}.png", dpi = 400)  

In [None]:
import itertools
import numpy as np

def most_likely_3_state_cycles(matrix, k=1):
    n = len(matrix)
    states = range(n)
    old_cycles = []
    young_cycles = []

    # Generate all permutations of 3 states
    for cycle in itertools.permutations(states, 4):
        # Close the cycle by returning to the first state
        full_cycle = cycle + (cycle[0],)
        prob = (
            abs(matrix[full_cycle[0]][full_cycle[1]])
            * abs(matrix[full_cycle[1]][full_cycle[2]])
            * abs(matrix[full_cycle[2]][full_cycle[3]])
            * abs(matrix[full_cycle[3]][full_cycle[0]])
        )**(1/4)

        # Only consider cycles with all positive transitions
        if (
            matrix[full_cycle[0]][full_cycle[1]] > 0
            and matrix[full_cycle[1]][full_cycle[2]] > 0
            and matrix[full_cycle[2]][full_cycle[3]] > 0
            and matrix[full_cycle[3]][full_cycle[0]] > 0
        ):
            old_cycles.append((full_cycle, prob))
        
        # Only consider cycles with all negative transitions
        elif (
            matrix[full_cycle[0]][full_cycle[1]] < 0
            and matrix[full_cycle[1]][full_cycle[2]] < 0
            and matrix[full_cycle[2]][full_cycle[3]] < 0
            and matrix[full_cycle[3]][full_cycle[0]] < 0

        ):
            young_cycles.append((full_cycle, prob))

    # Sort cycles by their probabilities in descending order
    old_cycles.sort(key=lambda x: x[1], reverse=True)
    young_cycles.sort(key=lambda x: x[1], reverse=True)

    # Return the top k cycles
    return old_cycles[:k], young_cycles[:k]

old_cycles, young_cycles = most_likely_3_state_cycles(BSR_PLS_salient[0], k = 8)

print("Most likely 3-state directed old cycles (positive only):")
for cycle, prob in old_cycles:
    print(f"Cycle: {cycle}, Probability: {prob}")
print("Most likely 3-state directed young cycles (negative only):")
for cycle, prob in young_cycles:
    print(f"Cycle: {cycle}, Probability: {prob}")

PLS Connectivity with COH

In [None]:
# # Bootstrap sampling ratios
# OUTPUT_DIR = "TDE_HMM/output/3_Neurocognitive_analysis/output_PLS_connectivity"
# BSR_PLS = np.load(f"{BASE_DIR}/{OUTPUT_DIR}/BSR_PLS_connectivity.npy")
# BSR_PLS_salient = np.copy(BSR_PLS)
# BSR_PLS_salient[abs(BSR_PLS_salient) < 3] = 0 # (2, 8, 52, 52, 88)

In [None]:
# ################
# # # Find state-relevant frequency ranges with NNMF
# BSR_COH_NNMF = np.zeros([2,4,88])
# for comp in range(2):
#     BSR_COH_NNMF[comp,:,:] = spectral.decompose_spectra(abs(BSR_PLS_salient[comp]), n_components=4)
#     np.save(f"{networks_dir}/BSR_COH_NNMF_comp{comp+1}.npy", BSR_COH_NNMF[comp])

#     fig, ax = plotting.plot_line(
#         [f,f,f,f],  # we need to repeat f times because we fitted f components
#         BSR_COH_NNMF[comp],
#         x_label="Frequency (Hz)",
#         y_label="Spectral Component",
#         labels=[f"Mode {mode+1}" for mode in range(BSR_COH_NNMF[comp].shape[0])],
#         )  # type: ignore
#     ax.set_xticks(np.arange(0,50,5)) # type: ignore
# ################

# BSR_COH_networks = np.zeros([2,4,8,52,52])
# for comp in range(2):
#     BSR_COH_networks[comp,...] = connectivity.mean_coherence_from_spectra(f, BSR_PLS_salient[comp],
#                                                                 components=BSR_COH_NNMF[comp])

# band_dict = {
#     1: "Mode 1",
#     2: "Mode 2",
#     3: "Mode 3",
#     4: "Mode 4"
# }

In [None]:
# # Connectivity networks
# # 2 components, 4 bands, 8 states = 64 plots organized into 2 figures
# for comp in range(2):
#     fig, axs = plt.subplots(8, 4, figsize=(30, 30))

#     for band in range(4):  
#         for state in range(8):
#             ax_subplot = axs[state,band]
#             if comp == 0:
#                 connectivity.save(
#                     BSR_COH_networks[comp,band,state,:,:],
#                     parcellation_file=parcellation_file,
#                     plot_kwargs={
#                         'title': f"{brain_states_dict[state+1]} - {band_dict[band+1]}",
#                         'axes': ax_subplot,
#                         "edge_vmin": -7,
#                         "edge_vmax": 7,
#                         "edge_cmap": custom_colormap(10,2,plt.cm.RdBu_r),
#                         'colorbar': False
#                         }
#                 )
#             elif comp == 1:
#                 connectivity.save(
#                     BSR_COH_networks[comp,band,state,:,:]*(-1),
#                     parcellation_file=parcellation_file,
#                     plot_kwargs={
#                         'title': f"{brain_states_dict[state+1]} - {band_dict[band+1]}",
#                         'axes': ax_subplot,
#                         "edge_vmin": -3,
#                         "edge_vmax": 3,
#                         "edge_cmap": custom_colormap(10,2,plt.cm.RdBu_r),
#                         'colorbar': False
#                         }
#                 )
    
#     # plt.savefig(f"{BASE_DIR}/{OUTPUT_DIR}/BSR_COH_network_allbands_allstates_comp{comp+1}.png", dpi = 400)

In [None]:
# # Chord diagrams
# from mne.viz.circle import circular_layout, _plot_connectivity_circle

# labels = np.genfromtxt('../3_Neurocognitive_analysis/parcellation_labels.txt',
#                      delimiter='\t', dtype="str")

# lh_labels = [name for name in labels if name.endswith("LH")]
# rh_labels = [name for name in labels if name.endswith("RH")]

# node_order = list()
# node_order.extend(lh_labels[::-1])  # reverse the order
# node_order.extend(rh_labels)
# node_angles = circular_layout(
#     labels, node_order, start_pos=90, group_boundaries=[0, len(labels) / 2]
# )

# for comp in range(2):
#     for band in range(4):  
#         for state in range(8):

#             fig_tmp, ax_tmp = plt.subplots(figsize=(8, 8), facecolor="black", subplot_kw=dict(polar=True))
#             _plot_connectivity_circle(
#                 BSR_COH_networks[comp,band,state,:,:],
#                 labels,
#                 n_lines=np.count_nonzero(
#                     np.triu(BSR_COH_networks[comp,band,state,:,:],k=1)
#                         ),
#                 node_angles=node_angles,
#                 colormap= "RdBu_r",
#                 vmin = -5,
#                 vmax = 5,
#                 ax=ax_tmp,
#                 node_colors= 52 * [(0, 0, 0, 0.)],
#                 fontsize_names=10,
#                 fontsize_title=16,
#                 title=f"{brain_states_dict[state+1]} - {band_dict[band+1]}",
#                 facecolor='black', 
#                 textcolor='white',
#                 # fontsize_colorbar=6,
#                 # colorbar_pos=(2,1),
#                 colorbar=False,
#             )
#              # Save individual state before embedding in subplots #############################
#             fig_tmp.tight_layout()
#             chord_img_path = f"{BASE_DIR}/{OUTPUT_DIR}/BSR_chord_state_{state+1}_band_{band+1}_{comp+1}.png"
#             fig_tmp.savefig(chord_img_path, dpi = 400)

In [None]:
# # Creating the entire figure
# for comp in range(2):
#     fig, axs = plt.subplots(8, 4, figsize=(8, 8))
#     fig.subplots_adjust(wspace=-0.75)

#     for band in range(4):  
#         for state in range(8):
#             chord_img_path = f"{BASE_DIR}/{OUTPUT_DIR}/BSR_chord_state_{state+1}_band_{band+1}_{comp+1}.png"
#             #################################################################################
#             # Subplot Embedding
#             #################################################################################
#             main_img = imread(chord_img_path, format="png")
#             # Display the image in the subplot
#             ax_subplot = axs[state,band]
#             ax_subplot.imshow(main_img)
#             ax_subplot.axis('off')