In [None]:
# Standard libraries
import os,sys,inspect
import copy
import numpy as np
import matplotlib.pyplot as plt

# Append base directory
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
print(currentdir)

# path1p = os.path.dirname(currentdir)
# libpath = os.path.join(path1p, "lib")

/opt/anaconda3/envs/py36qt5/lib/python3.6/site-packages


In [None]:
# User libraries
from codes.lib.plots.accuracy_plots import testplots
from metrics.graph_lib import offdiag_idx

from models.test_lib import noisePure, noiseLPF, sampleTrials
from models.dyn_sys import DynSys

from fc.te_idtxl_wrapper import idtxlParallelCPU, idtxlResultsParse
from signal_lib import resample
from aux_functions import merge_dicts

%load_ext autoreload
%autoreload 2

### Parameters

In [None]:
def DynSys_func(param):
    zealous_factor = 10
    param_tmp = copy.deepcopy(param)
    param_tmp['N_DATA'] += zealous_factor
    DS = DynSys(param_tmp)
    DS.compute()
    return np.copy(DS.data[:, zealous_factor:])

idtxl_settings = {
    'dim_order'       : 'ps',
    'method'          : 'MultivariateTE',
    'cmi_estimator'   : 'JidtGaussianCMI',
    'max_lag_sources' : 1,
    'min_lag_sources' : 1}

# Set parameters
model_param_noisepure = {
    'method'      : noisePure,
    'N_NODE'      : 12,             # Number of channels 
    'T_TOT'       : 10,             # seconds, Total simulation time
    'DT'          : 0.2,            # seconds, Binned optical recording resolution
    'STD'         : 1               # Standard deviation of random data
}

# Set parameters
model_param_lpfsub = {
    'method'      : noiseLPF,
    'N_NODE'      : 12,             # Number of channels 
    'T_TOT'       : 10,             # seconds, Total simulation time
    'TAU_CONV'    : 0.5,            # seconds, Ca indicator decay constant
    'DT_MICRO'    : 0.001,          # seconds, Neuronal spike timing resolution
    'DT'          : 0.2,            # seconds, Binned optical recording resolution
    'STD'         : 1               # Standard deviation of random data
}

# Set parameters
model_param_dynsys = {
    'method'  : DynSys_func,
    'ALPHA'   : 0.1,  # 1-connectivity strength
    'N_NODE'  : 12,   # Number of variables
    'N_DATA'  : 4000, # Number of timesteps
    'MAG'     : 0.0,    # Magnitude of input
    'T'       : 20,   # Period of input oscillation
    'STD'     : 0.2   # STD of neuron noise
}

model_param_all = {
    "purenoise"   : model_param_noisepure,
    "lpfsubnoise" : model_param_lpfsub,
    "dynsys"      : model_param_dynsys
}

# True connectivity matrix for this problem
N_NODE = 12
DS_TMP = DynSys(model_param_dynsys)
TRUE_CONN_DS = DS_TMP.M.transpose()
TRUE_CONN_DS[TRUE_CONN_DS == 0] = np.nan
TRUE_CONN_DICT = {
    "purenoise"   : np.full((N_NODE, N_NODE), np.nan),
    "lpfsubnoise" : np.full((N_NODE, N_NODE), np.nan),
    "dynsys"   :    TRUE_CONN_DS
}

# Width / Depth analysis

This analysis tests for Type 1 errors - how frequently a fake link is found. The two flavours are
* **Width** - A dataset with a single repetition, but increasing time duration
* **Depth** - A dataset with fixed time duration, but increasing repetition count

In [None]:
%%time
#####################
# Width/Depth Analysis
#####################

N_NODE = 12  # Number of channels
N_STEP = 40  # Number of different data sizes to pick
T_STEP = idtxl_settings['max_lag_sources'] + 1   # Data quantity multiplier 
ndata_lst = (2 * 10**(np.linspace(1.6, 2.9, N_STEP))).astype(int)

idtxl_methods = ['BivariateMI', 'MultivariateMI', 'BivariateTE', 'MultivariateTE']

