In [None]:
# Standard libraries
import numpy as np
import h5py
import matplotlib.pyplot as plt
from ipywidgets import interactive

# Append base directory
import os,sys,inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
path1p = os.path.dirname(currentdir)
path2p = os.path.dirname(path1p)
libpath = os.path.join(path1p, "lib")
pwd_mat = os.path.join(os.path.join(path2p, "data/"), "sim_ds_mat")
pwd_h5 = os.path.join(os.path.join(path2p, "data/"), "sim_ds_h5")

sys.path.insert(0, libpath) 
print("Appended library directory", libpath)

# User libraries
from matlab.matlab_yaro_lib import read_mat
from signal_lib import downsample
from corr_lib import sprMat
from qt_wrapper import gui_fnames
#from idtxl_wrapper import idtxlResultsParse

## Processing single point data

In [None]:
def getData(fname):
    filename = os.path.join(pwd_h5, os.path.join("real_data", fname))
    h5f = h5py.File(filename, "r")
    TE = np.copy(h5f['results']['TE_table'])
    lag = np.copy(h5f['results']['delay_table'])
    p = np.copy(h5f['results']['p_table'])
    h5f.close()
    return TE, lag, p

In [None]:
datafilenames = np.array(gui_fnames("IDTXL swipe result files...", directory='./', filter="HDF5 Files (*.h5)"))
basenames     = np.array([os.path.basename(name) for name in datafilenames])
print("Total user files", len(datafilenames))

In [None]:
##################################
# Parse metadata from filenames
##################################

# By Analysis type
isAnalysis = {}
isAnalysis['swipe'] = np.array(["swipe" in name for name in basenames], dtype=int)
isAnalysis['range'] = np.array(["range" in name for name in basenames], dtype=int)
isAnalysis['all'] = isAnalysis['swipe'] + isAnalysis['range'] == 0

# Determine if file uses GO, NOGO, or all
isTrial = {}
isTrial['GO'] = np.array(["iGO" in name for name in basenames], dtype=int)
isTrial['NOGO'] = np.array(["iNOGO" in name for name in basenames], dtype=int)
isTrial['ALL']  = isTrial['GO'] + isTrial['NOGO'] == 0

# Determine range types
isRange = {}
isRange['CUE'] = np.array(["CUE" in name for name in basenames], dtype=int)
isRange['TEX'] = np.array(["TEX" in name for name in basenames], dtype=int)
isRange['LIK'] = np.array(["LIK" in name for name in basenames], dtype=int)
isRange['none'] = isRange['CUE'] + isRange['TEX'] + isRange['LIK'] == 0

# Determine which method was used
isMethod = {}
isMethod['BTE'] = np.array(["BivariateTE" in name for name in basenames], dtype=int)
isMethod['MTE'] = np.array(["MultivariateTE" in name for name in basenames], dtype=int)

# Determine mouse which was used
mouse_names = ["_".join(name.split('_')[:2]) for name in basenames]

print("Selected only swipe files :::", len(datafilenames))
print("By Analysis :::", {k: np.sum(v) for k,v in isAnalysis.items()})
print("By Trial    :::", {k: np.sum(v) for k,v in isTrial.items()})
print("By Range    :::", {k: np.sum(v) for k,v in isRange.items()})
print("By Method   :::", {k: np.sum(v) for k,v in isMethod.items()})
print("By Mouse    :::", {k: mouse_names.count(k) for k in set(mouse_names)})

## Swipe Analysis

In [None]:
isCorrectMouse = np.array([mname == 'mtp_15' for mname in mouse_names], dtype=int)

fig1, ax1 = plt.subplots(nrows=1, ncols=2, figsize=(15,7.5))
totalConnPerConn = []

for i, trial in enumerate(["ALL"]):#["GO","NOGO"]):
    totalConnPerConn.append([])
    for j, method in enumerate(["BTE","MTE"]):
        ax1[j].set_title(method)
        
        totalConnPerConn[-1].append([])
        idxs_ths = isCorrectMouse + isAnalysis["swipe"] + isTrial[trial] + isMethod[method] == 4
        print("For trials", trial,"method", method, "have", np.sum(idxs_ths), "files")
        
        for fname, basename in zip(datafilenames[idxs_ths], basenames[idxs_ths]):
            print("Processing file: ", basename)
            te, lag, p = getData(fname)
            
            times = 0.2 * np.linspace(0, te.shape[2], te.shape[2])
            totalConnPerTime = [np.sum(1-np.isnan(te[:,:,i]).astype(int)) for i in range(te.shape[2])]
            totalConnPerConn[-1][-1].append(np.sum(1-np.isnan(te).astype(int), axis=2).flatten()  / (te.shape[0]**2 - te.shape[0]))
            
            ax1[j].plot(times, totalConnPerTime)
            ax1[j].axhline(y=4 if method=="BTE" else 1, linestyle="--")
            ax1[j].set_xlabel("start_time, seconds")
            ax1[j].set_ylabel(method)

