In [1]:
import bct
import copy
import matplotlib.cm as cmx
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pickle
import utils

import warnings; warnings.simplefilter('ignore')

from atlases import DesikanAtlas
from hmmlearn import hmm
from sklearn.decomposition import PCA
from sklearn.externals import joblib
from tqdm import tqdm

# Low Dimensional Connectome Dynamics

### Prepare Data - Separate Modalities

In [2]:
all_subjects_all_trials_connectomes = utils.load_connectomes(utils.ALL_SUBJECT_IDS, utils.ALL_TRIAL_IDS)
all_subjects_all_trials_connectomes['fmri'].shape

(17906, 68, 68)

Extract flattened representation of upper triangular of Pearson correlation matrix for each connectome type.

In [3]:
# NOTE: The below logic would have to change if we move away from using Desikan Atlas where the number of regions 
# are the same between EEG and fMRI
num_regions = all_subjects_all_trials_connectomes['fmri'].shape[1]
num_regions

68

In [4]:
upper_triangular_including_diagonal_idxs = np.triu_indices(num_regions, k=0)
lower_triangular_idxs = np.tril_indices(num_regions, k=-1)

In [5]:
all_subjects_all_trials_connectome_upper_triangular_flattened = copy.deepcopy(all_subjects_all_trials_connectomes)
for k in all_subjects_all_trials_connectome_upper_triangular_flattened:
    all_subjects_all_trials_connectome_upper_triangular_flattened[k] = np.array([c[upper_triangular_including_diagonal_idxs].flatten() for c in all_subjects_all_trials_connectomes[k]])

In [6]:
all_subjects_all_trials_connectome_upper_triangular_flattened['fmri'].shape

(17906, 2346)

### Prepare Data - Combined Modalities

In [7]:
def data_matrix_from_channels(channels):
    data_matrix = []
    for k in channels:
        data_matrix.append(all_subjects_all_trials_connectome_upper_triangular_flattened[k])
    data_matrix = np.concatenate(data_matrix, axis=1)
    return data_matrix

In [8]:
alpha_beta_delta_gamma_theta_matrix = data_matrix_from_channels(['alpha', 'beta', 'delta', 'gamma', 'theta'])
fmri_alpha_beta_delta_gamma_theta_matrix = data_matrix_from_channels(['fmri', 'alpha', 'beta', 'delta', 'gamma', 'theta'])
fmri_broad_matrix = data_matrix_from_channels(['fmri', 'broad'])

## Hidden Markov Models

### Training

(1) TODO: Find a bunch of graph statistics (node-wise or global) and compute the time series of the graph statistics and fit an HMM to that.

(2) TODO: Train HMMs on subset of data and predict likelihood of new data

(3) TODO: Align states between modalities.

In [9]:
def num_parameters_in_hmm(model):  
    k = model.means_.shape[0]
    d = model.means_.shape[1]
    # NOTE: We are using a diagonal covariance matrix
    return (k*d)+(k*d)+k+(k**2)

In [10]:
def train_optimal_hmm_on_data(data, bic=True, pca_variance_retained=0.99, forced_component_count=None):
    
    # Apply PCA to dimensionality reduce the data and retain xx% variance
    if pca_variance_retained < 1.0 and pca_variance_retained > 0.0:
        print("\tApplying pca and retaining {0:.2f}% variance...".format(pca_variance_retained*100))
        pca_model = PCA(pca_variance_retained)
        dim_reduced_data = pca_model.fit_transform(data)
        print("\t{0} dimensions -> {1} dimensions".format(data.shape[1], dim_reduced_data.shape[1]))
    else:
        print("Skipping dimensionality reduction...")
        pca_model = None
        dim_reduced_data = data
    
    # Compute BIC/AIC scores for HMM models with # states between 1 and N-1
    scores = []
    candidate_hmm_models = []
    
    for n_components in (range(2, data.shape[0]-1) if not forced_component_count else range(forced_component_count, forced_component_count+1)):
        
        print("\t\t training hmm with n_components={0}".format(n_components))

        # Train HMM model
        candidate_hmm_model = hmm.GaussianHMM(n_components=n_components,
                                              covariance_type="diag").fit(dim_reduced_data)
        candidate_hmm_models.append(candidate_hmm_model)

        # Compute BIC/AIC score
        n = dim_reduced_data.shape[0]
        k = num_parameters_in_hmm(candidate_hmm_model)
        L = candidate_hmm_model.decode(dim_reduced_data)[0]

        print("\t\t\tn={0}, k={1}, L={2}".format(n, k, L))

        score = (np.log(n)*k - 2*L) if bic else (k - L)
        scores.append(score)
        print("\t\t\tScore={0}".format(score))
        
        if len(scores) >= 3:
            # Last two increases in number of components yielded worse scores -> exit search
            if scores[-1] > scores[-3] and scores[-2] > scores[-3]:
                break
    
    # Select the model with the lowest BIC/AIC score
    selected_hmm_model = candidate_hmm_models[np.argmin(scores)]
    return selected_hmm_model, pca_model