for analysis in ['width', 'depth']:
    for modelname, model_param in model_param_all.items():
        TRUE_CONN = TRUE_CONN_DICT[modelname]

        ###################################
        # Generate data - takes some time
        ###################################
        data_lst = []
        ndata_eff = np.zeros(N_STEP, dtype=int)
        for i, ndata in enumerate(ndata_lst):
            print("Generating Data", analysis, modelname, ndata)

            model_param['N_NODE'] = N_NODE

            if analysis == 'width':
                # Generate whole data once
                idtxl_settings['dim_order'] = 'ps'
                if 'DT' in model_param:
                    model_param['T_TOT']  = ndata * T_STEP * model_param['DT'] 
                else:
                    model_param['N_DATA'] = ndata * T_STEP
                data = model_param['method'](model_param)
                ndata_eff[i] = data.shape[1]
            else:
                # Generate each trial independently, then concatenate
                idtxl_settings['dim_order'] = 'rsp'
                if 'DT' in model_param:
                    model_param['T_TOT' ] = T_STEP * model_param['DT']
                else:
                    model_param['N_DATA'] = T_STEP

                data = np.array([model_param['method'](model_param) for j in range(ndata)]).transpose((0, 2, 1))
                ndata_eff[i] = data.shape[0]
                
            data_lst += [data]
            
        for method in idtxl_methods:
            idtxl_settings['method'] = method
            te_results = np.zeros((3, N_NODE, N_NODE, N_STEP))
            
            for i, ndata in enumerate(ndata_lst):
                print("Processing Data", analysis, method, modelname, ndata)
                
                # Run calculation
                rez = idtxlParallelCPU(data_lst[i], idtxl_settings)

                # Parse Data
                te_results[..., i] = np.array(idtxlResultsParse(rez, N_NODE, method=method, storage='matrix'))

            # Plot
            fname = modelname + '_' + str(N_NODE) + method + '_' + analysis
            testplots(ndata_eff, te_results, TRUE_CONN, logx=True, percenty=True, pTHR=0.01, h5_fname=fname+'.h5', fig_fname=fname+'.png')

# Signal-to-Noise Analysis

This analysis tests for Type 2 errors - failure to find true links as a function of SNR. The two flavours are 
* **Observational Randomness** - white noise is added to final dataset. Study T2 errors as function of SNR
* **Occurence Randomness** - dataset with fixed white noise is augmented by additional noise-only trials. Simulates the effect not being present in some trials, or happening at a different time. Study T2 errors as function of ratio of good trials.

In [None]:
################
# Generate data
################
print("Generating Data")

idtxl_methods = ['BivariateMI', 'MultivariateMI', 'BivariateTE', 'MultivariateTE']
idtxl_settings['dim_order'] = 'rsp'  # Use multiple repetitions

N_NODE = 12
T_STEP = idtxl_settings['max_lag_sources'] + 1   # Number of timesteps
N_TRIAL = 400  # number of trials

modelname = "dynsys"
model_param = model_param_all["dynsys"]
model_param['N_NODE'] = N_NODE   # Number of channels
model_param['N_DATA'] = T_STEP  # Number of time steps
TRUE_CONN = TRUE_CONN_DICT[modelname]

data = np.array([model_param['method'](model_param) for j in range(N_TRIAL)]).transpose((0, 2, 1))
data /= np.std(data)  # Normalize all data to have unit variance

################
# Analyse
################

N_STEP = 40  # Number of different data sizes to pick

# Flavours
flavours = ['observational', 'occurence']

paramRanges = {
    'observational'  : np.arange(N_STEP) / (N_STEP),
    'occurence'      : np.arange(N_STEP) / (N_STEP-1)
}

dataLst = {
    'observational'  : [snr * data + (1 - snr) * np.random.normal(0, 1, data.shape) for snr in paramRanges['observational']], 
    'occurence'      : [
        np.concatenate(
            (data[:int(freq*N_TRIAL)],
             np.random.normal(0, 1, (int((1-freq)*N_TRIAL),)+data.shape[1:]  )), axis=0)
        for freq in paramRanges['occurence']]
}

