In [None]:
import numpy as np
import os
from scipy import io as sio
from scipy import stats
import pandas as pd
import torch
from IPython.display import Image 
import matplotlib.pyplot as plt
%matplotlib inline

# Load saved model

In [None]:
model = torch.load('/users/hailey/data/05_conference_workshop_grant/2023_KNU_bootcamp/results/result_20230205_202502/Outer_fold_1/model_fold_4.pt', map_location='cpu')
print(model['ext_1.weight'].shape,model['prd_1.weight'].shape,model['prd_2.weight'].shape)

# Make weight feature map from weight parameters

In [None]:
# Weight feature if multiplying weight matrices across layers
WF = torch.matmul(model['ext_1.weight'].T, model['prd_1.weight'].T) 
WF = torch.matmul(WF, model['prd_2.weight'].T)
WF = np.array(WF).squeeze()

# Visualize on 2D matrix

In [None]:
n_rois =  200
# Mapping from 1D to 2D (1 x 200*199/2)  = 1 x 19900
upper_tri_idx=np.triu(np.ones((n_rois,n_rois),dtype=np.float32),k=1)

WF_2d = np.zeros((n_rois,n_rois))
WF_2d[np.where(upper_tri_idx==1)] = WF 
WF_2d = WF_2d+WF_2d.T

In [None]:
plt.figure(figsize=(5,5))
plt.imshow(WF_2d,cmap='RdBu_r',vmin=-0.07,vmax=0.07)
plt.show()

# Sort order of ROIs by 7 network

In [None]:
Image("/users/hailey/data/05_conference_workshop_grant/2023_KNU_bootcamp/to_share/7network.jpg")

In [None]:
network_info = np.load('/users/hailey/data/05_conference_workshop_grant/2023_KNU_bootcamp/to_share/7network_order.npz')
network_label, network_orderidx, network_length = network_info['label'], network_info['order_idx'], network_info['length']

In [None]:
WF_2d_sort = WF_2d[network_orderidx,:]
WF_2d_sort = WF_2d_sort[:,network_orderidx]

plt.figure(figsize=(5,5))
plt.imshow(WF_2d_sort,cmap='RdBu_r',vmin=-0.07,vmax=0.07)
for net in range(7):
    plt.hlines(y=np.sum(network_length[:net]),xmin=0,xmax=n_rois)
    plt.vlines(x=np.sum(network_length[:net]),ymin=0,ymax=n_rois)
plt.show()

In [None]:
pct_thr = 99.95
thr = np.percentile(WF,pct_thr)

plt.figure(figsize=(5,5))
plt.imshow(WF_2d_sort*(WF_2d_sort>thr),cmap='RdBu_r',vmin=-thr,vmax=thr)
for net in range(7):
    plt.hlines(y=np.sum(network_length[:net]),xmin=0,xmax=n_rois)
    plt.vlines(x=np.sum(network_length[:net]),ymin=0,ymax=n_rois)
plt.title('Largest WF')
plt.show()

# Visualize on 3D brain space

In [None]:
from nilearn import datasets
from nilearn import plotting
from nilearn import image

atlas_schaefer = datasets.fetch_atlas_schaefer_2018(n_rois=n_rois, yeo_networks=7, resolution_mm=2)
coord_schaefer = plotting.find_parcellation_cut_coords(atlas_schaefer['maps'])
roi_net_info = pd.read_table('/users/hailey/data/05_conference_workshop_grant/2023_KNU_bootcamp/to_share/7networks_info.txt',names=['index','name','R','G','B','etc'])
roi_RGB = np.vstack([roi_net_info['R'].to_numpy(),roi_net_info['G'].to_numpy(),roi_net_info['B'].to_numpy()]).T/255

In [None]:
fig = plt.figure(figsize=(10,5))
plotting.plot_connectome(WF_2d*(WF_2d>thr), coord_schaefer, figure=fig, title='Largest WF',
                         display_mode='lzr', node_color=roi_RGB, node_size=50, colorbar=True)
plotting.show()

In [None]:
plotting.view_connectome(WF_2d*(WF_2d>thr), coord_schaefer, title='Largest WF', colorbar=True) 