In [None]:
%config Completer.use_jedi = False

In [None]:
import os
import itertools
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import mne
import json
import scipy.stats

# import scipy.io as sio
# from scipy import signal

import pac

import simple_pipeline

suffix = '_delay_corrected'#'_1ch_nv'
gamma = [20, 80]
beta  = [ 4, 16]

# functions

In [None]:
with open('config.json') as f:
    config = json.load(f)
    channels = config['channels']
colors = ['#1f77b4', '#ff7f0e', '#2ca02c']
linestyles = ['-', ':', '--']

groups = ['PD Med Off', 'PD Med On', 'CTL']
event_types = ['Target', 'Standard', 'Novelty']

In [None]:
def plot_pac(pac, high_freq=gamma, low_freq=beta, ax=None, **kwargs):
    if ax is None:
        fig = plt.figure(figsize=(7, 15))
        ax = fig.subplots()

    im = ax.imshow((pac), origin='lower', interpolation='spline36', #'nearest', 
                   extent=low_freq+high_freq,
#                    aspect='auto', )
                   aspect=np.diff(low_freq)/np.diff(high_freq), **kwargs)

    if ax is None:
        plt.show()
        
    return im

In [None]:
def get_percent(arr, thr=0.95):
    if arr.ndim > 1:
        arr = arr.ravel()
    freq, bins = np.histogram(arr, bins=100)
    return (
        bins[:-1][(freq / freq.sum()).cumsum() > thr][0], 
        bins[1:][(freq / freq.sum()).cumsum() > thr][0]
    )

In [None]:
def save_fig(path):
    directory = os.path.dirname(path)
    if not os.path.exists(directory):
        os.makedirs(directory)
    plt.savefig(path)

# Create Task list in `tasks_df`

In [None]:
if __name__ == '__main__':
    tasks_df = simple_pipeline.create_tasks_df()
    
    completed = []
    for task in tasks_df.iloc:
        json_path = os.path.join(task['dir'], task['file_formatter'].format(f'completed{suffix}.json'))
        completed.append(os.path.exists(json_path))

    tasks_df = pd.concat([tasks_df, pd.DataFrame({'completed': completed})], axis=1)

# Load Data

In [None]:
# MVL

mvl_2ds_time = [[] for k in groups] # np.zeros((3, 64, 169, 37 * 12))
mvl_2ds = [[] for k in groups] # np.zeros((3, 64, 169, 37))
mvls = [[] for k in groups] # np.zeros((3, 64))

for task in tasks_df.iloc:
# if 1:
    task_mvls = np.load(os.path.join(task['dir'], task['file_formatter'].format(f'mvls{suffix}.npz')))
    task_mvl_2ds = np.load(os.path.join(task['dir'], task['file_formatter'].format(f'mvl_2ds{suffix}.npz')))
    
    # mvls
    nbchan = task_mvls[task_mvls.files[0]].shape[0]
    mvl = np.zeros((3, nbchan))
    for i, event_type in enumerate(sorted(task_mvls.files)):
        mvl[i] = task_mvls[event_type].diagonal()
        
    mvls[task.pd_drug_type].append(mvl)
    
    # mvl_2ds
    mvl_2d = np.zeros((3, nbchan, gamma[1] - gamma[0] + 1, beta[1] - beta[0] + 1))
    for i, event_type in enumerate(sorted(task_mvl_2ds.files)):
        mvl_2d[i] = task_mvl_2ds[event_type].diagonal(0, 0, 1).transpose((2, 0, 1))
        
    mvl_2ds[task.pd_drug_type].append(mvl_2d)

mvls = np.array(mvls)              # --> (pd_drug_type, subjects, event_types, channels)
mvl_2ds = np.array(mvl_2ds)        # --> (pd_drug_type, subjects, event_types, channels, high_freqs, low_freqs)
mvl_2ds_time = np.array(mvl_2ds_time)        # --> (pd_drug_type, subjects, event_types, channels, high_freqs, low_freqs * 12)

In [None]:
# ERP

epochs = [[] for k in groups] # np.zeros((2, 63, 601))