for flavour in flavours:
    for method in idtxl_methods:
        idtxl_settings['method'] = method
        te_results = np.zeros((3, N_NODE, N_NODE, N_STEP))

        for i, data in enumerate(dataLst[flavour]):
            print("Processing Data", flavour, paramRanges[flavour][i])

            # Run calculation
            rez = idtxlParallelCPU(dataLst[flavour][i], idtxl_settings)

            # Parse Data
            te_results[..., i] = np.array(idtxlResultsParse(rez, N_NODE, method=method, storage='matrix'))

        # Plot
        fname = modelname + '_' + str(N_NODE) + method + '_' + flavour
        testplots(paramRanges[flavour], te_results, TRUE_CONN, logx=False, percenty=True, pTHR=0.01, h5_fname=fname+'.h5', fig_fname=fname+'.png')

# Window / Lag / Downsample Analysis

This analysis studies T1 error as function of implicit parameters of the analysis. Those parameters are:
* **Window** - How many timesteps are grouped together in a sweep-window to estimate time-dependent TE
* **Lag** - How many timesteps of past history to consider when estimating FC. Lag $<$ Window
* **Downsampling** - T1 error as function of downsampling rate ($\Delta t_2 / \Delta t_1$)

In [None]:
def stat_te(te_results, TRUE_CONN, pTHR=0.01):
    _, N_NODE, _, N_TIMES = te_results.shape
    
    te_results_copy = np.copy(te_results)
    te, lag, p = te_results_copy
    
    p[np.isnan(p)] = 100
    noconn_idx = p > pTHR
    te[noconn_idx] = np.nan
    
    test_fp = merge_dicts([accuracyTests(te[:, :, i], TRUE_CONN) for i in range(N_TIMES)])
    
    mu_dict  = {key : np.mean(val) for key, val in test_fp.items()}
    sig_dict = {key : np.std(val) for key, val in test_fp.items()}
    
    return mu_dict, sig_dict

def statplots(xval, stat, xlabel, fig_fname=None):
    stat_mu  = merge_dicts([s[0] for s in stat])
    stat_sig = merge_dicts([s[1] for s in stat])
    
    #####################################
    # Plot
    #####################################
    
    fig, ax = plt.subplots(figsize=(5, 5))
    ax.set_title("Error rates")
    ax.set_xlabel(xlabel)

    ax.errorbar(xval, stat_mu["FalsePositiveRate"], yerr=stat_sig["FalsePositiveRate"], label='FP_RATE')
    ax.errorbar(xval, stat_mu["FalseNegativeRate"], yerr=stat_sig["FalseNegativeRate"], label='FN_RATE')
    ax.legend()
    
    if fig_fname is not None:
        plt.savefig(fig_fname, dpi=300)
        
    plt.show()

In [None]:
%%time
###############################
# Window / Lag / Downsampling
###############################
'''
TODO:
+) Import resampling
2) Implement stat collection
3) Implement statplots
'''

N_NODE = 12
N_TRIAL = 400
N_DATA = 200
DT = 0.05     # seconds, natural delay time

f_log = open("log.txt", "w")
f_log.close()
def write_log(fname, s):
    with open("log.txt", "a") as f:
        f.write(s + "\n")


idtxl_settings['dim_order'] = 'rsp'
idtxl_methods = ['BivariateMI', 'MultivariateMI', 'BivariateTE', 'MultivariateTE']