In [11]:
def train_optimal_hmms_together(data_channels, bic=True, pca_variance_retained=0.99):
    
    # Apply PCA to dimensionality reduce the data and retain xx% variance
    print("\tApplying pca and retaining {0:.2f}% variance...".format(pca_variance_retained*100))
    pca_models = []
    dim_reduced_data_channels = []
    
    for data in data_channels:
        pca_model = PCA(pca_variance_retained)
        dim_reduced_data = pca_model.fit_transform(data)
        
        pca_models.append(pca_model)
        dim_reduced_data_channels.append(dim_reduced_data)
        print("\t{0} dimensions -> {1} dimensions".format(data.shape[1], dim_reduced_data.shape[1]))
    
    # Compute BIC/AIC scores for HMM models with # states between 1 and N-1
    scores = []
    candidate_hmm_models = []
    
    for n_components in range(2, len(data_channels[0])-1):
        
        print("\t\t training hmms with n_components={0}".format(n_components))

        # Train HMM model for every data channel
        candidate_hmm_models_with_n_components = []
        scores_with_n_components = []
        for dim_reduced_data in dim_reduced_data_channels:
            candidate_hmm_model = hmm.GaussianHMM(n_components=n_components,
                                                  covariance_type="diag").fit(dim_reduced_data)
            candidate_hmm_models_with_n_components.append(candidate_hmm_model)
            
            # Compute BIC/AIC score
            n = dim_reduced_data.shape[0]
            k = num_parameters_in_hmm(candidate_hmm_model)
            L = candidate_hmm_model.decode(dim_reduced_data)[0]

            print("\t\t\t\tn={0}, k={1}, L={2}".format(n, k, L))

            score = (np.log(n)*k - 2*L) if bic else (k - L)
            scores_with_n_components.append(score)
            print("\t\t\t\t\t=> Score={0}".format(score))
        
        candidate_hmm_models.append(candidate_hmm_models_with_n_components)
        scores.append(np.mean(scores_with_n_components))
        print("\t\t\tScore={0}".format(scores[-1]))


        if len(scores) >= 3:
            # Last two increases in number of components yielded worse scores -> exit search
            if scores[-1] > scores[-3] and scores[-2] > scores[-3]:
                break
                
    # Select the model with the lowest BIC/AIC score
    selected_hmm_models = candidate_hmm_models[np.argmin(scores)]
    return selected_hmm_models, pca_models

##### COMBINED EEG BANDS (Alpha+Beta+Delta+Gamma+Theta)

In [12]:
combined_alpha_beta_delta_gamma_theta_hmm_model = joblib.load("output/hmm/models/[hmm]-[alpha_beta_delta_gamma_theta]-[{0}].pkl".format("combined"))
combined_alpha_beta_delta_gamma_theta_pca_model = joblib.load("output/hmm/models/[pca]-[alpha_beta_delta_gamma_theta]-[{0}].pkl".format("combined"))

##### COMBINED EEG BANDS (Broad)

In [13]:
broad_hmm_model = joblib.load("output/hmm/models/[hmm]-[broad].pkl")
broad_pca_model = joblib.load("output/hmm/models/[pca]-[broad].pkl")

#####  COMBINED MODALITY HMM MODEL (fMRI+Alpha+Beta+Delta+Gamma+Theta)

In [14]:
combined_fmri_alpha_beta_delta_gamma_theta_hmm_model = joblib.load("output/hmm/models/[hmm]-[fmri_alpha_beta_delta_gamma_theta]-[{0}].pkl".format("combined"))
combined_fmri_alpha_beta_delta_gamma_theta_pca_model = joblib.load("output/hmm/models/[pca]-[fmri_alpha_beta_delta_gamma_theta]-[{0}].pkl".format("combined"))

#####  COMBINED MODALITY HMM MODEL (fMRI+Broad)

In [15]:
combined_fmri_broad_hmm_model = joblib.load("output/hmm/models/[hmm]-[fmri_broad]-[{0}].pkl".format("combined"))
combined_fmri_broad_pca_model = joblib.load("output/hmm/models/[pca]-[fmri_broad]-[{0}].pkl".format("combined"))

##### SEPARATE EEG BANDS (Alpha+Beta+Delta+Gamma+Theta HMMs Trained Together)

In [16]:
alpha_beta_delta_gamma_theta_hmm_models = [joblib.load("output/hmm/models/[hmm]-[alpha_beta_delta_gamma_theta]-[{0}].pkl".format(k)) for k in ['alpha', 'beta', 'delta', 'gamma', 'theta']]
alpha_beta_delta_gamma_theta_pca_models = [joblib.load("output/hmm/models/[pca]-[alpha_beta_delta_gamma_theta]-[{0}].pkl".format(k)) for k in ['alpha', 'beta', 'delta', 'gamma', 'theta']]

##### SEPARATE MODALITIES (fMRI+Alpha+Beta+Delta+Gamma+Theta HMMs Trained Together)

In [17]:
fmri_alpha_beta_delta_gamma_theta_pca_models = []
fmri_alpha_beta_delta_gamma_theta_hmm_models = []

for k in ['fmri', 'alpha', 'beta', 'delta', 'gamma', 'theta']:
    fmri_alpha_beta_delta_gamma_theta_pca_models.append(joblib.load("output/hmm/models/[pca]-[fmri_alpha_beta_delta_gamma_theta]-[{0}].pkl".format(k)))
    fmri_alpha_beta_delta_gamma_theta_hmm_models.append(joblib.load("output/hmm/models/[hmm]-[fmri_alpha_beta_delta_gamma_theta]-[{0}].pkl".format(k)))

##### SEPARATE MODALITIES (fMRI+Broad HMMs Trained Together)

In [18]:
fmri_broad_pca_models = []
fmri_broad_hmm_models = []

for k in ['fmri', 'broad']:
    fmri_broad_pca_models.append(joblib.load("output/hmm/models/[pca]-[fmri_broad]-[{0}].pkl".format(k)))
    fmri_broad_hmm_models.append(joblib.load("output/hmm/models/[hmm]-[fmri_broad]-[{0}].pkl".format(k)))

### Decoded Hidden State Sequence Likelihood Comparison

##### COMBINED EEG BANDS (Alpha+Beta+Delta+Gamma+Theta)

In [None]:
combined_alpha_beta_delta_gamma_theta_hmm_model.means_.shape[0]

In [None]:
combined_alpha_beta_delta_gamma_theta_log_likelihood, combined_alpha_beta_delta_gamma_theta_decoded_state_sequence = combined_alpha_beta_delta_gamma_theta_hmm_model.decode(combined_alpha_beta_delta_gamma_theta_pca_model.transform(alpha_beta_delta_gamma_theta_matrix))
combined_alpha_beta_delta_gamma_theta_log_likelihood