for task in tasks_df.iloc:
    task_epochs = np.load(os.path.join(task['dir'], task['file_formatter'].format(f'epochs.npz')))
    
    # epochs
    nbchan = task_epochs[task_epochs.files[0]].shape[-2]
    nbtime = task_epochs[task_epochs.files[0]].shape[-1]
    epoch = np.zeros((3, 2, nbchan, nbtime))
    for i, event_type in enumerate(sorted(task_epochs.files)):
        epoch[i, 0] = task_epochs[event_type].mean(axis=0)
        epoch[i, 1] = task_epochs[event_type].std(axis=0)
        
    epochs[task.pd_drug_type].append(epoch)

epochs = np.array(epochs)          # --> (pd_drug_type, subjects, event_types, (mean, std), channels, time)

# Plot ERP

In [None]:
# Plot for subjects

for ch in range(epochs.shape[4]):
    for drug_type in range(epochs.shape[0]):

        ymin = (epochs[:, :, :, 0, ch, :] - epochs[:, :, :, 1, ch, :]).min()
        ymax = (epochs[:, :, :, 0, ch, :] + epochs[:, :, :, 1, ch, :]).max()

        for sub in range(epochs.shape[1]):
#             fig, axs = plt.subplots(3, 1, sharex=True, sharey=True, figsize=(15, 15))
            plt.figure(figsize=(15, 7))
            
            for event_type in range(epochs.shape[2]):    
                erp = epochs[drug_type, sub, event_type, :, ch, :][0]
                std = epochs[drug_type, sub, event_type, :, ch, :][1]
                plt.plot(np.linspace(-200, 1000, 601), (erp), linewidth=2, color=colors[event_type], label=event_types[event_type])
                plt.plot(np.linspace(-200, 1000, 601), (erp + std), linestyle='--', linewidth=0.5, color=colors[event_type], alpha=0.5)
                plt.plot(np.linspace(-200, 1000, 601), (erp - std), linestyle='--', linewidth=0.5, color=colors[event_type], alpha=0.5)
                
            plt.grid()
            plt.ylim(ymin, ymax)
            plt.xlim(-200, 1000)
            gca = plt.gca()
            plt.plot([0, 0], [ymin, ymax], color='#999999', linestyle='--')
            plt.plot([100, 100], [ymin, ymax], color='#999999', linestyle='--')
            plt.plot([250, 250], [ymin, ymax], color='#999999', linestyle='--')
            plt.sca(gca)
            plt.title(f'Sub {sub + 1}, {groups[drug_type]}, Channel {channels[ch]}')
            plt.legend()
            
            
            if not os.path.exists(os.path.join('plots', f'erps', f'Sub{sub + 1}')):
                os.makedirs(os.path.join('plots', f'erps', f'Sub{sub + 1}'))
            plt.savefig(os.path.join('plots', 'erps', f'Sub{sub + 1}',
                                     f'Sub{sub + 1}_{groups[drug_type]}_{channels[ch]}.png'))
            plt.close()
#           plt.show()
            
#             break
#         break
#     break
    

In [None]:
# Plot events

for ch in range(epochs.shape[4]):
    fig, axs = plt.subplots(3, 1, sharex=True, sharey=True, figsize=(15, 15))
    for drug_type in range(epochs.shape[0]):
        
        for event_type in range(epochs.shape[2]):    
            erp = epochs[drug_type, :, event_type, :, ch, :].mean(axis=0)[0, :]
            std = epochs[drug_type, :, event_type, :, ch, :].std(axis=0)[0, :]
            axs[drug_type].plot(np.linspace(-200, 1000, 601), (erp), linewidth=2,
                                color=colors[event_type], label=event_types[event_type])
#             axs[drug_type].plot(np.linspace(-200, 1000, 601), (erp + std), linestyle='--', linewidth=0.5, color=colors[event_type], alpha=0.5)
#             axs[drug_type].plot(np.linspace(-200, 1000, 601), (erp - std), linestyle='--', linewidth=0.5, color=colors[event_type], alpha=0.5)
        
        axs[drug_type].legend()
        ymin, ymax = axs[drug_type].get_ylim()

