# Table of Contents
 <p><div class="lev1 toc-item"><a href="#Initialize-Environment" data-toc-modified-id="Initialize-Environment-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Initialize Environment</a></div><div class="lev1 toc-item"><a href="#Load-Toy-Data" data-toc-modified-id="Load-Toy-Data-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Load Toy Data</a></div><div class="lev1 toc-item"><a href="#Measure-Functional-Connectivity" data-toc-modified-id="Measure-Functional-Connectivity-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Measure Functional Connectivity</a></div><div class="lev1 toc-item"><a href="#Optimize-Dynamic-Subgraphs-Parameters" data-toc-modified-id="Optimize-Dynamic-Subgraphs-Parameters-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Optimize Dynamic Subgraphs Parameters</a></div><div class="lev2 toc-item"><a href="#Generate-Cross-Validation-Parameter-Sets" data-toc-modified-id="Generate-Cross-Validation-Parameter-Sets-41"><span class="toc-item-num">4.1&nbsp;&nbsp;</span>Generate Cross-Validation Parameter Sets</a></div><div class="lev2 toc-item"><a href="#Run-NMF-Cross-Validation-Parameter-Sets" data-toc-modified-id="Run-NMF-Cross-Validation-Parameter-Sets-42"><span class="toc-item-num">4.2&nbsp;&nbsp;</span>Run NMF Cross-Validation Parameter Sets</a></div><div class="lev2 toc-item"><a href="#Visualize-Quality-Measures-of-Search-Space" data-toc-modified-id="Visualize-Quality-Measures-of-Search-Space-43"><span class="toc-item-num">4.3&nbsp;&nbsp;</span>Visualize Quality Measures of Search Space</a></div><div class="lev1 toc-item"><a href="#Detect-Dynamic-Subgraphs" data-toc-modified-id="Detect-Dynamic-Subgraphs-5"><span class="toc-item-num">5&nbsp;&nbsp;</span>Detect Dynamic Subgraphs</a></div><div class="lev2 toc-item"><a href="#Stochastic-Factorization-with-Consensus" data-toc-modified-id="Stochastic-Factorization-with-Consensus-51"><span class="toc-item-num">5.1&nbsp;&nbsp;</span>Stochastic Factorization with Consensus</a></div><div class="lev2 toc-item"><a href="#Plot--Subgraphs-and-Spectrotemporal-Dynamics" data-toc-modified-id="Plot--Subgraphs-and-Spectrotemporal-Dynamics-52"><span class="toc-item-num">5.2&nbsp;&nbsp;</span>Plot  Subgraphs and Spectrotemporal Dynamics</a></div>

# Initialize Environment

In [None]:
from __future__ import division

import os
os.environ['MKL_NUM_THREADS'] = '1'
os.environ['NUMEXPR_NUM_THREADS'] = '1'
os.environ['OMP_NUM_THREADS'] = '1'
import sys

# Data manipulation
import numpy as np
import scipy.io as io
import NMF

# Echobase
sys.path.append('../Echobase/')
import Echobase

# Plotting
import matplotlib.pyplot as plt
import seaborn as sns

# Load Toy Data

In [None]:
# df contains the following keys:
#   -- evData contains ECoG with dims: n_sample x n_channels
#   -- Fs contains sampling frequency: 1 x 1
#   -- channel_lbl contains strings of channel labels with dims: n_channels
#   -- channel_ix_soz contains indices of seizure-onset channels: n_soz

df = io.loadmat('./ToyData/Seizure_ECoG.mat')
evData = df['evData']
fs = int(df['Fs'][0,0])
n_sample, n_chan = evData.shape

# Measure Functional Connectivity

In [None]:
def compute_dynamic_windows(n_sample, fs, win_dur=1.0, win_shift=1.0):
    """
        Divide samples into bins based on window duration and shift.
        
        Parameters
        ----------
            n_sample: int
                Number of samples
            fs: int
                Sampling frequency
            win_dur: float
                Duration of the dynamic window
            win_shift: float
                Shift of the dynamic window
    
        Returns
        -------
            win_ix: ndarray with dims: (n_win, n_ix)
    """
    
    n_samp_per_win = int(fs * win_dur)
    n_samp_per_shift = int(fs * win_shift)
    
    curr_ix = 0
    win_ix = []
    while (curr_ix+n_samp_per_win) <= n_sample:
        win_ix.append(np.arange(curr_ix, curr_ix+n_samp_per_win))
        curr_ix += n_samp_per_shift
    win_ix = np.array(win_ix)
    
    return win_ix

# Transform to a configuration matrix (n_window x n_connection)
triu_ix, triu_iy = np.triu_indices(n_chan, k=1)
n_conn = len(triu_ix)

# Measure dynamic functional connectivity using Echobase
#win_bin = compute_dynamic_windows(n_sample, fs)
win_bin = compute_dynamic_windows(fs*100, fs)
n_win = win_bin.shape[0]
n_fft = win_bin.shape[1] // 2