##### COMBINED EEG BANDS (Broad)

In [None]:
broad_hmm_model.means_.shape[0]

In [None]:
broad_log_likelihood, broad_decoded_state_sequence = broad_hmm_model.decode(broad_pca_model.transform(all_subjects_all_trials_connectome_upper_triangular_flattened['broad']))
broad_log_likelihood

##### COMBINED MODALITIES (fMRI+Alpha+Beta+Delta+Gamma+Theta)

In [None]:
combined_fmri_alpha_beta_delta_gamma_theta_hmm_model.means_.shape[0]

In [None]:
combined_fmri_alpha_beta_delta_gamma_theta_log_likelihood, combined_fmri_alpha_beta_delta_gamma_theta_decoded_state_sequence = combined_fmri_alpha_beta_delta_gamma_theta_hmm_model.decode(combined_fmri_alpha_beta_delta_gamma_theta_pca_model.transform(fmri_alpha_beta_delta_gamma_theta_matrix))
combined_fmri_alpha_beta_delta_gamma_theta_log_likelihood

#####  COMBINED MODALITIES (fMRI+Broad)

In [None]:
combined_fmri_broad_hmm_model.means_.shape[0]

In [None]:
combined_fmri_broad_log_likelihood, combined_fmri_broad_decoded_state_sequence = combined_fmri_broad_hmm_model.decode(combined_fmri_broad_pca_model.transform(fmri_broad_matrix))
combined_fmri_broad_log_likelihood

##### SEPARATE EEG BANDS (Alpha+Beta+Delta+Gamma+Theta HMMs Trained Together)

In [None]:
alpha_beta_delta_gamma_theta_hmm_models[0].means_.shape[0]

In [None]:
alpha_beta_delta_gamma_theta_decoded_state_sequences = []
alpha_beta_delta_gamma_theta_decoded_log_likelihoods = []
alpha_beta_delta_gamma_theta_data_channels = [
    all_subjects_all_trials_connectome_upper_triangular_flattened['alpha'],
    all_subjects_all_trials_connectome_upper_triangular_flattened['beta'],
    all_subjects_all_trials_connectome_upper_triangular_flattened['delta'],
    all_subjects_all_trials_connectome_upper_triangular_flattened['gamma'],
    all_subjects_all_trials_connectome_upper_triangular_flattened['theta']
]
for (data_channel, hmm_model, pca_model) in zip(alpha_beta_delta_gamma_theta_data_channels, alpha_beta_delta_gamma_theta_hmm_models, alpha_beta_delta_gamma_theta_pca_models):
    dim_reduced_data = pca_model.transform(data_channel)
    decoded_state_sequence = hmm_model.decode(dim_reduced_data)
    alpha_beta_delta_gamma_theta_decoded_log_likelihoods.append(decoded_state_sequence[0])
    alpha_beta_delta_gamma_theta_decoded_state_sequences.append(decoded_state_sequence[1])

In [None]:
print(alpha_beta_delta_gamma_theta_decoded_log_likelihoods)
print(np.mean(alpha_beta_delta_gamma_theta_decoded_log_likelihoods))

##### SEPARATE MODALITIES (fMRI+Alpha+Beta+Delta+Gamma+Theta HMMs Trained Together)

In [None]:
fmri_alpha_beta_delta_gamma_theta_hmm_models[0].means_.shape[0]

In [None]:
fmri_alpha_beta_delta_gamma_theta_decoded_state_sequences = []
fmri_alpha_beta_delta_gamma_theta_decoded_log_likelihoods = []
fmri_alpha_beta_delta_gamma_theta_data_channels = [
    all_subjects_all_trials_connectome_upper_triangular_flattened['fmri'],
    all_subjects_all_trials_connectome_upper_triangular_flattened['alpha'],
    all_subjects_all_trials_connectome_upper_triangular_flattened['beta'],
    all_subjects_all_trials_connectome_upper_triangular_flattened['delta'],
    all_subjects_all_trials_connectome_upper_triangular_flattened['gamma'],
    all_subjects_all_trials_connectome_upper_triangular_flattened['theta']
]
for (data_channel, hmm_model, pca_model) in zip(fmri_alpha_beta_delta_gamma_theta_data_channels, fmri_alpha_beta_delta_gamma_theta_hmm_models, fmri_alpha_beta_delta_gamma_theta_pca_models):
    dim_reduced_data = pca_model.transform(data_channel)
    decoded_state_sequence = hmm_model.decode(dim_reduced_data)
    fmri_alpha_beta_delta_gamma_theta_decoded_log_likelihoods.append(decoded_state_sequence[0])
    fmri_alpha_beta_delta_gamma_theta_decoded_state_sequences.append(decoded_state_sequence[1])

In [None]:
print(fmri_alpha_beta_delta_gamma_theta_decoded_log_likelihoods)
print(np.mean(fmri_alpha_beta_delta_gamma_theta_decoded_log_likelihoods))

##### SEPARATE MODALITIES (fMRI+Broad HMMs Trained Together)

In [None]:
fmri_broad_hmm_models[0].means_.shape[0]

In [None]:
fmri_broad_decoded_state_sequences = []
fmri_broad_decoded_log_likelihoods = []
fmri_broad_data_channels = [
    all_subjects_all_trials_connectome_upper_triangular_flattened['fmri'],
    all_subjects_all_trials_connectome_upper_triangular_flattened['broad'],
]
for (data_channel, hmm_model, pca_model) in zip(fmri_broad_data_channels, fmri_broad_hmm_models, fmri_broad_pca_models):
    dim_reduced_data = pca_model.transform(data_channel)
    decoded_state_sequence = hmm_model.decode(dim_reduced_data)
    fmri_broad_decoded_log_likelihoods.append(decoded_state_sequence[0])
    fmri_broad_decoded_state_sequences.append(decoded_state_sequence[1])