#         axs[drug_type].plot([0, 0], [ymin, ymax], color='#999999', linestyle='--')
#         axs[drug_type].plot([100, 100], [ymin, ymax], color='#999999', linestyle='--')
#         axs[drug_type].plot([250, 250], [ymin, ymax], color='#999999', linestyle='--')
        
#         axs[drug_type].set_ylim(ymin, ymax)
        axs[drug_type].set_xlim(-200, 1000)
        axs[drug_type].grid()
        
    for i in range(3):
        axs[i].set_ylabel(groups[i])
    
    plt.tight_layout(pad=4, w_pad=0.5, h_pad=1.0)
    fig.suptitle(f'Channel {channels[ch]}', fontsize=16)

    if not os.path.exists(os.path.join('plots', f'erps', 'channels')):
        os.makedirs(os.path.join('plots', f'erps', 'channels'))
    plt.savefig(os.path.join('plots', 'erps', 'channels', f'{channels[ch]}.png'))
#     plt.close()
    plt.show()
            
#     break

In [None]:
# Plot stimuli
    
for ch in range(epochs.shape[4]):
    fig, axs = plt.subplots(3, 1, sharex=True, sharey=True, figsize=(15, 15))
    for event_type in range(epochs.shape[2]):
        for drug_type in range(epochs.shape[0]):

            erp = epochs[drug_type, :, event_type, :, ch, :].mean(axis=0)[0, :]
            std = epochs[drug_type, :, event_type, :, ch, :].std(axis=0)[0, :]
            axs[event_type].plot(np.linspace(-200, 1000, 601), (erp), linewidth=2, color=colors[drug_type], label=groups[drug_type])
#             axs[drug_type].plot(np.linspace(-200, 1000, 601), (erp + std), linestyle='--', linewidth=0.5, color=colors[event_type], alpha=0.5)
#             axs[drug_type].plot(np.linspace(-200, 1000, 601), (erp - std), linestyle='--', linewidth=0.5, color=colors[event_type], alpha=0.5)
        
        axs[event_type].legend()
        ymin, ymax = axs[event_type].get_ylim()

#         axs[drug_type].plot([0, 0], [ymin, ymax], color='#999999', linestyle='--')
#         axs[drug_type].plot([100, 100], [ymin, ymax], color='#999999', linestyle='--')
#         axs[drug_type].plot([250, 250], [ymin, ymax], color='#999999', linestyle='--')
        
#         axs[drug_type].set_ylim(ymin, ymax)
        axs[event_type].set_xlim(-200, 1000)
        axs[event_type].grid()
        
    for i in range(3):
        axs[i].set_ylabel(event_types[i])
    
    plt.tight_layout(pad=4, w_pad=0.5, h_pad=1.0)
    fig.suptitle(f'Channel {channels[ch]}', fontsize=16)

    if not os.path.exists(os.path.join('plots', f'erps', 'Events')):
        os.makedirs(os.path.join('plots', f'erps', 'Events'))
    plt.savefig(os.path.join('plots', 'erps', 'Events', f'{channels[ch]}.png'))
#     plt.close()
    plt.show()
            
#     break

In [None]:
# Plot all
    
for ch in range(epochs.shape[4]):
    fig = plt.figure(figsize=(15, 7))
    for drug_type in range(epochs.shape[0]):
        for event_type in range(epochs.shape[2]):

            erp = epochs[drug_type, :, event_type, :, ch, :].mean(axis=0)[0, :]
            std = epochs[drug_type, :, event_type, :, ch, :].std(axis=0)[0, :]
            plt.plot(np.linspace(-200, 1000, 601), (erp), linewidth=1, color=colors[drug_type], 
                     label=f'{groups[drug_type]} {event_types[event_type]}', linestyle=linestyles[event_type])
#             axs[drug_type].plot(np.linspace(-200, 1000, 601), (erp + std), linestyle='--', linewidth=0.5, color=colors[event_type], alpha=0.5)
#             axs[drug_type].plot(np.linspace(-200, 1000, 601), (erp - std), linestyle='--', linewidth=0.5, color=colors[event_type], alpha=0.5)
        
    plt.legend()
