# Table of Contents
 <p><div class="lev1 toc-item"><a href="#Initialize-Environment" data-toc-modified-id="Initialize-Environment-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Initialize Environment</a></div><div class="lev1 toc-item"><a href="#Gather-all-data-sets" data-toc-modified-id="Gather-all-data-sets-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Gather all data sets</a></div><div class="lev2 toc-item"><a href="#Get-a-full-subject-list" data-toc-modified-id="Get-a-full-subject-list-21"><span class="toc-item-num">2.1&nbsp;&nbsp;</span>Get a full subject list</a></div><div class="lev2 toc-item"><a href="#Tabulate-Atlas-Data" data-toc-modified-id="Tabulate-Atlas-Data-22"><span class="toc-item-num">2.2&nbsp;&nbsp;</span>Tabulate Atlas Data</a></div><div class="lev2 toc-item"><a href="#Get-structural-connectivity-data" data-toc-modified-id="Get-structural-connectivity-data-23"><span class="toc-item-num">2.3&nbsp;&nbsp;</span>Get structural connectivity data</a></div><div class="lev2 toc-item"><a href="#Get-Channel-information" data-toc-modified-id="Get-Channel-information-24"><span class="toc-item-num">2.4&nbsp;&nbsp;</span>Get Channel information</a></div><div class="lev2 toc-item"><a href="#Get-Baseline-data" data-toc-modified-id="Get-Baseline-data-25"><span class="toc-item-num">2.5&nbsp;&nbsp;</span>Get Baseline data</a></div><div class="lev2 toc-item"><a href="#Stim-data" data-toc-modified-id="Stim-data-26"><span class="toc-item-num">2.6&nbsp;&nbsp;</span>Stim data</a></div><div class="lev2 toc-item"><a href="#Memory-State-data" data-toc-modified-id="Memory-State-data-27"><span class="toc-item-num">2.7&nbsp;&nbsp;</span>Memory State data</a></div><div class="lev1 toc-item"><a href="#Save-all-meta-data-for-downstream-experiments" data-toc-modified-id="Save-all-meta-data-for-downstream-experiments-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Save all meta data for downstream experiments</a></div>

# Initialize Environment

In [None]:
try:
    %load_ext autoreload
    %autoreload 2
    %reset
except:
    print 'NOT IPYTHON'

from __future__ import division

import os
import sys
import glob
import time

import numpy as np
import pandas as pd
import seaborn as sns
import nibabel as nib
import scipy.linalg as scialg
import scipy.stats as stats
import scipy.io as io
import h5py
import pickle as pkl
import matplotlib.pyplot as plt
from matplotlib import rcParams

sys.path.append('/Users/akhambhati/Developer/hoth_research/Echobase')
import Echobase
convert_adj_matr_to_cfg_matr = Echobase.Network.Transforms.configuration.convert_adj_matr_to_cfg_matr
convert_conn_vec_to_adj_matr = Echobase.Network.Transforms.configuration.convert_conn_vec_to_adj_matr

rcParams = Echobase.Plotting.fig_format.update_rcparams(rcParams)
rcParams.update(rcParams)

path_AtlasData = '/Users/akhambhati/Remotes/CORE.MRI_Atlases'
path_ImagingData = '/Users/akhambhati/Remotes/CORE.RAM_Imaging'
path_CoreData = '/Users/akhambhati/Remotes/CORE.PS_Stim'
path_PeriphData = '/Users/akhambhati/Remotes/RSRCH.PS_Stim'
path_ExpData = path_PeriphData + '/e00-Multimodal_Mapping'

for path in [path_CoreData, path_PeriphData, path_ExpData]:
    if not os.path.exists(path):
        print('Path: {}, does not exist'.format(path))
        os.makedirs(path)
        
from sklearn.externals.joblib import Memory
memory = Memory(cachedir=path_ExpData, verbose=0)

# Gather all data sets

## Get a full subject list

