<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Table-of-Contents" data-toc-modified-id="Table-of-Contents-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Table of Contents</a></span></li><li><span><a href="#Initialize-Environment" data-toc-modified-id="Initialize-Environment-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Initialize Environment</a></span></li><li><span><a href="#Load-Toy-Data" data-toc-modified-id="Load-Toy-Data-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Load Toy Data</a></span></li><li><span><a href="#Measure-Functional-Connectivity" data-toc-modified-id="Measure-Functional-Connectivity-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Measure Functional Connectivity</a></span></li><li><span><a href="#Optimize-Dynamic-Subgraphs-Parameters" data-toc-modified-id="Optimize-Dynamic-Subgraphs-Parameters-5"><span class="toc-item-num">5&nbsp;&nbsp;</span>Optimize Dynamic Subgraphs Parameters</a></span><ul class="toc-item"><li><span><a href="#Generate-Cross-Validation-Parameter-Sets" data-toc-modified-id="Generate-Cross-Validation-Parameter-Sets-5.1"><span class="toc-item-num">5.1&nbsp;&nbsp;</span>Generate Cross-Validation Parameter Sets</a></span></li><li><span><a href="#Run-NMF-Cross-Validation-Parameter-Sets" data-toc-modified-id="Run-NMF-Cross-Validation-Parameter-Sets-5.2"><span class="toc-item-num">5.2&nbsp;&nbsp;</span>Run NMF Cross-Validation Parameter Sets</a></span></li><li><span><a href="#Visualize-Quality-Measures-of-Search-Space" data-toc-modified-id="Visualize-Quality-Measures-of-Search-Space-5.3"><span class="toc-item-num">5.3&nbsp;&nbsp;</span>Visualize Quality Measures of Search Space</a></span></li></ul></li><li><span><a href="#Detect-Dynamic-Subgraphs" data-toc-modified-id="Detect-Dynamic-Subgraphs-6"><span class="toc-item-num">6&nbsp;&nbsp;</span>Detect Dynamic Subgraphs</a></span><ul class="toc-item"><li><span><a href="#Stochastic-Factorization-with-Consensus" data-toc-modified-id="Stochastic-Factorization-with-Consensus-6.1"><span class="toc-item-num">6.1&nbsp;&nbsp;</span>Stochastic Factorization with Consensus</a></span></li><li><span><a href="#Plot-all-the-subgraphs" data-toc-modified-id="Plot-all-the-subgraphs-6.2"><span class="toc-item-num">6.2&nbsp;&nbsp;</span>Plot all the subgraphs</a></span></li></ul></li></ul></div>

# Initialize Environment

In [None]:
from __future__ import division

import os
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['NUMEXPR_NUM_THREADS'] = '1'
os.environ['OMP_NUM_THREADS'] = '1'
import sys

# Data manipulation
import numpy as np
import scipy.io as io
import NMF

# Plotting
import matplotlib.pyplot as plt
#import seaborn as sns

# Load Toy Data

In [None]:
# df contains the following keys:
#   -- evData contains ECoG with dims: n_sample x n_channels
#   -- Fs contains sampling frequency: 1 x 1
#   -- channel_lbl contains strings of channel labels with dims: n_channels
#   -- channel_ix_soz contains indices of seizure-onset channels: n_soz

df = io.loadmat('./ToyData/Seizure_ECoG.mat')
evData = df['evData']
fs = int(df['Fs'][0,0])
n_sample, n_chan = evData.shape

# Measure Functional Connectivity

In [None]:
def compute_dynamic_windows(n_sample, fs, win_dur=1.0, win_shift=1.0):
    """
        Divide samples into bins based on window duration and shift.
        
        Parameters
        ----------
            n_sample: int
                Number of samples
            fs: int
                Sampling frequency
            win_dur: float
                Duration of the dynamic window
            win_shift: float
                Shift of the dynamic window
    
        Returns
        -------
            win_ix: ndarray with dims: (n_win, n_ix)
    """
    
    n_samp_per_win = int(fs * win_dur)
    n_samp_per_shift = int(fs * win_shift)
    
    curr_ix = 0
    win_ix = []
    while (curr_ix+n_samp_per_win) <= n_sample:
        win_ix.append(np.arange(curr_ix, curr_ix+n_samp_per_win))
        curr_ix += n_samp_per_shift
    win_ix = np.array(win_ix)
    
    return win_ix

# Compute dynamic functional connectivity using Echobase
adj = []
for ix in compute_dynamic_windows(n_sample, fs):
    adj.append(np.corrcoef(evData[ix, :].T))
