In [1]:
%config Completer.use_jedi = False

In [170]:
import os
import itertools
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import mne
import json

# import scipy.io as sio
# from scipy import signal

import pac

import simple_pipeline

In [263]:
def plot_pac(pac, high_freq=[32, 200], low_freq=[4, 40], ax=None, **kwargs):
    if ax is None:
        fig = plt.figure(figsize=(7, 15))
        ax = fig.subplots()

    im = ax.imshow((pac), origin='lower', interpolation='nearest', 
                   extent=low_freq+high_freq,
#                    aspect='auto', )
                   aspect=np.diff(low_freq)/np.diff(high_freq), **kwargs)

    if ax is None:
        plt.show()
        
    return im

In [232]:
def get_percent(arr, thr=0.95):
    if arr.ndim > 1:
        arr = arr.ravel()
    freq, bins = np.histogram(arr, bins=100)
    return (
        bins[:-1][(freq / freq.sum()).cumsum() > thr][0], 
        bins[1:][(freq / freq.sum()).cumsum() > thr][0]
    )

In [276]:
def save_fig(path):
    directory = os.path.dirname(path)
    if not os.path.exists(directory):
        os.makedirs(directory)
    plt.savefig(path)

# Create Task list in `tasks_df`

In [4]:
if __name__ == '__main__':
    tasks_df = simple_pipeline.create_tasks_df()
    
    completed = []
    for task in tasks_df.iloc:
        json_path = os.path.join(task['dir'], task['file_formatter'].format('completed.json'))
        completed.append(os.path.exists(json_path))

    tasks_df = pd.concat([tasks_df, pd.DataFrame({'completed': completed})], axis=1)

In [182]:
groups = ['CTL', 'PD Med On', 'PD Med Off']
event_types = ['Target', 'Standard', 'Novelty']
mvl_2ds = [[] for k in groups] # np.zeros((3, 64, 169, 37))
mvls = [[] for k in groups] # np.zeros((3, 64))

for task in tasks_df.iloc:
# if 1:
    task_mvls = np.load(os.path.join(task['dir'], task['file_formatter'].format('mvls.npz')))
    task_mvl_2ds = np.load(os.path.join(task['dir'], task['file_formatter'].format('mvl_2ds.npz')))

    mvl = np.zeros((3, 64))
    for i, event_type in enumerate(sorted(task_mvls.files)):
        mvl[i] = task_mvls[event_type].diagonal()
        
    mvls[task.pd_drug_type].append(mvl)
    
    mvl_2d = np.zeros((3, 64, 169, 37))
    for i, event_type in enumerate(sorted(task_mvl_2ds.files)):
        mvl_2d[i] = task_mvl_2ds[event_type].diagonal(0, 0, 1).transpose((2, 0, 1))
        
    mvl_2ds[task.pd_drug_type].append(mvl_2d)

mvls = np.array(mvls)
mvl_2ds = np.array(mvl_2ds)

In [293]:
with open('config.json') as f:
    config = json.load(f)
    channels = config['channels']
    
for ch in range(mvl_2ds.shape[3]):
    fig, axs = plt.subplots(3, 3, sharex=True, sharey=True, figsize=(15, 15))
    vmin, vmax = get_percent(np.log(mvl_2ds[:, :, :, ch].mean(axis=1)), 0.83)
    ims = ([[None] * 3]) * 3
    for i, j in itertools.product(range(3), range(3)):
        im = plot_pac(mvl_2ds[i, :, j, ch].mean(axis=0), ax=axs[i, j],)
    #              vmin=vmin, vmax=vmax)
        axs[i, j].xaxis.set_visible(False)
        axs[i, j].yaxis.set_visible(False)


    # fig.colorbar(im, ax=axs.ravel().tolist(), location='right', shrink=0.95)
    cbar_ax = fig.add_axes([0.95, 0.15, 0.02, 0.7])
    fig.colorbar(im, cax=cbar_ax)
    # cbar = fig.colorbar(ims[i][j], ax=axs.ra//vel().tolist(), shrink=0.95)

    for i in range(3):
        axs[i, 0].set_ylabel(groups[i])
        axs[0, i].set_title(event_types[i])

        axs[-1, i].xaxis.set_visible(True)
        axs[i, 0].yaxis.set_visible(True)

    save_fig(os.path.join('plots', f'pac_{channels[ch]}'))
    plt.close(fig)

#     plt.show()