In [None]:
@memory.cache
def get_subj_list():
    path_list = glob.glob('{}/Electrode_Info/*.mat'.format(path_CoreData))
    
    subj_list = [path.split('/')[-1].split('.')[0]
                 for path in path_list]
    subj_list = np.array(subj_list)
    
    return subj_list
subj_list = get_subj_list()

## Tabulate Atlas Data

In [None]:
@memory.cache
def get_lausanne_label_mni():
    lausanne_label_mni = {'scale33': {},
                          'scale60': {},
                          'scale125': {},
                          'scale250': {},
                          'scale500': {}}
    
    for scale in lausanne_label_mni.keys():    
        lausanne_labels = pd.read_csv('{}/Lausanne/{}.csv'.format(path_AtlasData, scale))
        atlas = nib.load('{}/Lausanne/lausanne/ROIv_{}_dilated.nii.gz'.format(path_AtlasData, scale))
        atlas_data = atlas.get_data()

        for roi_id in np.unique(atlas_data)[1:]:        
            i,j,k = np.nonzero(atlas_data == roi_id)

            roi_coords = []
            for ii, jj, kk in zip(i, j, k):
                xx, yy, zz = atlas.affine[:3, :3].dot([ii, jj, kk]) + atlas.affine[:3, 3]
                roi_coords.append((xx, yy, zz))
            roi_coords = np.array(roi_coords)

            # Add to coords for ROI label to dict
            sel_lbl = (lausanne_labels['Label_ID'] == roi_id)
            roi_lbl = (lausanne_labels[sel_lbl]['Hemisphere'] + '_' + lausanne_labels[sel_lbl]['ROI']).as_matrix()[0]
            
            lausanne_label_mni[scale][roi_lbl] = roi_coords
            
    return lausanne_label_mni
lausanne_label_mni = get_lausanne_label_mni()

## Get structural connectivity data

In [None]:
@memory.cache
def get_subj_structadj(subj_list):
    
    structadj_dict = {}
    for subj_id in subj_list:
        subj_id_abbr = subj_id.split('_')[0]
        adj_dir = '{}/Adjacency_Matrices/Epilepsy_Subjects/lausanne/{}'.format(path_ImagingData, subj_id_abbr)
        if not os.path.exists(adj_dir):
            continue
        
        structadj_dict[subj_id] = {}
        for scale in lausanne_label_mni.keys():

            for stream_type in ['QA', 'GFA']:
                adj_path = '{}/*ROIv_{}*{}*.mat'.format(adj_dir, scale, stream_type.lower())
                adj_path = glob.glob(adj_path)[0]
                adj = io.loadmat(adj_path)['connectivity']

                # Compute communicability
                D_adj = np.diag(1.0 / np.sqrt(np.sum(adj, axis=0)))
                adj_comm = scialg.expm(np.dot(np.dot(D_adj, adj), D_adj))

                # Populate the subject dict
                structadj_dict[subj_id][scale] = {}            
                structadj_dict[subj_id][scale][stream_type] = {}
                structadj_dict[subj_id][scale][stream_type]['adj'] = adj
                structadj_dict[subj_id][scale][stream_type]['adj_comm'] = adj_comm
                
    return structadj_dict
structadj_dict = get_subj_structadj(subj_list)

## Get Channel information