#     ymin, ymax = plt.ylim()
    plt.xlim(-200, 1000)
    plt.grid()
        
#     for i in range(3):
#         axs[i].set_ylabel(event_types[i])
    
#     plt.tight_layout(pad=4, w_pad=0.5, h_pad=1.0)
    fig.suptitle(f'Channel {channels[ch]}', fontsize=16)

    if not os.path.exists(os.path.join('plots', f'erps', 'all')):
        os.makedirs(os.path.join('plots', f'erps', 'all'))
    plt.savefig(os.path.join('plots', 'erps', 'all', f'{channels[ch]}.png'))
#     plt.close()
    plt.show()
            
#     break

In [None]:
# TTest events

for ch in range(epochs.shape[4]):
    fig, axs = plt.subplots(3, 1, sharex=True, sharey=True, figsize=(15, 15))
    for drug_type in range(epochs.shape[0]):
        for i, grp in enumerate(itertools.combinations(range(epochs.shape[2]), 2)):
            
            erp_a = epochs[drug_type, :, grp[0], 0, ch, :]
            erp_b = epochs[drug_type, :, grp[1], 0, ch, :]
            t, p = scipy.stats.ttest_ind(erp_a, erp_b)
            axs[drug_type].plot(np.linspace(-200, 1000, 601), -np.log2(p), linewidth=2, color=colors[i],
                                label=f'{event_types[grp[0]]} vs {event_types[grp[1]]}')
            
        
        axs[drug_type].plot(np.linspace(-200, 1000, 601), -np.log2(0.05) * np.ones((601, )),
                            linewidth=1, linestyle='--', color='black', label=f'Threshold')
        
        axs[drug_type].legend()
        ymin, ymax = axs[drug_type].get_ylim()

        axs[drug_type].set_xlim(-200, 1000)
        axs[drug_type].grid()
        
    for i in range(3):
        axs[i].set_ylabel(groups[i])
    
    plt.tight_layout(pad=4, w_pad=0.5, h_pad=1.0)
    fig.suptitle(f'Channel {channels[ch]}', fontsize=16)

    if not os.path.exists(os.path.join('plots', f'ttest', 'channels')):
        os.makedirs(os.path.join('plots', f'ttest', 'channels'))
    plt.savefig(os.path.join('plots', 'ttest', 'channels', f'{channels[ch]}.png'))
#     plt.close()
    plt.show()
            
#     break

In [None]:
# TTest stimuli

for ch in range(epochs.shape[4]):
    fig, axs = plt.subplots(3, 1, sharex=True, sharey=True, figsize=(15, 15))
    for event_type in range(epochs.shape[0]):
        for i, grp in enumerate(itertools.combinations(range(epochs.shape[2]), 2)):
            
            erp_a = epochs[grp[0], :, event_type, 0, ch, :]
            erp_b = epochs[grp[1], :, event_type, 0, ch, :]
            t, p = scipy.stats.ttest_ind(erp_a, erp_b)
            axs[event_type].plot(np.linspace(-200, 1000, 601), -np.log2(p), linewidth=2, color=colors[i],
                                label=f'{groups[grp[0]]} vs {groups[grp[1]]}')
            
        
        axs[event_type].plot(np.linspace(-200, 1000, 601), -np.log2(0.05) * np.ones((601, )),
                             linewidth=1, linestyle='--', color='black', label=f'Threshold')
        
        axs[event_type].legend()
        ymin, ymax = axs[event_type].get_ylim()

        axs[event_type].set_xlim(-200, 1000)
        axs[event_type].grid()
        
    for i in range(3):
        axs[i].set_ylabel(event_types[i])
    
    plt.tight_layout(pad=4, w_pad=0.5, h_pad=1.0)
    fig.suptitle(f'Channel {channels[ch]}', fontsize=16)

    if not os.path.exists(os.path.join('plots', f'ttest', 'Events')):
        os.makedirs(os.path.join('plots', f'ttest', 'Events'))
    plt.savefig(os.path.join('plots', 'ttest', 'Events', f'{channels[ch]}.png'))