# Notch filter the line-noise
fft_freq = np.linspace(0, fs // 2, n_fft)
notch_60hz = ((fft_freq > 55.0) & (fft_freq < 65.0))
notch_120hz = ((fft_freq > 115.0) & (fft_freq < 125.0))
notch_180hz = ((fft_freq > 175.0) & (fft_freq < 185.0))
fft_freq_ix = np.setdiff1d(np.arange(n_fft),
                           np.flatnonzero(notch_60hz | notch_120hz | notch_180hz))
fft_freq = fft_freq[fft_freq_ix]
n_freq = len(fft_freq_ix)

# Compute dFC
A_tensor = np.zeros((n_win, n_freq, n_conn))
for w_ii, w_ix in enumerate(win_bin):
    evData_hat = evData[w_ix, :]
    evData_hat = Echobase.Sigproc.reref.common_avg_ref(evData_hat)
    
    for tr_ii, (tr_ix, tr_iy) in enumerate(zip(triu_ix, triu_iy)):
        out = Echobase.Pipelines.ecog_network.coherence.mt_coherence(
            df=1.0/fs, xi=evData_hat[:, tr_ix], xj=evData_hat[:, tr_iy],
            tbp=5.0, kspec=9, nf=n_fft,
            p=0.95, iadapt=1,
            cohe=True, freq=True)
        A_tensor[w_ii, :, tr_ii] = out['cohe'][fft_freq_ix]
A_hat = A_tensor.reshape(-1, n_conn)

# Optimize Dynamic Subgraphs Parameters

## Generate Cross-Validation Parameter Sets

In [None]:
def generate_folds(n_win, n_fold):
    """
        Generate folds for cross-validation by randomly dividing the windows
        into different groups for train/test-set.
        
        Parameters
        ----------
            n_win: int
                Number of windows (observations) in the configuration matrix
            n_fold: int
                Number of folds desired
        
        Returns
        -------
            fold_list: list[list]
                List of index lists that can be further divided into train
                and test sets
    """

    # discard incomplete folds
    n_win_per_fold = int(np.floor(n_win / n_fold))  
    
    win_list = np.arange(n_win)
    win_list = np.random.permutation(win_list)
    win_list = win_list[:(n_win_per_fold*n_fold)]
    win_list = win_list.reshape(n_fold, -1)
    fold_list = [list(ff) for ff in win_list]

    return fold_list

fold_list = generate_folds(n_win, n_fold=5)

# Set the bounds of the search space
# Random sampling scheme
param_search_space = {'rank_range': (2, 20),
                      'alpha_range': (0.01, 1.0),
                      'beta_range': (0.01, 1.0),
                      'n_param': 20}

# Get parameter search space
# Each sampled parameter set will be evaluated n_fold times
param_list = NMF.optimize.gen_random_sampling_paramset(
    fold_list=fold_list,
    **param_search_space)

## Run NMF Cross-Validation Parameter Sets

In [None]:
# **This cell block should be parallelized. Takes time to run**

# Produces a list of quality measures for each parameter set in param_list
qmeas_list = [NMF.optimize.run_xval_paramset(A_hat, pdict)
              for pdict in param_list]

## Visualize Quality Measures of Search Space

In [None]:
all_param, opt_params = NMF.optimize.find_optimum_xval_paramset(param_list, qmeas_list, search_pct=5)

# Generate quality measure plots
for qmeas in ['error', 'pct_sparse_subgraph', 'pct_sparse_coef']:
    for param in ['rank', 'alpha', 'beta']:

        param_unq = np.unique(all_param[param])
        qmeas_mean = [np.mean(all_param[qmeas][all_param[param]==pp]) for pp in param_unq]
        
        ax_jp = sns.jointplot(all_param[param], all_param[qmeas], kind='kde', 
                              space=0, n_levels=60, shade_lowest=False)
        ax = ax_jp.ax_joint
        ax.plot([opt_params[param], opt_params[param]], 
                [ax.get_ylim()[0], ax.get_ylim()[1]],
                lw=1.0, alpha=0.75, linestyle='--')

        ax.yaxis.set_ticks_position('left')
        ax.xaxis.set_ticks_position('bottom')
        ax.set_xlabel(param)
        ax.set_ylabel(qmeas)
        
        plt.show()
        plt.close()

# Detect Dynamic Subgraphs

## Stochastic Factorization with Consensus

In [None]:
def refactor_connection_vector(conn_vec):
    n_node = int(np.ceil(np.sqrt(2*len(conn_vec))))
    triu_ix, triu_iy = np.triu_indices(n_node, k=1)
    
    adj = np.zeros((n_node, n_node))
    adj[triu_ix, triu_iy] = conn_vec[...]
    adj += adj.T
    
    return adj


fac_subgraph, fac_coef, err = NMF.optimize.consensus_nmf(A_hat, n_seed=2, n_proc=1,
                                                         opt_alpha=opt_params['alpha'],
                                                         opt_beta=opt_params['beta'],
                                                         opt_rank=opt_params['rank'])
    
fac_subgraph = np.array([refactor_connection_vector(subg)
                         for subg in fac_subgraph])
fac_coef = fac_coef.reshape(-1, n_win, n_freq)

## Plot  Subgraphs and Spectrotemporal Dynamics

In [None]:
n_row = fac_subgraph.shape[0]
n_col = 2

plt.figure(figsize=(12,36))
for fac_ii in xrange(fac_subgraph.shape[0]):
    ax = plt.subplot(n_row, n_col, 2*fac_ii+1)
    ax.matshow(fac_subgraph[fac_ii, ...] / fac_subgraph.max(), cmap='viridis')
    ax.set_axis_off()

    ax = plt.subplot(n_row, n_col, 2*fac_ii+2)
    ax.matshow(fac_coef[fac_ii, ...].T / fac_coef.max(), aspect=n_win/n_freq, cmap='inferno')
    
plt.show()    