for modelname, model_param in model_param_all.items():
    TRUE_CONN = TRUE_CONN_DICT[modelname]

    ###################################
    # Generate data - takes some time
    ###################################
    write_log(f_log, str(["Generating Data", modelname]))

    model_param['N_NODE'] = N_NODE
    model_param['DT'] = DT
    model_param['N_DATA'] = N_DATA
    model_param['T_TOT' ] = N_DATA * DT
    
    # Generate each trial independently, then concatenate
    data = np.array([model_param['method'](model_param) for j in range(N_TRIAL)]).transpose((0, 2, 1))
    
    # FIXME: Crop data to NDATA in case there is tail
    data = data[:, :N_DATA, :]
    
    ###################################
    # Window
    ###################################
    idtxl_settings['min_lag_sources'] = 1
    idtxl_settings['max_lag_sources'] = 1
    
    wlen = np.arange(2, 11)
    for method in idtxl_methods:
        stat = []
        
        for window in wlen:
            write_log(f_log, str(["Processing Window Data", modelname, method]))
            
            idtxl_settings['method'] = method                    
            te_results = np.zeros((3, N_NODE, N_NODE, N_DATA - window))
            
            for iTime in range(N_DATA - window):
                write_log(f_log, str(["--- ", iTime]))
                
                # Run calculation
                rez = idtxlParallelCPU(data[:, iTime:iTime+window, :], idtxl_settings)

                # Parse Data
                te_results[..., iTime] = np.array(idtxlResultsParse(rez, N_NODE, method=method, storage='matrix'))
                
            stat += [stat_te(te_results, TRUE_CONN, pTHR=0.01)]
            
        # Plot
        fname = "window_" + modelname + '_' + str(N_NODE) + method
        statplots(wlen, stat, xlabel="window", fig_fname=fname+'.png')
    
    
    ###################################
    # Lag
    ###################################
    window = 6
    idtxl_settings['min_lag_sources'] = 1
    
    maxlag_lst = np.arange(1, 6)
    for method in idtxl_methods:
        stat = []
        
        for maxlag in maxlag_lst:
            write_log(f_log, str(["Processing Lag Data", modelname, method, maxlag]))
            
            idtxl_settings['method'] = method
            idtxl_settings['max_lag_sources'] = maxlag
            
            te_results = np.zeros((3, N_NODE, N_NODE, N_DATA - window))
            
            for iTime in range(N_DATA - window):
                write_log(f_log, str(["--- ", iTime]))
                
                # Run calculation
                rez = idtxlParallelCPU(data[:, iTime:iTime+window, :], idtxl_settings)

                # Parse Data
                te_results[..., iTime] = np.array(idtxlResultsParse(rez, N_NODE, method=method, storage='matrix'))
                
            stat += [stat_te(te_results, TRUE_CONN, pTHR=0.01)]
            
        # Plot
        fname = "lag_" + modelname + '_' + str(N_NODE) + method
        statplots(maxlag_lst, stat, xlabel="lag", fig_fname=fname+'.png')
    
    
    ###################################
    # Downsampling
    ###################################
    window = 6
    idtxl_settings['min_lag_sources'] = 1
    idtxl_settings['max_lag_sources'] = 5
    
    downsample_times = [DT] + list(np.arange(0.1, 0.6, 0.1))
    
    stat_dict = {method : [] for method in idtxl_methods}
    for dt_down in downsample_times:
        write_log(f_log, str(["Generating Downsample Data for", modelname]))

        idtxl_settings['method'] = method
        param_down = {'method' : 'averaging', 'kind' : 'kernel'}

        # Downsample data
        if dt_down == DT:
            data_downsampled = np.copy(data)
        else:
            times_orig = np.arange(0, model_param['T_TOT'], DT)
            times_down = np.arange(0, model_param['T_TOT'], dt_down)
            N_DATA_DOWN = len(times_down)
            data_downsampled = np.zeros((N_TRIAL, N_DATA_DOWN, N_NODE))

            for iTr in range(N_TRIAL):
                for iNode in range(N_NODE):
                    data_downsampled[iTr, :, iNode] = resample(times_orig, data[iTr, :, iNode], times_down, param_down)

                    
        for method in idtxl_methods:
            te_results = np.zeros((3, N_NODE, N_NODE, N_DATA_DOWN - window))

            for iTime in range(N_DATA_DOWN - window):
                write_log(f_log, str(["--- ", method, iTime]))
                
                # Run calculation
                rez = idtxlParallelCPU(data_downsampled[:, iTime:iTime+window, :], idtxl_settings)

                # Parse Data
                te_results[..., iTime] = np.array(idtxlResultsParse(rez, N_NODE, method=method, storage='matrix'))

            stat_dict[method] += [stat_te(te_results, TRUE_CONN, pTHR=0.01)]

    # Plot
    for method in idtxl_methods:
        fname = "downsample_" + modelname + '_' + str(N_NODE) + method
        statplots(downsample_times, stat_dict[method], xlabel="timestep", fig_fname=fname+'.png')
        
f_log.close()