fig2, ax2 = plt.subplots(nrows=1, ncols=2, figsize=(15,7.5))
for i, trial in enumerate(["ALL"]):#["GO","NOGO"]):
    for j, method in enumerate(["BTE","MTE"]):
        
        ax2[j].set_xlabel("connection index, sorted")
        ax2[j].set_ylabel("Frequency of occurence")
        
        thisConn = np.array(totalConnPerConn[i][j])
        sortedArgs = np.flip(np.argsort(np.sum(thisConn, axis=0)))
        
        for conn in totalConnPerConn[i][j]:
            ax2[j].plot(conn[sortedArgs], '.')

## Range Analysis

In [None]:
isCorrectMouse = np.array([mname == 'mtp_15' for mname in mouse_names], dtype=int)

fig1, ax1 = plt.subplots(ncols=2, figsize=(10, 5))
fig2, ax2 = plt.subplots(nrows=2, ncols=3, figsize=(15, 5), tight_layout=True)

for i, trial in enumerate(["ALL"]):
    for j, method in enumerate(["BTE","MTE"]):
        ax2[j][0].set_ylabel(method)
        
        for k, rng in enumerate(["CUE", "TEX", "LIK"]):
            ax2[0][k].set_title(rng)
            
            idxs_ths = isCorrectMouse + isAnalysis["range"] + isTrial[trial] + isMethod[method] + isRange[rng] == 5
            print("For trials", trial,"range", rng,"method", method, "have", np.sum(idxs_ths), "files")
            
            nConn = []
            actSum = np.zeros((12,12))
            for fname, basename in zip(datafilenames[idxs_ths], basenames[idxs_ths]):
#                 print("Processing file: ", basename)
                te, lag, p = getData(fname)

                isActive = 1-np.isnan(te).astype(int)
                actSum += isActive
                nConn += [np.sum(isActive) / (te.shape[0]**2 - te.shape[0])]
#                 totalConnPerConn[-1][-1].append(np.sum(1-np.isnan(te).astype(int), axis=2).flatten()  / (te.shape[0]**2 - te.shape[0]))

            ax1[j].plot(nConn, label=rng)
            ax2[j][k].imshow(actSum)
    
        ax1[j].set_title(method)
        ax1[j].set_xlabel("day/file")
        ax1[j].set_ylabel("ratio of active connections")
        ax1[j].legend()


        

In [None]:
# for model in ["BivariateTE", "MultivariateTE"]:
#     for i in range(16, 20):
#         for rng in ["CUE", "TEX", "LIK"]:
#             fname = "mtp_15_2018_05_" + str(i) + "_a_" + model + "_range_" + rng + ".h5"
#             te, lag, p = getData(fname)
            
#             fig, ax = plt.subplots(ncols = 3, figsize=(15, 5))
#             ax[0].imshow(te[:,:], cmap="jet", vmin=0, vmax=1)
#             ax[1].imshow(lag[:,:], cmap="jet", vmin=1, vmax=5)
#             ax[2].imshow(p[:,:], cmap="jet", vmin=0, vmax=1)
#             ax[0].set_title("TE")
#             ax[1].set_title("delay")
#             ax[2].set_title("p-value")
            
#             fig.suptitle(fname)
#             plt.show()

## Processing swipe Data

In [None]:
bte_te, bte_lag, bte_p = getData("mtp_15_2018_05_19_a_BivariateTE_swipe.h5")
mte_te, mte_lag, mte_p = getData("mtp_15_2018_05_19_a_MultivariateTE_swipe.h5")

# Interactive
def f(i):
    fig, ax = plt.subplots(nrows = 2, ncols = 3, figsize=(12, 8))
    ax[0][0].imshow(bte_te[:,:,i], cmap="jet", vmin=0, vmax=1)
    ax[1][0].imshow(mte_te[:,:,i], cmap="jet", vmin=0, vmax=1)
    ax[0][1].imshow(bte_lag[:,:,i], cmap="jet", vmin=1, vmax=5)
    ax[1][1].imshow(mte_lag[:,:,i], cmap="jet", vmin=1, vmax=5)
    ax[0][2].imshow(bte_p[:,:,i], cmap="jet", vmin=0, vmax=1)
    ax[1][2].imshow(mte_p[:,:,i], cmap="jet", vmin=0, vmax=1)
    
    ax[0][0].set_ylabel("Bivariate")
    ax[1][0].set_ylabel("Multivatiate")
    ax[0][0].set_title("TE")
    ax[0][1].set_title("delay")
    ax[0][2].set_title("p-value")
    plt.show()
    
interactive_plot = interactive(f, i=(0, bte_te.shape[2]-1, 1))
output = interactive_plot.children[-1]
output.layout.height = '500px'
interactive_plot