#     plt.close()
    plt.show()
            
#     break

In [None]:
print(epochs.shape)
print(list(itertools.combinations(range(epochs.shape[2]), 2)))
print(erp_a.shape, erp_b.shape)
print(p.shape)

In [None]:
# TTest stimuli dotted
# --> (pd_drug_type, subjects, event_types, (mean, std), channels, time) (3, 25, 3, 2, 63, 601)
for ch in range(epochs.shape[4]):
    fig, axs = plt.subplots(3, 1, sharex=True, sharey=True, figsize=(15, 15))
    for event_type in range(epochs.shape[2]):
        for i, grp in enumerate(itertools.combinations(range(epochs.shape[0]), 2)):
            
            erp_a = epochs[grp[0], :, event_type, 0, ch, :]
            erp_b = epochs[grp[1], :, event_type, 0, ch, :]
            t, p = scipy.stats.ttest_ind(erp_a, erp_b)
            p[p > 0.05] = 1
            axs[event_type].plot(np.linspace(-200, 1000, 601), -np.log2(p), '.',
                                 linewidth=2, color=colors[i],
                                 label=f'{groups[grp[0]]} vs {groups[grp[1]]}')
            
        
        axs[event_type].plot(np.linspace(-200, 1000, 601), -np.log2(0.05) * np.ones((601, )),
                             linewidth=1, linestyle='--', color='black', label=f'Threshold')
        
        axs[event_type].legend()
        ymin, ymax = axs[event_type].get_ylim()

        axs[event_type].set_xlim(-200, 1000)
        axs[event_type].grid()
        
    for i in range(3):
        axs[i].set_ylabel(event_types[i])
    
    plt.tight_layout(pad=4, w_pad=0.5, h_pad=1.0)
    fig.suptitle(f'Channel {channels[ch]}', fontsize=16)

    if not os.path.exists(os.path.join('plots', f'ttest_dot', 'Events')):
        os.makedirs(os.path.join('plots', f'ttest_dot', 'Events'))
    plt.savefig(os.path.join('plots', 'ttest_dot', 'Events', f'{channels[ch]}.png'))
#     plt.close()
    plt.show()
            
#     break

# Plot  2d PAC 

In [None]:
mvl_2ds = [[] for k in groups] # np.zeros((3, 64, 169, 37))
mvls = [[] for k in groups] # np.zeros((3, 64))

for task in tasks_df.iloc:
# if 1:
    task_mvls = np.load(os.path.join(task['dir'], task['file_formatter'].format(f'mvls{suffix}.npz')))
    task_mvl_2ds = np.load(os.path.join(task['dir'], task['file_formatter'].format(f'mvl_2ds{suffix}.npz')))
    
    # mvls
    nbchan = task_mvls[task_mvls.files[0]].shape[0]
    mvl = np.zeros((3, nbchan))
    for i, event_type in enumerate(sorted(task_mvls.files)):
        mvl[i] = task_mvls[event_type].diagonal()
        
    mvls[task.pd_drug_type].append(mvl)
    
    # mvl_2ds
    mvl_2d = np.zeros((3, nbchan, gamma[1] - gamma[0] + 1, beta[1] - beta[0] + 1))
    for i, event_type in enumerate(sorted(task_mvl_2ds.files)):
        mvl_2d[i] = task_mvl_2ds[event_type].diagonal(0, 0, 1).transpose((2, 0, 1))
        
    mvl_2ds[task.pd_drug_type].append(mvl_2d)

mvls = np.array(mvls)              # --> (pd_drug_type, subjects, event_types, channels)
mvl_2ds = np.array(mvl_2ds)        # --> (pd_drug_type, subjects, event_types, channels, high_freqs, low_freqs)

In [None]:
with open('config.json') as f:
    config = json.load(f)
    channels = config['channels']
    
for ch in range(mvl_2ds.shape[3]):
    fig, axs = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(15, 15))
