In [7]:
import os
from os.path import join, expanduser

from keras.models import load_model
from nilearn._utils import check_niimg
from wordcloud import WordCloud

from modl.datasets import get_data_dirs
from modl.hierarchical import HierarchicalLabelMasking, PartialSoftmax
from modl.classification import Reconstructer
from modl.input_data.fmri.unmask import retrieve_components
from modl.utils.system import get_cache_dirs
from nilearn.image import index_img
from nilearn.input_data import MultiNiftiMasker
from nilearn.plotting import plot_stat_map, plot_prob_atlas
from sklearn.externals.joblib import Memory, load
import matplotlib.pyplot as plt

import numpy as np


def get_model(artifact_dir, scale_importance, dictionary_penalty,
              n_components_list):
    model = load_model(join(artifact_dir, 'model.keras'),
                       custom_objects={'HierarchicalLabelMasking':
                                           HierarchicalLabelMasking,
                                       'PartialSoftmax': PartialSoftmax})
    memory = Memory(cachedir=get_cache_dirs()[0], verbose=2)
    print('Fetch data')
    mask = join('/home/arthur/data/modl_data/mask_img.nii.gz')
    masker = MultiNiftiMasker(mask_img=mask, smoothing_fwhm=0).fit()
    print('Retrieve components')
    bases = memory.cache(retrieve_components)(dictionary_penalty, masker,
                                              n_components_list)
    for i, basis in enumerate(bases):
        S = np.std(basis, axis=1)
        S[S == 0] = 0
        basis = basis / S[:, np.newaxis]
        bases[i] = basis
    reconstructer = Reconstructer(bases=bases,
                                  scale_importance=scale_importance)
    weights = model.get_layer('latent').get_weights()[0].T
    imgs = reconstructer.fit_transform(weights)
    return masker, imgs, model

In [8]:
dictionary_penalty = 1e-4
n_components_list = [16, 64, 256]
scale_importance = 'sqrt'

artifact_dir = join(expanduser('~/data/modl_data'), 'pipeline', 'contrast',
                    'prediction_hierarchical', 'good')

analysis_dir = join(expanduser('~/data/modl_data'), 'pipeline', 'contrast',
                    'prediction_hierarchical', 'analysis')
if not os.path.exists(analysis_dir):
    os.makedirs(analysis_dir)

masker, imgs, model = get_model(artifact_dir, scale_importance,
                                dictionary_penalty, n_components_list)

np.save(join(analysis_dir, 'imgs'), imgs)
components_imgs = masker.inverse_transform(imgs)
components_imgs.to_filename(join(analysis_dir, 'components.nii.gz'))

imgs /= np.sqrt(np.sum(imgs ** 2, axis=1, keepdims=True))
gram = imgs.dot(imgs.T)
plt.imshow(gram)
plt.colorbar()
plt.savefig('gram.pdf')

Fetch data
Retrieve components
[Memory]    0.0s, 0.0min: Loading retrieve_components...


In [9]:
weights, bias = model.get_layer('supervised_depth_1').get_weights()

labels = load(join(artifact_dir, 'labels.pkl'))
X = load(join(artifact_dir, 'X.pkl'))
y = X.index.values
y_pred = load(join(artifact_dir, 'y_pred.pkl'))

In [12]:
X.values.shape

(27952, 336)

In [15]:
weights = model.get_layer('latent').get_weights()[0].T

In [18]:
X_proj = X.values.dot(weights.T)

In [19]:
X_proj

array([[-0.47945663, -0.08212911,  0.89589018, ...,  1.22268927,
        -1.26734388,  2.94306493],
       [-1.95100904,  0.95480269, -0.73394781, ...,  0.15132065,
        -0.49374482,  1.26228547],
       [-1.73693407,  0.08936246, -0.75502312, ...,  0.12293071,
        -0.3658635 ,  0.85336179],
       ..., 
       [-1.58841646,  2.36796784, -1.1477381 , ..., -2.6932292 ,
         4.13038826,  0.9163034 ],
       [-3.31059718,  2.35946369, -5.35217428, ..., -2.19593811,
         2.41828012, -3.31701946],
       [-2.54856086,  0.41908887, -3.8307519 , ..., -3.22793317,
         2.60602593, -3.0742774 ]], dtype=float32)

In [22]:
y_pred = y_pred[1]

In [27]:
from sklearn.cross_decomposition import CCA

In [33]:
cca = CCA(n_components=25)

In [34]:
cca.fit(X_proj, y_pred)

CCA(copy=True, max_iter=500, n_components=25, scale=True, tol=1e-06)

In [32]:
cca.coef_.shape

(25, 97)

In [41]:
cca.x_rotations_.shape

(25, 25)

In [42]:
cca.y_rotations_.shape

(97, 25)

In [44]:
cca.y_rotations_.T[0]

array([  1.33135021e-06,   1.94834300e-06,   1.46288861e-06,
         1.23994699e-06,   1.25514912e-06,   1.39589840e-06,
        -1.56606408e-06,  -2.84905245e-07,   8.57217841e-07,
        -6.99229154e-07,   1.32039003e-07,  -1.28426759e-06,
        -2.86942260e-07,   2.42201646e-07,   2.97577410e-07,
         6.02436118e-07,  -1.29402590e-06,   9.54423650e-07,
         9.05203240e-07,   1.69977510e-06,   7.35574273e-07,
        -3.05512752e-07,  -1.69780240e-06,  -1.04857563e-08,
        -1.48296924e-06,   9.94358609e-08,  -1.62118588e-06,
        -1.95220812e-06,  -8.25235567e-07,  -9.66353833e-07,
         3.29447266e-07,  -1.12688900e-07,  -2.60468409e-07,
        -5.51104425e-07,  -5.14304428e-07,  -2.65501600e-07,
         1.83463432e-06,   1.79866274e-07,  -4.74715456e-07,
        -7.61601739e-08,  -1.61851720e-08,   1.85837969e-07,
         2.47650186e-07,  -2.28810605e-07,   1.84374468e-07,
        -1.80205860e-08,   9.85837278e-09,   1.08157698e-06,
         5.72863799e-08,