**Krakencoder usage example**

This notebook provides an example of how to load connectome data and apply a pretrained Krakencoder model to that data.

The process is as follows:
1.   Load model
2.   Load new data and do mild domain adaptation (map input data mean to training data mean)
3.   Transform each input data flavor into 128-dimensional latent space
4.   Average latent space across all types ("fusion")
  * Note: this might be used for prediction, clustering, etc.
5.   Transform "fusion" averaged latent vectors to output connectomes



In [None]:
import torch
import numpy as np
from krakencoder.model import Krakencoder
from krakencoder.adaptermodel import KrakenAdapter
from krakencoder.utils import square2tri, tri2square, numpyvar
from krakencoder.data import generate_adapt_transformer, load_transformers_from_file
from krakencoder.fetch import fetch_model_data
from scipy.io import loadmat, savemat
import os
import humanize

In [None]:
# load model checkpoint and precomputed PCA transforms
# The first time these are used, they are downloaded to package_dir/model_data (~1.3GB)
# To use an alternate storage location, you can set the KRAKENCODER_DATA environment variable
checkpoint_file=fetch_model_data('kraken_chkpt_SCFC_fs86+shen268+coco439_pc256_225paths_latent128_20240413_ep002000.pt', 
                                 data_folder="/home/jovyan/shared/krakencoder/model_data/")

In [None]:
ioxfm_file_list=fetch_model_data(['kraken_ioxfm_SCFC_fs86_pc256_710train.npy',
                                  'kraken_ioxfm_SCFC_shen268_pc256_710train.npy',
                                  'kraken_ioxfm_SCFC_coco439_pc256_710train.npy'], 
                                  data_folder="/home/jovyan/shared/krakencoder/model_data")

In [None]:
inner_net, checkpoint_info = Krakencoder.load_checkpoint(checkpoint_file, eval_mode=True)

In [None]:
transformer_list, transformer_info_list = load_transformers_from_file(ioxfm_file_list)

In [None]:
#create new model that wraps the inner kraken model and includes PCA transforms from raw data
net=KrakenAdapter(inner_model=inner_net,
                  data_transformer_list=[transformer_list[conntype] for conntype in checkpoint_info['input_name_list']],
                  linear_polynomial_order=0,
                  eval_mode=True)

In [None]:
#load example data (10 validation subjects from HCP-YA)
conndata_squaremats=loadmat('/home/jovyan/shared/krakencoder/exampledata_10subj_fs86_inputs.mat',
                            simplify_cells=True)

#skip internal header keys
conntypes=[c for c in conndata_squaremats.keys() if not c.startswith("_")]

#input data are lists of [roi x roi] square matrices so convert those to [subj x edges]
conndata={}
conndata_triidx={} #store so we can restore square later
for c in conntypes:
  conndata[c]={'data': np.stack([square2tri(x) for x in conndata_squaremats[c]['data']])}
  _, conndata_triidx[c]=square2tri(conndata_squaremats[c]['data'][0],return_indices=True)
  print("conndata_squaremats['%s']['data']" % (c),conndata_squaremats[c]['data'].shape, conndata_squaremats[c]['data'][0].shape)
  print(" -> conndata['%s']['data']" % (c),conndata[c]['data'].shape)
del conndata_squaremats

In [None]:
# compute SIMPLE domain adaptation transform: map mean(input subjects) to mean(training subjects) for each flavor
# and transform each input data flavor to match training data mean
#
# It's not really needed for these data since they are from HCP-YA already, which the model was trained on.
# That is why the model fits are all roughly "y = 1.0*x + 0". It is shown here for demonstration purposes only.
adxfm_dict={}
conndata_adapted={}
for c in conndata:
  adxfm_dict[c]=generate_adapt_transformer(input_data=conndata[c]['data'],
                                           target_data=transformer_info_list[c],
                                           adapt_mode='meanfit+meanshift')
  conndata_adapted[c]={'data':adxfm_dict[c].transform(conndata[c]['data'])}

In [None]:
# transform input data to krakencoder latent space
encoded_data={}

#loop through all of the input names from the saved checkpoint,
# because the encoder/decoder indices are in this order.
for encidx, c in enumerate(checkpoint_info['input_name_list']):
  if not c in conndata_adapted:
    #if this input type was not in the example data, skip it
    continue
  with torch.no_grad():
    encoded_data[c]=net(conndata_adapted[c]['data'],encoder_index=encidx, decoder_index=-1)

# compute average latent representation
encoded_fusion=torch.mean(torch.stack([encoded_data[c] for c in encoded_data]),axis=0)

print("fusion latent space representation: ", encoded_fusion.shape)

# Now predict output connectomes from fusion latent representation
# Predictions are stored in predicted_alltypes[inputtype][outputtype]
predicted_alltypes={'fusion':{}}

for decidx, c in enumerate(checkpoint_info['input_name_list']):
  with torch.no_grad():
    _,pred=net(encoded_fusion,encoder_index=-1, decoder_index=decidx)
  predicted_alltypes['fusion'][c]=numpyvar(pred) #convert back to numpy for analysis
  print("predicted_alltypes['fusion']['%s']: " % (c), predicted_alltypes['fusion'][c].shape)

#add fusion latent representation to output
predicted_alltypes['fusion']['encoded']=numpyvar(encoded_fusion)
print("predicted_alltypes['fusion']['%s']: " % ('encoded'), predicted_alltypes['fusion']['encoded'].shape)


In [None]:
# save outputs to file for later analysis
outfile='exampledata_outputs.mat'
savemat(outfile, {'predicted_alltypes':predicted_alltypes}, format='5', do_compression=True)
print("Saved %s (%s)" % (outfile, humanize.naturalsize(os.path.getsize(outfile))))

In [None]:
# convert upper tri back to square to display observed and predicted connectomes

import matplotlib.pyplot as plt

conntype='FCcorr_fs86_hpf'
isubj=0

Cobs_square=tri2square(conndata_adapted[conntype]['data'][isubj,:],
                            tri_indices=conndata_triidx[conntype],
                            diagval=1) #diagval=1 for FC
Cpred_square=tri2square(predicted_alltypes['fusion'][conntype][isubj,:],
                            tri_indices=conndata_triidx[conntype],
                            diagval=1) #diagval=1 for FC

#make sure these are converted back to CPU/numpy before trying to display
Cobs_square=numpyvar(Cobs_square)
Cpred_square=numpyvar(Cpred_square)

plt.figure(figsize=(10,5))
plt.subplot(1,2,1)
im=plt.imshow(Cobs_square, vmin=-1, vmax=1, cmap='Spectral_r')
plt.colorbar(im,fraction=0.046, pad=.04)
plt.title('Obs. %s: Subj %d' % (conntype,isubj))

plt.subplot(1,2,2)
im=plt.imshow(Cpred_square, vmin=-1, vmax=1, cmap='Spectral_r')
plt.colorbar(im,fraction=0.046, pad=0.04)
plt.title('Pred. %s: Subj %d' % (conntype,isubj))

plt.tight_layout()
plt.show()