#     vmin, vmax = get_percent(np.log(mvl_2ds[:, :, :, ch].mean(axis=1)), 0.83)
    for i, j in itertools.product(range(3), range(3)):
        im = plot_pac(mvl_2ds[i, :, j, ch].mean(axis=0), ax=axs[i, j],)
#                       vmin=vmin, vmax=vmax)
        axs[i, j].xaxis.set_visible(False)
        axs[i, j].yaxis.set_visible(False)


    # fig.colorbar(im, ax=axs.ravel().tolist(), location='right', shrink=0.95)
    cbar_ax = fig.add_axes([0.95, 0.15, 0.02, 0.7])
    fig.colorbar(im, cax=cbar_ax)
    # cbar = fig.colorbar(ims[i][j], ax=axs.ra//vel().tolist(), shrink=0.95)

    for i in range(3):
        axs[i, 0].set_ylabel(groups[i])
        axs[0, i].set_title(event_types[i])

        axs[-1, i].xaxis.set_visible(True)
        axs[i, 0].yaxis.set_visible(True)
    fig.suptitle(f'{channels[ch]}', fontsize=16)

    save_fig(os.path.join('plots', f'pac{suffix}', f'pac_{channels[ch]}{suffix}'))
    plt.close(fig)

#     plt.show()

In [None]:
with open('config.json') as f:
    config = json.load(f)
    channels = config['channels']
    
for ch in range(mvl_2ds.shape[3]):
    fig, axs = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(15, 15))
#     vmin, vmax = get_percent(np.log(mvl_2ds[:, :, :, ch].mean(axis=1)), 0.83)
    for i, j in itertools.product(range(3), range(3)):
        im = plot_pac(mvl_2ds[i, :, j, ch].mean(axis=0), ax=axs[i, j],)
#                       vmin=vmin, vmax=vmax)
        axs[i, j].xaxis.set_visible(False)
        axs[i, j].yaxis.set_visible(False)


    # fig.colorbar(im, ax=axs.ravel().tolist(), location='right', shrink=0.95)
    cbar_ax = fig.add_axes([0.95, 0.15, 0.02, 0.7])
    fig.colorbar(im, cax=cbar_ax)
    # cbar = fig.colorbar(ims[i][j], ax=axs.ra//vel().tolist(), shrink=0.95)

    for i in range(3):
        axs[i, 0].set_ylabel(groups[i])
        axs[0, i].set_title(event_types[i])

        axs[-1, i].xaxis.set_visible(True)
        axs[i, 0].yaxis.set_visible(True)
    fig.suptitle(f'{channels[ch]}', fontsize=16)

    save_fig(os.path.join('plots', f'pac{suffix}', f'pac_{channels[ch]}{suffix}'))
    plt.close(fig)

#     plt.show()

In [None]:
with open('config.json') as f:
    config = json.load(f)
    channels = config['channels']
    
for ch in range(mvl_2ds.shape[3]):
    fig, axs = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(15, 15))
#     vmin, vmax = get_percent(np.log(mvl_2ds[:, :, :, ch].mean(axis=1)), 0.83)
    for grp, j in itertools.product(zip(itertools.combinations(range(3), 2), range(3)), range(3)):
        a, b = grp[0]
        i = grp[1]
        t, p = scipy.stats.ttest_ind(mvl_2ds[a, :, j, ch], mvl_2ds[b, :, j, ch])
        p[p>0.05] = 1
        im = plot_pac(-np.log(p), ax=axs[i, j], )
#                       vmin=vmin, vmax=vmax)
        axs[i, j].xaxis.set_visible(False)
        axs[i, j].yaxis.set_visible(False)


    # fig.colorbar(im, ax=axs.ravel().tolist(), location='right', shrink=0.95)
    cbar_ax = fig.add_axes([0.95, 0.15, 0.02, 0.7])
    fig.colorbar(im, cax=cbar_ax)
    # cbar = fig.colorbar(ims[i][j], ax=axs.ra//vel().tolist(), shrink=0.95)

    for grp in zip(itertools.combinations(range(3), 2), range(3)):
        a, b = grp[0]
        i = grp[1]
        
        axs[i, 0].set_ylabel(f'{groups[a]} vs {groups[b]}')
        axs[0, i].set_title(event_types[i])

        axs[-1, i].xaxis.set_visible(True)
        axs[i, 0].yaxis.set_visible(True)
        
    fig.suptitle(f'{channels[ch]}', fontsize=16)

    save_fig(os.path.join('plots', f'pac_pv{suffix}', f'pac_pv_{channels[ch]}{suffix}'))