In [None]:
print(fmri_broad_decoded_log_likelihoods)
print(np.mean(fmri_broad_decoded_log_likelihoods))

### Spatial Representation of Hidden States

In [19]:
def plot_spatial_representation(hmm_models, pca_models, modalities, title, path):
    num_models_to_plot = len(hmm_models)
    num_states_to_plot = min(min([hmm_model.means_.shape[0] for hmm_model in hmm_models]), 10)
    subplot_idx = 1
    
    fig = plt.figure(figsize=(40*num_states_to_plot, 10+30*num_models_to_plot))
    fig.suptitle(title, fontsize=50)
    subplot_idx = 1
    
    for hmm_model, pca_model, modality in zip(hmm_models, pca_models, modalities):
        for state_idx in range(0, num_states_to_plot):

            # Add empty plot if necessary
            if hmm_model.means_.shape[0] <= state_idx:
                ax = fig.add_subplot(num_models_to_plot, num_states_to_plot, subplot_idx)
                subplot_idx += 1
                continue
                  
            # Extract connectome representation of the hidden state
            hidden_state = np.zeros((num_regions, num_regions))
            hidden_state[upper_triangular_including_diagonal_idxs] = pca_model.inverse_transform(hmm_model.means_[state_idx])
            hidden_state[lower_triangular_idxs] = hidden_state.T[lower_triangular_idxs]

            # Plot connectome representation of the hidden state
            ax = fig.add_subplot(num_models_to_plot, num_states_to_plot, subplot_idx)
            DesikanAtlas.plot(connectome=hidden_state,
                              title='{0} HiddenState-{1} Connectome'.format(modality, state_idx+1),
                              axes=ax)
            subplot_idx += 1

    plt.savefig(path)
    plt.close()

In [20]:
def plot_combined_spatial_representation(combined_hmm_model, combined_pca_model, modalities, modality_feature_idxs, title, path):
    num_modalities_to_plot = len(modalities)
    num_states_to_plot = min(combined_hmm_model.means_.shape[0], 10)
    
    fig = plt.figure(figsize=(40*num_states_to_plot, 10+30*num_modalities_to_plot))
    fig.suptitle(title, fontsize=50)
    subplot_idx = 1


    for modality, feature_idxs in zip(modalities, modality_feature_idxs):
        for state_idx in range(0, num_states_to_plot):

            features = combined_pca_model.inverse_transform(combined_hmm_model.means_[state_idx])
                  
            # Extract connectome representation of the hidden state
            hidden_state = np.zeros((num_regions, num_regions))
            hidden_state[upper_triangular_including_diagonal_idxs] = features[feature_idxs]
            hidden_state[lower_triangular_idxs] = hidden_state.T[lower_triangular_idxs]

            # Plot connectome representation of the hidden state
            ax = fig.add_subplot(num_modalities_to_plot, num_states_to_plot, subplot_idx)
            DesikanAtlas.plot(connectome=hidden_state,
                              title='{0} HiddenState-{1} Connectome'.format(modality, state_idx+1),
                              axes=ax)
            subplot_idx += 1

    plt.savefig(path)
    plt.close()

##### COMBINED EEG BANDS (Alpha+Beta+Delta+Gamma+Theta)

In [21]:
plot_combined_spatial_representation(combined_hmm_model=combined_alpha_beta_delta_gamma_theta_hmm_model,
                                    combined_pca_model=combined_alpha_beta_delta_gamma_theta_pca_model,
                                    modalities=["Alpha", "Beta", "Delta", "Gamma", "Theta"],
                                    modality_feature_idxs=[
                                        range(2346*0, 2346*1),
                                        range(2346*1, 2346*2),
                                        range(2346*2, 2346*3),
                                        range(2346*3, 2346*4),
                                        range(2346*4, 2346*5),
                                    ],
                                    title="Connectome Representation of States of Combined Alpha+Beta+Delta+Gamma+Theta Hidden Markov Model",
                                    path="output/hmm/spatial_representation-[alpha_beta_delta_gamma_theta]-[combined].png")

##### COMBINED EEG BANDS (Broad)

In [22]:
plot_combined_spatial_representation(combined_hmm_model=broad_hmm_model,
                                    combined_pca_model=broad_pca_model,
                                    modalities=["Broad"],
                                    modality_feature_idxs=[
                                        range(2346*0, 2346*1),
                                    ],
                                    title="Connectome Representation of States of Broad Hidden Markov Model",
                                    path="output/hmm/spatial_representation-[broad]-[combined].png")

#####  COMBINED MODALITY HMM MODEL (fMRI+Alpha+Beta+Delta+Gamma+Theta)

In [23]:
plot_combined_spatial_representation(combined_hmm_model=combined_fmri_alpha_beta_delta_gamma_theta_hmm_model,
                                    combined_pca_model=combined_fmri_alpha_beta_delta_gamma_theta_pca_model,
                                    modalities=["fMRI", "Alpha", "Beta", "Delta", "Gamma", "Theta"],
                                    modality_feature_idxs=[
                                        range(2346*0, 2346*1),
                                        range(2346*1, 2346*2),
                                        range(2346*2, 2346*3),
                                        range(2346*3, 2346*4),
                                        range(2346*4, 2346*5),
                                        range(2346*5, 2346*6),
                                    ],
                                    title="Connectome Representation of States of Combined fMRI+Alpha+Beta+Delta+Gamma+Theta Hidden Markov Model",
                                    path="output/hmm/spatial_representation-[fmri_alpha_beta_delta_gamma_theta]-[combined].png")

#####  COMBINED MODALITY HMM MODEL (fMRI+Broad)