In [None]:
"""
contains:  
    - Jacksheet #: int
    - Label: (good_channels x 1)
    - Atlas_Label: {'scale': (good_channels x 1)}    
    - MNI_Coord: (good_channels x 3)
    - Dist_Matrix: spatial distances between electrodes
    - Struct_Adj: scale and track_type, electrode adjacency matrix
"""
@memory.cache
def get_subj_channel_info(subj_list):

    channel_dict = {}
    for subj_id in subj_list:
        print(subj_id)
        channel_path = '{}/Electrode_Info/{}.mat'.format(path_CoreData, subj_id)
        df_chan = io.loadmat(channel_path)
            
        channel_dict[subj_id] = {}
        for ecog_reref in ['Bipolar', 'CommonAverage']:    
            
            ### Reorder bipolar electrode id
            electrode_id = df_chan['electrode_id']
            electrode_label = np.array([lbl[0][0] for lbl in df_chan['electrode_label']])       
            electrode_id_reref = np.array(sorted(df_chan['electrode_id_bp'].tolist(),
                                              key=lambda x: (x[0], x[1])))
            if ecog_reref == 'CommonAverage':
                electrode_id_reref = np.unique(electrode_id_reref.reshape(-1))
            n_node = electrode_id_reref.shape[0]

            ### Regenerate re-referenced electrode labels
            electrode_label_reref = []
            try:
                for reref_id in electrode_id_reref:
                    if ecog_reref == 'Bipolar':
                        anode_ix = np.flatnonzero(df_chan['electrode_id'][0, :] == reref_id[0])
                        cathode_ix = np.flatnonzero(df_chan['electrode_id'][0, :] == reref_id[1])

                        electrode_label_reref.append([df_chan['electrode_label'][anode_ix[0], 0][0],
                                                   df_chan['electrode_label'][cathode_ix[0], 0][0]])
                    else:
                        monopolar_ix = np.flatnonzero(df_chan['electrode_id'][0, :] == reref_id)
                        electrode_label_reref.append(df_chan['electrode_label'][monopolar_ix[0], 0][0])
            except:
                print(subj_id + ' -- no channel labels')
            electrode_label_reref = np.array(electrode_label_reref)

            ### MNI Coords
            mni_coords = []
            for reref_id in electrode_id_reref:
                if ecog_reref == 'Bipolar':
                    anode_ix = np.flatnonzero(df_chan['electrode_coords'][:, 0] == reref_id[0])
                    cathode_ix = np.flatnonzero(df_chan['electrode_coords'][:, 0] == reref_id[1])

                    if (len(anode_ix) == 0) or (len(cathode_ix) == 0):
                        mni_coords.append(tuple([np.nan, np.nan, np.nan]))
                    else:
                        mni_coords.append(tuple(0.5*(df_chan['electrode_coords'][anode_ix[0], 1:] +
                                                     df_chan['electrode_coords'][cathode_ix[0], 1:])))            
                else:
                    monopolar_ix = np.flatnonzero(df_chan['electrode_coords'][:, 0] == reref_id)
                    if len(monopolar_ix) == 0:
                        mni_coords.append(tuple([np.nan, np.nan, np.nan]))
                    else:
                        mni_coords.append(tuple((df_chan['electrode_coords'][monopolar_ix[0], 1:])))
            mni_coords = np.array(mni_coords)
            
            ### Check MNI Coordinates are valid
            if len(np.flatnonzero(np.isnan(mni_coords))) > 0:
                print(subj_id + ' -- contains unlocalized channels')

            ### Compute Inter-Channel Distances
            dist_matr = np.zeros((n_node, n_node))
            triu_ix, triu_iy = np.triu_indices(n_node, k=1)
            for ix, iy in zip(triu_ix, triu_iy):
                ix_chan_loc = mni_coords[ix, :]
                iy_chan_loc = mni_coords[iy, :]            
                dist_matr[ix, iy] = np.sqrt(np.sum((ix_chan_loc-iy_chan_loc)**2))            
            dist_matr += dist_matr.T

            ### Assign channel to an ROI using greedy approach
            atlas_label = {}
            atlas_index = {}        
            for scale in lausanne_label_mni.keys():
                roi_names = lausanne_label_mni[scale].keys()
                chan_assign = []
                chan_index = []
                for ch_crd in mni_coords:
                    roi_voxel_dist = []
                    for roi in roi_names:
                        roi_crd = lausanne_label_mni[scale][roi]
                        nearest_voxel = np.min(np.sqrt(np.sum((ch_crd - roi_crd)**2, axis=1)))
                        roi_voxel_dist.append(nearest_voxel)
                    roi_ix = np.argmin(roi_voxel_dist)
                    chan_assign.append(roi_names[roi_ix])
                    chan_index.append(roi_ix) 
                atlas_label[scale] = np.array(chan_assign)
                atlas_index[scale] = np.array(chan_index)            

            ### Populate dict
            channel_dict[subj_id][ecog_reref] = {}
            channel_dict[subj_id][ecog_reref]['Jacksheet'] = electrode_id_reref
            channel_dict[subj_id][ecog_reref]['Jacksheet_Label'] = electrode_label_reref
            channel_dict[subj_id][ecog_reref]['Atlas_Label'] = atlas_label
            channel_dict[subj_id][ecog_reref]['Atlas_Index'] = atlas_index
            channel_dict[subj_id][ecog_reref]['MNI_Coord'] = mni_coords
            channel_dict[subj_id][ecog_reref]['Dist_Matr'] = dist_matr
        
    return channel_dict