adj = np.array(adj)

# Transform to a configuration matrix (n_window x n_connection)
triu_ix, triu_iy = np.triu_indices(n_chan, k=1)
A_hat = adj[:, triu_ix, triu_iy]
n_win, n_conn = A_hat.shape

# Optimize Dynamic Subgraphs Parameters

## Generate Cross-Validation Parameter Sets

In [None]:
def generate_folds(n_win, n_fold):
    """
        Generate folds for cross-validation by randomly dividing the windows
        into different groups for train/test-set.
        
        Parameters
        ----------
            n_win: int
                Number of windows (observations) in the configuration matrix
            n_fold: int
                Number of folds desired
        
        Returns
        -------
            fold_list: list[list]
                List of index lists that can be further divided into train
                and test sets
    """

    # discard incomplete folds
    n_win_per_fold = int(np.floor(n_win / n_fold))  
    
    win_list = np.arange(n_win)
    win_list = np.random.permutation(win_list)
    win_list = win_list[:(n_win_per_fold*n_fold)]
    win_list = win_list.reshape(n_fold, -1)
    fold_list = [list(ff) for ff in win_list]

    return fold_list

fold_list = generate_folds(n_win, n_fold=5)

# Set the bounds of the search space
# Random sampling scheme
param_search_space = {'rank_range': (2, 10),
                      'alpha_range': (0.01, 1.0),
                      'beta_range': (0.01, 1.0),
                      'n_param': 20}

# Get parameter search space
# Each sampled parameter set will be evaluated n_fold times
param_list = NMF.optimize.gen_random_sampling_paramset(
    fold_list=fold_list,
    **param_search_space)

## Run NMF Cross-Validation Parameter Sets

In [None]:
# **This cell block should be parallelized. Takes time to run**

# Produces a list of quality measures for each parameter set in param_list
qmeas_list = [NMF.optimize.run_xval_paramset(A_hat, pdict)
              for pdict in param_list]

## Visualize Quality Measures of Search Space

In [None]:
all_param, opt_param = NMF.optimize.find_optimum_xval_paramset(param_list, qmeas_list, search_pct=5)

# Generate quality measure plots
for qmeas in ['error', 'pct_sparse_subgraph', 'pct_sparse_coef']:
    for param in ['rank', 'alpha', 'beta']:

        param_unq = np.unique(all_param[param])
        qmeas_mean = [np.mean(all_param[qmeas][all_param[param]==pp]) for pp in param_unq]
        
        ax_jp = sns.jointplot(all_param[param], all_param[qmeas], kind='kde', 
                              space=0, n_levels=60, shade_lowest=False)
        ax = ax_jp.ax_joint
        ax.plot([opt_params[param], opt_params[param]], 
                [ax.get_ylim()[0], ax.get_ylim()[1]],
                lw=1.0, alpha=0.75, linestyle='--')

        ax.yaxis.set_ticks_position('left')
        ax.xaxis.set_ticks_position('bottom')
        ax.set_xlabel(param)
        ax.set_ylabel(qmeas)
        
        plt.show()
        plt.close()

# Detect Dynamic Subgraphs

## Stochastic Factorization with Consensus

In [None]:
def refactor_connection_vector(conn_vec):
    n_node = int(np.ceil(np.sqrt(2*len(conn_vec))))
    triu_ix, triu_iy = np.triu_indices(n_node, k=1)
    
    adj = np.zeros((n_node, n_node))
    adj[triu_ix, triu_iy] = conn_vec[...]
    adj += adj.T
    
    return adj


fac_subgraph, fac_coef, err = NMF.optimize.consensus_nmf(A_hat, n_seed=2, n_proc=1,
                                                         opt_alpha=opt_param['alpha'],
                                                         opt_beta=opt_param['beta'],
                                                         opt_rank=opt_param['rank'])
    
fac_subgraph = np.array([refactor_connection_vector(subg)
                         for subg in fac_subgraph])

## Plot all the subgraphs

In [None]:
n_row = int(np.ceil(np.sqrt(fac_subgraph.shape[0])))
n_col = int(np.ceil(fac_subgraph.shape[0] / n_row))

plt.figure(figsize=(12,12))
for ii in range(fac_subgraph.shape[0]):
    ax = plt.subplot(n_row, n_col, ii+1)
    ax.matshow(fac_subgraph[ii, ...] / fac_subgraph.max(), cmap='rainbow')
    ax.set_axis_off()
plt.show()    