In [24]:
plot_combined_spatial_representation(combined_hmm_model=combined_fmri_broad_hmm_model,
                                    combined_pca_model=combined_fmri_broad_pca_model,
                                    modalities=["fMRI", "Broad"],
                                    modality_feature_idxs=[
                                        range(2346*0, 2346*1),
                                        range(2346*1, 2346*2),
                                    ],
                                    title="Connectome Representation of States of Combined fMRI+Broad Hidden Markov Model",
                                    path="output/hmm/spatial_representation-[fmri_broad]-[combined].png")

##### SEPARATE EEG BANDS (Alpha+Beta+Delta+Gamma+Theta HMMs Trained Together)

In [25]:
plot_spatial_representation(hmm_models=alpha_beta_delta_gamma_theta_hmm_models,
                            pca_models=alpha_beta_delta_gamma_theta_pca_models,
                            modalities=["Alpha", "Beta", "Delta", "Gamma", "Theta"],
                            title="Connectome Representation of Jointly Trained Alpha+Beta+Delta+Gamma+Theta Hidden Markov Model States",
                            path="output/hmm/spatial_representation-[alpha_beta_delta_gamma_theta]-[separate].png")

##### SEPARATE MODALITIES (fMRI+Alpha+Beta+Delta+Gamma+Theta HMMs Trained Together)

In [26]:
plot_spatial_representation(hmm_models=fmri_alpha_beta_delta_gamma_theta_hmm_models,
                            pca_models=fmri_alpha_beta_delta_gamma_theta_pca_models,
                            modalities=["fMRI", "Alpha", "Beta", "Delta", "Gamma", "Theta"],
                            title="Connectome Representation of Jointly Trained fMRI+Alpha+Beta+Delta+Gamma+Theta Hidden Markov Model States",
                            path="output/hmm/spatial_representation-[fmri_alpha_beta_delta_gamma_theta]-[separate].png")

##### SEPARATE MODALITIES (fMRI+Broad HMMs Trained Together)

In [27]:
plot_spatial_representation(hmm_models=fmri_broad_hmm_models,
                            pca_models=fmri_broad_pca_models,
                            modalities=["fMRI", "Broad"],
                            title="Connectome Representation of Jointly Trained fMRI+Broad Hidden Markov Model States",
                            path="output/hmm/spatial_representation-[fmri_broad]-[separate].png")

### Decoded Hidden State Sequence Statistics

Plot fractional occupancy - the fraction of time spent in each state relative to the total duration

In [None]:
# SEPARATE
fig = plt.figure(figsize=(20, 4))
fig.suptitle('Fractional Occupancy')
num_plots = len(decoded_state_sequences)
subplot_idx = 1

for (k, hmm_model, state_seq) in zip(all_subjects_all_trials_connectome_upper_triangular_flattened, hmm_models, decoded_state_sequences):
    
    # Compute fractional occupancy
    fractional_occupancies_per_state = []
    for state_idx in range(0, hmm_model.means_.shape[0]):
        fractional_occupancy = len(state_seq[state_seq == state_idx])/len(state_seq)
        fractional_occupancies_per_state.append(fractional_occupancy)
        
    # Plot fractional occupancy per state
    fig.add_subplot(1, num_plots, subplot_idx)
    plt.title(k)
    x = np.arange(hmm_model.means_.shape[0])
    plt.bar(x, height=fractional_occupancies_per_state)
    plt.xticks(x, [str(x_i) for x_i in x])
    subplot_idx += 1
    
plt.savefig('output/hmm/fractional_occupancy.png')
plt.close()

In [None]:
# COMBINED
# Compute fractional occupancy
fractional_occupancies_per_state = []
for state_idx in range(0, combined_modality_hmm_model.means_.shape[0]):
    fractional_occupancy = len(combined_modality_decoded_state_sequence[combined_modality_decoded_state_sequence == state_idx])/len(combined_modality_decoded_state_sequence)
    fractional_occupancies_per_state.append(fractional_occupancy)

# Plot fractional occupancy per state
plt.title('Combined Fractional Occupancy')
x = np.arange(combined_modality_hmm_model.means_.shape[0])
plt.bar(x, height=fractional_occupancies_per_state)
plt.xticks(x, [str(x_i) for x_i in x])
plt.savefig('output/hmm/combined_fractional_occupancy.png')
plt.close()

Plot mean life time - the time spent in a state before transitioning to a new state on average

In [None]:
# SEPARATE
fig = plt.figure(figsize=(20, 4))
fig.suptitle('Mean Life Time')
num_plots = len(decoded_state_sequences)
subplot_idx = 1

for (k, hmm_model, state_seq) in zip(all_subjects_all_trials_connectome_upper_triangular_flattened, hmm_models, decoded_state_sequences):

    # Compute mean life time per state
    mean_life_time_per_state = []
    for state_id in range(0, hmm_model.means_.shape[0]):
        
        # Count number of transitions out of state with state_id
        num_transitions_out_of_state_with_state_id = 0
        for i in range(0, len(state_seq)-1):
            if state_seq[i] == state_id and state_seq[i+1] != state_id:
                num_transitions_out_of_state_with_state_id += 1
        
        # Count total number of time points spent in state with state id
        num_time_points_in_state_with_state_id = len(state_seq[state_seq == state_id])
        
        # Compute mean life time
        mean_life_time = num_time_points_in_state_with_state_id/num_transitions_out_of_state_with_state_id
        mean_life_time_per_state.append(mean_life_time)
    
    
    # Plot mean life time per state
    fig.add_subplot(1, num_plots, subplot_idx)
    plt.title(k)
    x = np.arange(hmm_model.means_.shape[0])
    plt.bar(x, height=mean_life_time_per_state)
    plt.xticks(x, [str(x_i) for x_i in x])
    subplot_idx += 1
    
plt.savefig('output/hmm/mean_life_time.png')
plt.close()

