In [None]:
# %matplotlib inline
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
import matplotlib as mpl
from matplotlib import gridspec
from scipy import stats, signal
from scipy.signal import hilbert
from numpy.lib.recfunctions import append_fields, merge_arrays
from nilearn import plotting
from pycircstat import mean as circmean
import pycircstat
from matplotlib.backends.backend_pdf import PdfPages

import sys
import pickle
import glob

from tarjan import tarjan
from sklearn.decomposition import PCA
from sklearn.metrics.pairwise import euclidean_distances
from sklearn.linear_model import LinearRegression as LR
import matplotlib.colors as clrs
import matplotlib.cm as cmx
import nilearn.plotting as ni_plot
from mpl_toolkits.axes_grid1 import make_axes_locatable
import pylab as pyl
from PIL import Image
import os
from scipy.spatial.distance import pdist, squareform
from SubjectLevel.par_funcs_fine import *
from scipy.io import loadmat

from ptsa.data.readers import EEGReader
from ptsa.data.filters import ResampleFilter

from joblib import Parallel, delayed
import multiprocessing
import cluster_helper.cluster
import time
from collections import defaultdict
from types import SimpleNamespace

In [2]:
# get ST subjects
pd.set_option('display.max_columns', 500)
pd.set_option('display.max_rows', 1000)

subjs=RAM_helpers.get_subjs_and_montages('pymms')
pd.set_option('display.max_columns', 500)
pd.set_option('display.max_rows', 1000)

#Find subjects with only grids electrodes
Count_Grids = np.empty(0)
for i in range(0,np.shape(subjs)[0],1):
    subject = subjs['subject'].iloc[i]
    montage = subjs['montage'].iloc[i]
    monopol_chans = RAM_helpers.load_elec_info(subject, montage, bipolar=False)
    monopol_chans_types = monopol_chans['type']
    monopol_chans_types_counts_grids = monopol_chans_types.str.count('G')
    monopol_chans_types_counts_grids = sum(np.asarray(monopol_chans_types_counts_grids))
    Count_Grids = np.append(Count_Grids, np.array([monopol_chans_types_counts_grids]), axis=0)
    del subject
    del montage
    del monopol_chans
    del monopol_chans_types
    del monopol_chans_types_counts_grids
    
subjs_grids = subjs.iloc[np.intersect1d(np.nonzero(Count_Grids), np.argwhere(~np.isnan(Count_Grids)))]
subjs_grids = subjs_grids.reset_index(drop=True)
del subjs
stsubjs = subjs_grids
stsubjs['task'] = 'pymms'
stsubjs

Unnamed: 0,subject,montage,task
0,TJ006,0,pymms
1,TJ007,0,pymms
2,TJ008,0,pymms
3,TJ012,0,pymms
4,TJ015,0,pymms
5,TJ020,0,pymms
6,TJ021,0,pymms
7,UP010,0,pymms
8,UP014,0,pymms
9,UP015,0,pymms


In [3]:
#Find minimum sampling frequency
Samp_Freq_All = np.empty(0)
for i in range(0,np.shape(stsubjs)[0],1):
    subject = stsubjs['subject'].iloc[i]
    montage = stsubjs['montage'].iloc[i]
    events = RAM_helpers.load_subj_events('pymms', subject, montage, as_df=True, remove_no_eeg=True)
    events_short = events.iloc[[0]]
    monopol_chans = RAM_helpers.load_elec_info(subject, montage, bipolar=False)
    eeg = RAM_helpers.load_eeg(events_short, rel_start_ms=0, rel_stop_ms=1000, buf_ms=1000, elec_scheme=monopol_chans, noise_freq=[58., 62.], resample_freq=None, pass_band=None, use_mirror_buf=False, demean=True, do_average_ref=True)
    Samp_Freq_All = np.append(Samp_Freq_All, np.array([float(eeg.samplerate)]), axis=0)
    del subject
    del montage
    del events
    del events_short
    del monopol_chans
    del eeg
    
Samp_Freq_All
np.unique(Samp_Freq_All)
Samp_Freq = np.min(np.unique(Samp_Freq_All))
Samp_Freq_All

array([ 400.,  400.,  400.,  400.,  400., 1000., 1000.,  400.,  512.,
        512.,  512.,  512.,  512.,  512.,  512.,  512.,  512.])