channel_info = get_subj_channel_info(subj_list)

## Get Baseline data

In [None]:
@memory.cache
def get_subj_baseline_info(subj_list):
    
    base_dict = {'Subject_ID': [],
                 'Base_ID': [],
                 'Raw_Path': []}
    
    for subj_id in subj_list:
        print(subj_id)
        
        base_paths = glob.glob('{}/Base_Trials/{}.Base_Event.*.mat'.format(path_CoreData, subj_id))
        
        for path in base_paths:
            base_id = int(path.split('/')[-1].split('.')[-2])
            
            ### Populate the dictionary
            base_dict['Subject_ID'].append(subj_id)
            base_dict['Base_ID'].append(base_id)    
            base_dict['Raw_Path'].append(path)
            
    df_baseline = pd.DataFrame(base_dict, columns=base_dict.keys())   
    
    return df_baseline
baseline_info = get_subj_baseline_info(subj_list)

## Stim data

In [None]:
@memory.cache
def get_subj_stim_info(subj_list):

    stim_dict = {'Subject_ID': [],
                 'Event_ID': [],                  
                 'Experiment_ID': [],
                 'Stim_Type': [],
                 'Stim_Freq': [],
                 'Stim_Amp': [],
                 'Stim_Dur': [],
                 'Stim_Anode': [],
                 'Stim_Cathode': [],           
                 'Raw_Path': []}

    for subj_id in subj_list:
        print(subj_id)

        ### Get event table for subject
        event_path = '{}/Event_Table/{}_events.mat'.format(path_CoreData, subj_id)
        if not os.path.exists(event_path):
            continue
        df_event = io.loadmat(event_path)
        
        stim_paths = glob.glob('{}/Stim_Trials/{}.Stim_Event.*.mat'.format(path_CoreData, subj_id))

        ### Iterate over all trials for the subject
        for path in stim_paths:
            event_id = int(path.split('/')[-1].split('.')[-2])

            ### Get Trial Parameters
            sel_event = df_event['events'][0, event_id-1]
            stim_type = sel_event['type'][0].lower()
            if not ((stim_type == 'stimulating') or (stim_type == 'sham')):
                continue        

            stim_exp = sel_event['experiment'][0]        
            stim_amp = sel_event['amplitude'][0, 0]
            stim_freq = sel_event['pulse_frequency'][0, 0]
            stim_dur = sel_event['pulse_duration'][0, 0]            

            stim_anode_jack = sel_event['stimAnode'][0, 0]
            stim_cathode_jack = sel_event['stimCathode'][0, 0]                

            ### Populate the dictionary
            stim_dict['Subject_ID'].append(subj_id)
            stim_dict['Event_ID'].append(event_id)            
            stim_dict['Experiment_ID'].append(stim_exp)
            stim_dict['Stim_Type'].append(stim_type)  
            stim_dict['Stim_Freq'].append(stim_freq)
            stim_dict['Stim_Amp'].append(stim_amp)
            stim_dict['Stim_Dur'].append(stim_dur)            
            stim_dict['Stim_Anode'].append(stim_anode_jack)
            stim_dict['Stim_Cathode'].append(stim_cathode_jack)
            stim_dict['Raw_Path'].append(path)
    
    df_stim = pd.DataFrame(stim_dict, columns=stim_dict.keys())    
    
    return df_stim