In [None]:
# COMBINED
# Compute mean life time per state
mean_life_time_per_state = []
for state_id in range(0, combined_modality_hmm_model.means_.shape[0]):

    # Count number of transitions out of state with state_id
    num_transitions_out_of_state_with_state_id = 0
    for i in range(0, len(combined_modality_decoded_state_sequence)-1):
        if combined_modality_decoded_state_sequence[i] == state_id and combined_modality_decoded_state_sequence[i+1] != state_id:
            num_transitions_out_of_state_with_state_id += 1

    # Count total number of time points spent in state with state id
    num_time_points_in_state_with_state_id = len(combined_modality_decoded_state_sequence[combined_modality_decoded_state_sequence == state_id])

    # Compute mean life time
    mean_life_time = num_time_points_in_state_with_state_id/num_transitions_out_of_state_with_state_id
    mean_life_time_per_state.append(mean_life_time)


# Plot mean life time per state
plt.title('Combined Mean Life Time')
x = np.arange(combined_modality_hmm_model.means_.shape[0])
plt.bar(x, height=mean_life_time_per_state)
plt.xticks(x, [str(x_i) for x_i in x])

plt.savefig('output/hmm/combined_mean_life_time.png')
plt.close()

Plot transition probabilities.

In [None]:
# SEPARATE - Transition Probabilities
fig = plt.figure(figsize=(20, 4))
fig.suptitle('Transition Probabilities')
num_plots = len(hmm_models)
subplot_idx = 1

for (k, hmm_model) in zip(all_subjects_all_trials_connectome_upper_triangular_flattened, hmm_models):
    
    fig.add_subplot(1, num_plots, subplot_idx)
    plt.imshow(hmm_model.transmat_, cmap='gist_heat')
    plt.title(k)
    plt.colorbar()
    subplot_idx += 1

plt.savefig('output/hmm/transition_probabilities.png')
plt.close()

In [None]:
# COMBINED - Transition Probabilities    
plt.imshow(combined_modality_hmm_model.transmat_, cmap='gist_heat')
plt.title('Combined Transition Probabilities')
plt.colorbar()
plt.savefig('output/hmm/combined_transition_probabilities.png')
plt.close()

Plot transitions between decoded states as a time series.

In [None]:
def get_n_colors(n):
    return [ cmx.rainbow(float(i)/n) for i in range(n) ]

In [None]:
# Separate
fig = plt.figure(figsize=(300, 25))
fig.suptitle('Decoded Hidden State Sequences')

num_plots = len(hmm_models)
subplot_idx = 1

for (k, hmm_model, state_seq) in zip(all_subjects_all_trials_connectome_upper_triangular_flattened, hmm_models, decoded_state_sequences):
    
    print(k)
    fig.add_subplot(num_plots, 1, subplot_idx)
    component_colors = get_n_colors(hmm_model.means_.shape[0])
    x = 0
    for state in state_seq[:6000]:
        plt.axvline(x=x, color=component_colors[state])
        x += 1
        
    plt.title("{0}".format(k))
    plt.yticks([])
    subplot_idx += 1    

plt.subplots_adjust(hspace=0.5)
plt.savefig('output/hmm/decoded_hidden_state_sequence.png')
plt.close()

In [None]:
# Combined
fig = plt.figure(figsize=(300, 5))

component_colors = get_n_colors(combined_modality_hmm_model.means_.shape[0])
x = 0
for state in combined_modality_decoded_state_sequence[:6000]:
    plt.axvline(x=x, color=component_colors[state])
    x += 1

plt.title("{0}".format("Combined Decoded Hidden State Sequence"))
plt.yticks([])
plt.subplots_adjust(hspace=0.5)
plt.savefig('output/hmm/combined_decoded_hidden_state_sequence.png')
plt.close()

# Brain Graph Statistics

In [None]:
def controllability_statistic(connectome, edge_threshold=0.9):
    assert(edge_threshold > 0.5)
    
    # Sparsify connectome by only keeping top x/2% and bottom x/2% most extreme edge weights
    sorted_edge_weights = sorted(connectome.flatten())
    
    upper_threshold_idx = int( (edge_threshold+((1 - edge_threshold)/2)) * len(sorted_edge_weights) )    
    lower_threshold_idx = int( ((1 - edge_threshold)/2)                  * len(sorted_edge_weights) )

    upper_threshold_val = sorted_edge_weights[upper_threshold_idx]   
    lower_threshold_val = sorted_edge_weights[lower_threshold_idx]   
    
    thresholded_connectome = np.copy(connectome)
    for i in range(thresholded_connectome.shape[0]):
        for j in range(thresholded_connectome.shape[1]):
            if thresholded_connectome[i][j] > lower_threshold_val and thresholded_connectome[i][j] < upper_threshold_val:
                thresholded_connectome[i][j] = 0.0
    
    # Compute maximum matching from sparsified graph
    nx_graph = nx.from_numpy_matrix(thresholded_connectome)
    max_matching = nx.max_weight_matching(nx_graph, maxcardinality=True)
    
    # Compute controllability stat from maximum matching
    num_nodes = connectome.shape[0]
    matched_nodes = set([node for node_pair in max_matching for node in node_pair])
    num_unmatched_nodes = num_nodes - len(matched_nodes)
    controllability = num_unmatched_nodes/num_nodes

    return controllability

In [None]:
# The assortativity coefficient is a correlation coefficient between the strengths (weighted degrees) of all nodes
# on two opposite ends of a link. A positive assortativity coefficient indicates that nodes tend to link to other
# nodes with the same or similar strength.
assortativity_statistic_time_series = {}

# The global efficiency is the average inverse shortest path length in the network, and is inversely related 
# to the characteristic path length.
global_efficiency_statistic_time_series = {}

# The optimal community structure is a subdivision of the network into nonoverlapping groups of nodes in a way
# that maximizes the number of within-group edges, and minimizes the number of between-group edges. 
# The modularity is a statistic that quantifies the degree to which the network may be subdivided into such clearly
# delineated groups.
modularity_statistic_time_series = {}