In [None]:
for i_subj in range(0, len(stsubjs), 1): 
    
    subject = stsubjs['subject'].iloc[i_subj]
    montage = stsubjs['montage'].iloc[i_subj]
    events_tmp = RAM_helpers.load_subj_events('pymms', subject, montage, as_df=True, remove_no_eeg=True)
    events_tmp_pres = events_tmp[events_tmp['type'].str.startswith('PRES')]
    events_tmp_pres = events_tmp_pres.reset_index(drop=True)

    events_short = events_tmp_pres.iloc[[0]]
    monopol_chans = RAM_helpers.load_elec_info(subject, montage, bipolar=False)
    eeg = RAM_helpers.load_eeg(events_short, rel_start_ms=0, rel_stop_ms=1000, buf_ms=1000, elec_scheme=monopol_chans, noise_freq=[58., 62.], resample_freq=None, pass_band=None, use_mirror_buf=False, demean=True, do_average_ref=True)
    fs = np.array([float(eeg.samplerate)])

    self = SimpleNamespace()
    # whether to load bipolar pairs of electrodes or monopolar contacts
    self.bipolar = False

    # This will load eeg and compute the average reference before computing power. Recommended if bipolar = False.
    self.mono_avg_ref = True

    # power computation settings
    #self.start_time = 0
    self.start_time = 0
    self.end_time = 1000
    self.wave_num = 5
    self.buf_ms = 1000
    self.noise_freq = [58., 62.]
    self.resample_freq = 400.
    self.log_power = True
    self.freqs = np.logspace(np.log10(3), np.log10(40), 200)
    self.mean_over_time = True
    self.time_bins = None
    self.use_mirror_buf = False
    self.pool = None

    # load electrode info
    self.subject = subject
    self.montage = montage
    self.elec_info = RAM_helpers.load_elec_info(self.subject, self.montage, self.bipolar)

    events_for_computation = events_tmp_pres 
    events_for_computation['type'] = 'PRES'

    # compute power with RAM_helper function
    subject_data = RAM_helpers.compute_power(events_for_computation,
                                             self.freqs,
                                             self.wave_num,
                                             self.start_time,
                                             self.end_time,
                                             buf_ms=self.buf_ms,
                                             cluster_pool=self.pool,
                                             log_power=self.log_power,
                                             noise_freq=self.noise_freq,
                                             elec_scheme=self.elec_info,
                                             resample_freq=self.resample_freq,
                                             do_average_ref=self.mono_avg_ref,
                                             mean_over_time=self.mean_over_time,
                                             use_mirror_buf=self.use_mirror_buf,
                                             loop_over_chans=True)


    # default frequency settings for identifying peaks
    self.freqs = np.logspace(np.log10(3), np.log10(40), 200)
    self.bipolar = False
    self.start_time = 0
    self.end_time = 1000
    self.mono_avg_ref = True
    self.hilbert_start=0
    self.hilbert_end=1000

    # window size to find clusters (in Hz)
    self.cluster_freq_range = 2.

    # D: depths, G: grids, S: strips
    self.elec_types_allowed = ['G']

    # spatial distance considered near (mm)
    self.min_elec_dist = 25.

    # If True, osciallation clusters can't cross hemispheres
    self.separate_hemis = True

    # number of electrodes needed to be considered a cluster
    self.min_num_elecs = 6

    # elec_info column from which to extract x,y,z coordinates
    self.elec_pos_column = ''


    xyz = self.elec_info[['{}{}'.format(self.elec_pos_column, coord) for coord in ['x', 'y', 'z']]].values
    if self.separate_hemis:
        if 'Loc1' in self.elec_info:
            xyz[self.elec_info['Loc1']=='Left Cerebrum', 0] -= 100
        else:
            xyz[xyz[:, 0] < 0, 0] -= 100
    elec_dists = squareform(pdist(xyz.astype('int')))

    # figure out which pairs of electrodes are closer than the threshold
    near_adj_matr = (elec_dists < self.min_elec_dist) & (elec_dists > 0.)
    allowed_elecs = np.array([e in self.elec_types_allowed for e in self.elec_info['type']])













    # default frequency settings for identifying peaks
    self.freqs = np.logspace(np.log10(3), np.log10(40), 200)
    self.bipolar = False
    self.start_time = 0
    self.end_time = 1000
    self.mono_avg_ref = True
    self.hilbert_start=0
    self.hilbert_end=1000

    # window size to find clusters (in Hz)
    self.cluster_freq_range = 2.

    # D: depths, G: grids, S: strips
    self.elec_types_allowed = ['G']

    # spatial distance considered near (mm)
    self.min_elec_dist = 25.

    # If True, osciallation clusters can't cross hemispheres
    self.separate_hemis = True

    # number of electrodes needed to be considered a cluster
    self.min_num_elecs = 6

    # elec_info column from which to extract x,y,z coordinates
    self.elec_pos_column = ''


    xyz = self.elec_info[['{}{}'.format(self.elec_pos_column, coord) for coord in ['x', 'y', 'z']]].values
    if self.separate_hemis:
        if 'Loc1' in self.elec_info:
            xyz[self.elec_info['Loc1']=='Left Cerebrum', 0] -= 100
        else:
            xyz[xyz[:, 0] < 0, 0] -= 100
    elec_dists = squareform(pdist(xyz.astype('int')))

    # figure out which pairs of electrodes are closer than the threshold
    near_adj_matr = (elec_dists < self.min_elec_dist) & (elec_dists > 0.)
    allowed_elecs = np.array([e in self.elec_types_allowed for e in self.elec_info['type']])

    # normalized power spectra
    event_dim_str='event'
    self.subject_data = subject_data
    sessions = self.subject_data[event_dim_str].data['session']
    norm_spectra = np.empty(self.subject_data.shape)
    uniq_sessions = np.unique(sessions)
    for sess in uniq_sessions:
        sess_inds = sessions == sess

        m = np.mean(self.subject_data[sess_inds], axis=1)
        m = np.mean(m, axis=0)
        s = np.std(self.subject_data[sess_inds], axis=1)
        s = np.mean(s, axis=0)
        norm_spectra[sess_inds] = (self.subject_data[sess_inds] - m) / s
    p_spect = norm_spectra

    # Compute mean power spectra across events, and then find where each electrode has peaks
    mean_p_spect = np.mean(p_spect, axis=self.subject_data.get_axis_num('event'))
    peaks = par_find_peaks_by_chan2(mean_p_spect, self.freqs, 1)
    # now that we know at which each electrode has peaks, compute clusters of electrodes that exhibit peaks at
    # similar frequencies and are close enough together
    steplength=5;windowLength=15;
    distance=25;
    CLUSTERS=[]
    peaks[:, ~allowed_elecs] = False
    peakid=peaks
    windows=np.stack([[i,windowLength+i] for i in range(0,201-steplength,steplength)])
    for ire in range(10):
        peak_counts=np.sum(peakid,axis=1)
        window_counts=np.array([sum(peak_counts[w[0]:w[1]]) for w in windows])
        peak_window=np.argmax(window_counts)
        near_this_ev = near_adj_matr.copy()
        peak_within_window=np.any(peakid[windows[peak_window,0]:windows[peak_window,1]],axis=0)
        near_this_ev[~peak_within_window,:] = False
        near_this_ev[:, ~peak_within_window] = False
        # use targan algorithm to find the clusters
        graph = {}
        for elec, row in enumerate(near_this_ev):
            graph[elec] = np.where(row)[0]
        groups = tarjan(graph)
        #choose the connected componet with most electrodes as seed
        clusterSeed=sorted(sorted(groups,key=lambda a:-len(a))[0])
        #make it a dictionary
        window_true=np.zeros((200),dtype=bool)
        window_true[windows[peak_window,0]:windows[peak_window,1]]=True
        cluster={}
        if len(clusterSeed)>1:
            for i in clusterSeed:
                peak_freq=np.squeeze(np.where(np.logical_and(peakid[:,i],window_true))[0][0])
                cluster[i]=peak_freq
                peakid[peak_freq,i]=False
        if len(cluster)>1:
            for ire2 in range(10):
                for i in range(len(near_adj_matr)):
                    if i not in cluster:
                        near_freqS=np.squeeze(list(cluster.values()))[near_adj_matr[i,list(cluster.keys())]]
                        if len(near_freqS)>1:
                            window_true=np.zeros((200),dtype=bool)
                            window_true[windows[peak_window,0]:windows[peak_window,1]]=True
                            near_freq=int(np.median(near_freqS))
                            electrode_frequency=np.where(peakid[:,i])[0]
                            if np.any(np.abs(electrode_frequency-near_freq)<15):
                                peak_freq=np.array(min(electrode_frequency, key=lambda x:abs(x-near_freq)))
                                cluster[i]=peak_freq
                                peakid[peak_freq,i]=False
            CLUSTERS.append(cluster)
    for i in CLUSTERS:
        for j in i:
            i[j]=self.freqs[i[j]]
    res={}
    i=0
    while i<(len(CLUSTERS)):
        if len(CLUSTERS[i])>3:
            res[i]=CLUSTERS[i]
        i+=1
    cluster_count = 0
    df_list = []

    for i in res.keys():
        cluster_count += 1
        col_name = 'cluster{}'.format(cluster_count)
        cluster_df = pd.DataFrame(data=np.full(shape=(peaks.shape[1]), fill_value=np.nan), columns=[col_name])
        for j in res[i]:
            cluster_df.iloc[j]=res[i][j]
        df_list.append(cluster_df)
    df = None
    if df_list:
        df = pd.concat(df_list, axis='columns')
        df = pd.concat([df,self.elec_info], axis='columns')
        # add a tag to localization. avg.region for new patients, Loc3 for old
        if 'avg.region' in self.elec_info:
            df['tag'] = self.elec_info['avg.region']
        else:
            df['tag'] = self.elec_info['Loc3']

    self.res = {}        
    self.res['clusters'] = df













    events = events_for_computation
    self.task = stsubjs.iloc[0][2]
    Resample = Samp_Freq
    task_phase = 'PRES'

    # initialize eeg and res
    uniq_sessions = np.unique(self.subject_data.event.data['session'])
    if self.subject[:2]=='FR':
        noise=[48., 52.]
    else:
        noise=[58., 62.]
    buf=1000

    eeg=RAM_helpers.load_eeg(events, self.hilbert_start, self.hilbert_end, buf_ms=buf, elec_scheme=self.elec_info, noise_freq=noise, resample_freq=None, pass_band=None, use_mirror_buf=False, demean=True, do_average_ref=True)
    eeg = ResampleFilter(eeg, Resample).filter()

    print('Band pass EEG')
    time_frame_start=int((buf/1000)*Resample)
    time_frame_end=int(Resample*(self.hilbert_end-self.hilbert_start+buf)/1000)
    eeg=eeg.transpose('channel', 'event', 'time')
    thetas = np.radians(np.arange(0, 360, 5))
    rs = np.radians(np.arange(0, 18, 1))
    theta_r = np.stack([(x, y) for x in thetas for y in rs])
    params = np.stack([theta_r[:, 1] * np.cos(theta_r[:, 0]), theta_r[:, 1] * np.sin(theta_r[:, 0])], -1)
    f=circ_lin_regress
    allowed_elecs = np.array([e in self.elec_types_allowed for e in self.elec_info['type']])
    i=1
    print('Traveling wave analysis')
    while 'cluster{}'.format(i) in self.res['clusters'].columns:
        res={}
        clusterphase=[]
        clusterpower=[]
        clustereeg=[]
        cluster='cluster{}'.format(i)
        for j in range(len(self.res['clusters'])):
            freq=self.res['clusters'][cluster][j]
            if ~np.isnan(freq):
                feeg=RAM_helpers.band_pass_eeg(eeg[j], [freq *.85, freq/.85],order=3)
                clusterphase.append(np.angle(hilbert(feeg, N=feeg.shape[-1], axis=-1)))
                clusterpower.append(np.abs(hilbert(feeg, N=feeg.shape[-1], axis=-1)))
                clustereeg.append(feeg)
        res['power']=np.stack(clusterpower)[:,:,time_frame_start:time_frame_end].astype('float32') 
        res['phase']=np.stack(clusterphase)[:,:,time_frame_start:time_frame_end].astype('float32') 
        res['eeg']=np.stack(clustereeg)[:,:,time_frame_start:time_frame_end].astype('float32') 
        xyz = self.elec_info[['{}{}'.format(self.elec_pos_column, coord) for coord in ['x', 'y', 'z']]].values
        xyz = xyz[~np.isnan(self.res['clusters'][cluster])]
        xyz -= np.mean(xyz, axis=0)
        pca = PCA(n_components=3)
        xyz_PCA = pca.fit_transform(xyz)[:, :2]
        elec_dists = squareform(pdist(xyz_PCA))
        near_adj_matr = (elec_dists < 25.)
        local_angle={}
        local_sf={}
        local_rs={}
        local_off={}
        for j in tqdm(range(0,len(xyz))):
            if sum(near_adj_matr[j])>3:
                norm_coords = xyz_PCA[near_adj_matr[j],:]
                num_iters = int(res['phase'][near_adj_matr[j]].T.shape[0])
                data_as_list = zip(res['phase'][near_adj_matr[j]].T,np.array([norm_coords]*num_iters), [theta_r]*num_iters, [params]*num_iters)
                res_as_list = Parallel(n_jobs=40, verbose=0)(delayed(f)(x[0], x[1], x[2], x[3]) for x in data_as_list)
                local_angle[j]=np.stack([x[0] for x in res_as_list], axis=0).astype('float32') 
                local_sf[j]=np.stack([x[1] for x in res_as_list], axis=0).astype('float32') 
                local_rs[j]=np.stack([x[2] for x in res_as_list], axis=0).astype('float32')
                local_off[j]=np.stack([x[3] for x in res_as_list], axis=0).astype('float32')  
        res['direction'] = local_angle 
        res['spatial_freuency']= local_sf
        res['rs']= local_rs
        res['offs']= local_off
        if cluster not in self.res:
            self.res[cluster]={}
        self.res[cluster][task_phase]=res
        i+=1

    with open('/home1/anup.das/Results/tw_st_files_grids_2D/tw_files/tw_pres_files/tw_subj_' + subject + '_grids_2D_fine.pkl', 'wb') as f:
        pickle.dump(self.res, f)