stim_info = get_subj_stim_info(subj_list)

## Memory State data

In [None]:
@memory.cache
def get_subj_memory_info(subj_list):
    
    memory_dict = {'Subject_ID': [],
                   'Event_ID': [],
                   'Experiment_ID': [],                   
                   'Pre_Stim_Prob': [],
                   'Post_Stim_Prob': [],                   
                   'Stim_Type': [],
                   'Stim_Freq': [],
                   'Stim_Amp': [],
                   'Stim_Dur': [],
                   'Stim_Anode': [],
                   'Stim_Cathode': [],
                   'Raw_Path': []}
    
    for subj_id in subj_list:
        print(subj_id)

        ### Get event table for subject
        try:
            df_stim_event = io.loadmat('{}/Event_Table/{}_events.mat'.format(path_CoreData, subj_id))
            ev_mstime = np.array([event['mstime'][0,0]
                                  for event in df_stim_event['events'][0, :]])
        except:
            continue
        
        memory_paths = glob.glob('{}/Memory_States_bPolar/{}.*.memory_states.mat'.format(path_CoreData, subj_id))

        ### Iterate over all trials for the subject        
        for path in memory_paths:

            ### Get memory state table for subject
            df_mem_event = io.loadmat(path)
            n_mem_event = df_mem_event['eegoffset'].shape[1]

            ### Iterate over all memory trials for the subject
            for ix in xrange(n_mem_event):
                mem_mstime = df_mem_event['mstime'][0, ix]            
                mem_event_ix = np.flatnonzero(ev_mstime == mem_mstime) 

                if not (len(mem_event_ix) == 1):
                    continue
                mem_event_id = mem_event_ix[0] + 1

                ## Get the event parameter
                sel_event = df_stim_event['events'][0, mem_event_ix[0]]
                stim_type = sel_event['type'][0].lower()

                stim_exp = sel_event['experiment'][0]        
                stim_amp = sel_event['amplitude'][0, 0]
                stim_freq = sel_event['pulse_frequency'][0, 0]
                stim_dur = sel_event['pulse_duration'][0, 0]            

                stim_anode_jack = sel_event['stimAnode'][0, 0]
                stim_cathode_jack = sel_event['stimCathode'][0, 0]


                ### Populate the dictionary
                memory_dict['Subject_ID'].append(subj_id)
                memory_dict['Event_ID'].append(mem_event_id)
                memory_dict['Experiment_ID'].append(stim_exp) 
                memory_dict['Pre_Stim_Prob'].append(df_mem_event['pre_stim_mem_prob'][0, ix])
                memory_dict['Post_Stim_Prob'].append(df_mem_event['post_stim_mem_prob'][0, ix])            
                memory_dict['Stim_Type'].append(stim_type)     
                memory_dict['Stim_Freq'].append(stim_freq)
                memory_dict['Stim_Amp'].append(stim_amp) 
                memory_dict['Stim_Dur'].append(stim_dur)  
                memory_dict['Stim_Anode'].append(stim_anode_jack)
                memory_dict['Stim_Cathode'].append(stim_cathode_jack)
                memory_dict['Raw_Path'].append(path)

    df_memory = pd.DataFrame(memory_dict, columns=memory_dict.keys())    
    
    return df_memory
memory_info = get_subj_memory_info(subj_list)

# Save all meta data for downstream experiments

In [None]:
np.savez('{}/Meta.Electrode_Loc.npz'.format(path_ExpData),
         subj_list=subj_list,
         channel_info=channel_info,
         lausanne_label_mni=lausanne_label_mni,
         structadj_dict=structadj_dict)

memory_info.to_pickle('{}/Meta.Memory_Info.pkl'.format(path_ExpData))
baseline_info.to_pickle('{}/Meta.Baseline_Info.pkl'.format(path_ExpData))
stim_info.to_pickle('{}/Meta.Stim_Info.pkl'.format(path_ExpData))