#     plt.close(fig)

    plt.show()

# Plot topographic PAC

In [None]:
with open('config.json') as f:
    config = json.load(f)
    channels = config['channels']
    
if 'VEOG' in channels: channels.remove('VEOG')

# create_elc_file(task)
montage = mne.channels.read_custom_montage(os.path.join(
    task.dir, task.file_formatter.format('electrodes.elc')))

montage = mne.channels.read_custom_montage('Standard-10-20-Cap81.locs')
n_channels = mvls.shape[-1]
mne_info = mne.create_info(ch_names=channels, sfreq=500., ch_types='eeg')

fig, axs = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(15, 15))
vmin, vmax = mvls.mean(axis=1).min(), mvls.mean(axis=1).max() #get_percent(mvls.mean(axis=1), 0.95)
for i, j in itertools.product(range(3), range(3)):
    data = mvls[i, :, j, :].mean(axis=0).reshape((-1, 1))
    mvl_evoked = mne.EvokedArray(data - mvls.mean(), mne_info)
    mvl_evoked.set_montage(montage)
    mne.viz.plot_topomap(mvl_evoked.data[:, 0], mvl_evoked.info, axes=axs[i, j], show=False,
                         names=channels, show_names=True, vmin=vmin- mvls.mean(), vmax=vmax- mvls.mean())

for i in range(3):
        axs[i, 0].set_ylabel(groups[i])
        axs[0, i].set_title(event_types[i])

        axs[-1, i].xaxis.set_visible(True)
        axs[i, 0].yaxis.set_visible(True)
        
# mne.viz.mne_analyze_colormap(limits=[vmin, (vmin+vmax)/2, vmax], format='mayavi')

plt.show()


In [None]:
with open('config.json') as f:
    config = json.load(f)
    channels = config['channels']
    
if 'VEOG' in channels: channels.remove('VEOG')

# create_elc_file(task)
montage = mne.channels.read_custom_montage(os.path.join(
    task.dir, task.file_formatter.format('electrodes.elc')))

montage = mne.channels.read_custom_montage('Standard-10-20-Cap81.locs')
n_channels = mvls.shape[-1]
mne_info = mne.create_info(ch_names=channels, sfreq=500., ch_types='eeg')

fig, axs = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(15, 15))
# vmin, vmax = mvls.mean(axis=1).min(), mvls.mean(axis=1).max() #get_percent(mvls.mean(axis=1), 0.95)
for grp, j in itertools.product(zip(itertools.combinations(range(3), 2), range(3)), range(3)):
    a, b = grp[0]
    i = grp[1]
    t, p = scipy.stats.ttest_ind(mvls[a, :, j, :], mvls[b, :, j, :])
    p[p>0.05] = 1
    p = p.reshape((-1, 1))
    mvl_evoked = mne.EvokedArray(-np.log(p) - -np.log(p).mean(), mne_info)
    mvl_evoked.set_montage(montage)
    mne.viz.plot_topomap(mvl_evoked.data[:, 0], mvl_evoked.info, axes=axs[i, j], show=False,
                         names=channels, show_names=True)#, vmin=vmin- mvls.mean(), vmax=vmax- mvls.mean())

for grp in zip(itertools.combinations(range(3), 2), range(3)):
    a, b = grp[0]
    i = grp[1]

    axs[i, 0].set_ylabel(f'{groups[a]} vs {groups[b]}')
    axs[0, i].set_title(event_types[i])

    axs[-1, i].xaxis.set_visible(True)
    axs[i, 0].yaxis.set_visible(True)
        

plt.show()

# Plot topographic Time PAC