# The fraction of nodes required to 'control' the state of the network:
#  https://www.barabasilab.com/publications/controllability-of-complex-networks
controllability_statistic_time_series = {}

num_statistic_types = 3

for k in all_subjects_all_trials_connectomes:
    
    print(k)
    assortativity_statistic_time_series[k] = []
    global_efficiency_statistic_time_series[k] = []
    modularity_statistic_time_series[k] = []
    controllability_statistic_time_series[k] = []
    
    for time_pt in all_subjects_all_trials_connectomes[k]:
#         assortativity_statistic_time_series[k].append(bct.assortativity_wei(time_pt))
#         global_efficiency_statistic_time_series[k].append(bct.efficiency_wei(time_pt))        
#         modularity_statistic_time_series[k].append(bct.modularity_und(time_pt)[1]) 
        controllability_statistic_time_series[k].append(controllability_statistic(time_pt))

Plot timeseries of graph statistics.

In [None]:
fig = plt.figure(figsize=(100, 20))
fig.suptitle('Graph Statistic Timeseries')

num_unique_statistic_modality_pairs = num_statistic_types*len(modularity_statistic_time_series)
subplot_idx = 1

for k in assortativity_statistic_time_series:
    
    t = np.arange(len(assortativity_statistic_time_series[k]))

    # Modularity
    fig.add_subplot(num_unique_statistic_modality_pairs, num_statistic_types, subplot_idx)
    plt.plot(t, modularity_statistic_time_series[k], color="green")
    plt.title("{0} modularity".format(k))
    subplot_idx += 1
    
    # Assortativity
    fig.add_subplot(num_unique_statistic_modality_pairs, num_statistic_types, subplot_idx)
    plt.plot(t, assortativity_statistic_time_series[k], color="red")
    plt.title("{0} assortativity".format(k))
    subplot_idx += 1
    
    # Global Efficiency
    fig.add_subplot(num_unique_statistic_modality_pairs, num_statistic_types, subplot_idx)
    plt.plot(t, global_efficiency_statistic_time_series[k], color="blue")
    plt.title("{0} global efficiency".format(k))
    subplot_idx += 1

plt.subplots_adjust(hspace=0.5)
plt.savefig('output/hmm/graph_statistic_timeseries.png')
plt.close()

Compute correlation between all possible pairs of statistics across modalities.

In [None]:
statistic_timeseries_correlation_matrix = np.zeros((num_unique_statistic_modality_pairs, num_unique_statistic_modality_pairs))
labels = []

i = -1
j = -1

for statistic_timeseries_a, statistic_name_a in zip([modularity_statistic_time_series, assortativity_statistic_time_series, global_efficiency_statistic_time_series], ["modularity", "assortativity", "global efficiency"]):
    for k_a in statistic_timeseries_a:
        i += 1
        j = -1
        
        labels.append("{0} {1}".format(k_a, statistic_name_a))
        
        for statistic_timeseries_b, statistic_name_b in zip([modularity_statistic_time_series, assortativity_statistic_time_series, global_efficiency_statistic_time_series], ["modularity", "assortativity", "global efficiency"]):
            for k_b in statistic_timeseries_b:
                
                j += 1
                corr = np.corrcoef(statistic_timeseries_a[k_a], statistic_timeseries_b[k_b])[0, 1]                
                statistic_timeseries_correlation_matrix[i, j] = corr

In [None]:
f = plt.figure(figsize=(15, 20))

plt.imshow(statistic_timeseries_correlation_matrix, cmap='gist_heat')
plt.title("Cross Modality Correlations in Brain Graph Statistic Timeseries").set_position([.5, 1.3])
plt.axes().xaxis.set_ticks_position('top')
plt.xticks(range(len(labels)), labels, rotation='vertical')
plt.yticks(range(len(labels)), labels)
plt.colorbar()

f.savefig('output/hmm/graph_statistic_cross_modality_correlations.png')
plt.close()

##### Fit a HMM to the time series of the brain graph statistics

In [None]:
brain_graph_statistics_descriptions = [
    "fMRI Modularity",
    "Alpha-Band EEG Modularity",
    "Beta-Band EEG Modularity",
    "Delta-Band EEG Modularity",
    "Gamma-Band EEG Modularity",
    "Theta-Band EEG Modularity",
    
    "fMRI Assortativity",
    "Alpha-Band EEG Assortativity",
    "Beta-Band EEG Assortativity",
    "Delta-Band EEG Assortativity",
    "Gamma-Band EEG Assortativity",
    "Theta-Band EEG Assortativity",
    
    "fMRI Global Efficiency",
    "Alpha-Band EEG Global Efficiency",
    "Beta-Band EEG Global Efficiency",
    "Delta-Band EEG Global Efficiency",
    "Gamma-Band EEG Global Efficiency",
    "Theta-Band EEG Global Efficiency",
]

In [None]:
brain_graph_statistics = pickle.load(open('output/hmm/graph_statistics/brain_graph_statistics_time_series.pkl', 'rb'))
brain_graph_statistics_means = pickle.load(open('output/hmm/graph_statistics/brain_graph_statistics_means.pkl', 'rb'))
brain_graph_statistics_vars = pickle.load(open('output/hmm/graph_statistics/brain_graph_statistics_vars.pkl', 'rb'))

In [None]:
brain_graph_statistics_hmm_model, _ = train_optimal_hmm_on_data(brain_graph_statistics, pca_variance_retained=1.0)

In [None]:
brain_graph_statistics_hmm_model.means_.shape

In [None]:
# Hidden State Analysis
num_hidden_states = brain_graph_statistics_hmm_model.means_.shape[0]

f = plt.figure(figsize=(10*num_hidden_states, 20))
f.suptitle('Brain Graph Statistic HMM Hidden State Analysis')
subplot_idx = 1

for i in range(num_hidden_states):

    hidden_state = brain_graph_statistics_hmm_model.means_[i]

    f.add_subplot(1, num_hidden_states, subplot_idx)
    x = 0
    for stat, stat_mean, stat_var in zip(hidden_state, brain_graph_statistics_means, brain_graph_statistics_vars):
        z_stat = (stat - stat_mean)/np.sqrt(stat_var)
        plt.bar(x, z_stat, width=1, color='blue' if z_stat > 0 else 'red')
        x += 1
    plt.ylim([-1.96, 1.96])
    plt.ylabel("Z-Score of Statistic in Hidden State Compared to Overall Timeseries")
    plt.title("Hidden State {0}".format(i))
    plt.xticks(range(len(brain_graph_statistics_descriptions)), brain_graph_statistics_descriptions, rotation='vertical')
    subplot_idx += 1

plt.subplots_adjust(hspace=0.5)
f.savefig('output/hmm/graph_statistic_hmm_hidden_state_analysis.png')
plt.close()

In [None]:
# Transition Probabilities    
plt.imshow(brain_graph_statistics_hmm_model.transmat_, cmap='gist_heat')
plt.title('Brain Graph Statistic HMM Transition Probabilities')
plt.xticks(np.arange(0, brain_graph_statistics_hmm_model.transmat_.shape[0], 1))
plt.yticks(np.arange(0, brain_graph_statistics_hmm_model.transmat_.shape[0], 1))
plt.colorbar()
plt.savefig('output/hmm/graph_statistic_transition_probabilities.png')
plt.close()

In [None]:
# Decoded State Sequence
brain_graph_statistics_hmm_decoded_log_likelihood, brain_graph_statistics_hmm_decoded_state_sequence = brain_graph_statistics_hmm_model.decode(brain_graph_statistics)

In [None]:
# Mean Life Time Per State
mean_life_time_per_state = []
for state_id in range(0, brain_graph_statistics_hmm_model.means_.shape[0]):

    # Count number of transitions out of state with state_id
    num_transitions_out_of_state_with_state_id = 0
    for i in range(0, len(brain_graph_statistics_hmm_decoded_state_sequence)-1):
        if brain_graph_statistics_hmm_decoded_state_sequence[i] == state_id and brain_graph_statistics_hmm_decoded_state_sequence[i+1] != state_id:
            num_transitions_out_of_state_with_state_id += 1

    # Count total number of time points spent in state with state id
    num_time_points_in_state_with_state_id = len(brain_graph_statistics_hmm_decoded_state_sequence[brain_graph_statistics_hmm_decoded_state_sequence == state_id])

    # Compute mean life time
    mean_life_time = num_time_points_in_state_with_state_id/num_transitions_out_of_state_with_state_id
    mean_life_time_per_state.append(mean_life_time)


# Plot mean life time per state
plt.title('Brain Graph Statistic HMM Hidden State Mean Life Time')
x = np.arange(brain_graph_statistics_hmm_model.means_.shape[0])
plt.bar(x, height=mean_life_time_per_state)
plt.xticks(x, [str(x_i) for x_i in x])

plt.savefig('output/hmm/graph_statistic_hmm_mean_life_time.png')
plt.close()

In [None]:
# Fractional Occupancy Per State
fractional_occupancies_per_state = []
for state_idx in range(0, brain_graph_statistics_hmm_model.means_.shape[0]):
    fractional_occupancy = len(brain_graph_statistics_hmm_decoded_state_sequence[brain_graph_statistics_hmm_decoded_state_sequence == state_idx])/len(brain_graph_statistics_hmm_decoded_state_sequence)
    fractional_occupancies_per_state.append(fractional_occupancy)

# Plot fractional occupancy per state
plt.title('Brain Graph Statistic HMM Hidden State Fractional Occupancy')
x = np.arange(brain_graph_statistics_hmm_model.means_.shape[0])
plt.bar(x, height=fractional_occupancies_per_state)
plt.xticks(x, [str(x_i) for x_i in x])
plt.savefig('output/hmm/graph_statistic_hmm_fractional_occupancy.png')
plt.close()

In [None]:
TODO: Compute Controllability Statistic

# Visualize fMRI/EEG Connectome Dynamics

In [None]:
# for subject_id in ALL_SUBJECT_IDS:
#     for trial_id in ALL_TRIAL_IDS:
        
#         # Attempt to load all connectome types
#         connectomes = load_all_connectome_types(subject_id, trial_id,
#                                                atlas='desikan', 
#                                                seconds_used_to_compute_fmri_connectome=60,
#                                                exclude_bad_fmri_frames=True,
#                                                filter_artifact_timepoints=True)
        
#         if connectomes is None:
#             continue

#         # Plot connectomes through time
#         for t in range(0, connectomes['fmri'].shape[0]):
            
#             # Create figure and set title
#             fig = plt.figure(figsize=(30, 35))
#             fig.suptitle('Subject: "{0}" | Trial: {1} | Time: {2}'.format(subject_id, trial_id, t), fontsize=50)
            
#             # Plot connectomes
#             subplot_idx = 1
#             for connectome_id, connectome in connectomes.items():

#                 ax = fig.add_subplot(len(connectomes), 2, subplot_idx)
#                 plotting.plot_connectome(connectome[t], desikan_atlas_coordinates(), title='{0} Connectome'.format(connectome_id),
#                                          edge_threshold='95%', node_size=20, colorbar=True, axes=ax)
#                 subplot_idx += 1
            
#                 ax = fig.add_subplot(len(connectomes), 2, subplot_idx)
#                 plotting.plot_matrix(connectome[t], vmin=-1., vmax=1., colorbar=True, axes=ax)
#                 subplot_idx += 1
    
#             plt.savefig('output/connectomes_through_time/subject={0}_trial={1}_t={2}.png'.format(subject_id, trial_id, t))