In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from idtxl.bivariate_pid import BivariatePID
from idtxl.data import Data

from mesostat.utils.decorators import redirect_stdout
from mesostat.visualization.mpl_colors import base_colors_rgb
from mesostat.metric.dim3d.pid_gaussian import pid_kayince_gaussian

# Append base directory
import os,sys
rootname = "conservative-tripartite-testing"
thispath = os.getcwd()
rootpath = os.path.join(thispath[:thispath.index(rootname)], rootname)
sys.path.append(rootpath)
print("Appended root directory", rootpath)

import src.null_models_3D as null3D
import src.null_test as nulltest

%load_ext autoreload
%autoreload 2

## PID Funictions

In [None]:
decompLabels = ['unq_s1', 'unq_s2', 'shd_s1_s2', 'syn_s1_s2']

In [None]:
def pid(x, y, z):
    return dict(zip(decompLabels, pid_kayince_gaussian(x,y,z)))

In [None]:
contFuncDict = null3D.cont_method_dict()

### Testing binning-dependence

In [None]:
valThrDict = None
# valThrDict = {'unq_s1': 0.08, 'unq_s2': 0.08, 'shd_s1_s2': None, 'syn_s1_s2': 0.16}

In [None]:
taskDict = {
    'yolo': np.array([0.01,0.01,0.01]),
    'norand': np.array([0.01,0.01,0.5]),
    'randx': np.array([0.5,0.01,0.5]),
    'rand': np.array([0.5,0.5,0.5])
}

for taskName, params in taskDict.items():
    print(taskName)
    rezDict = {}

    # Do continuous tests
    for funcName, func in contFuncDict.items():
        print('-', funcName)
        
        f_data   = lambda: func(10000, *params)

        rezDF   = nulltest.run_tests(f_data, pid, decompLabels, nTest=100)
        rezDFsh = nulltest.run_tests(f_data, pid, decompLabels, nTest=100, haveShuffle=True)

        nulltest.plot_test_summary(rezDF, rezDFsh, suptitle=funcName, haveEff=False, valThrDict=valThrDict)
        suffix = '' if valThrDict is None else '_withThr'
        plt.savefig(funcName + '_cont_kaygau_summary_'+taskName+suffix+'.svg')
        plt.show()

### Effect of variance

Continuous

In [None]:
# Do continuous tests
nData = 10000

alphaStratDict = {
    'PureSrc': lambda alpha: [0.001,0.001,alpha],
    'ImpureX': lambda alpha: [alpha,0.001,alpha],
    'Impure' : lambda alpha: [alpha,alpha,alpha],
}

thrMetricDictDict = {
    'H0_orig' : None,
#     'H0_adj' : {'unq_s1': 0.1, 'unq_s2': 0.1, 'red': 0.00061, 'syn': None}
}


for fName, f_data in contFuncDict.items():
    for alphaStratName, alphaFunc in alphaStratDict.items():
        # Plot constant thresholds for PureSrc
        avgRand = alphaStratName == 'PureSrc'
        
        f_data_eff = lambda alpha: f_data(nData, *alphaFunc(alpha))
        
        for h0type, thrMetricDict in thrMetricDictDict.items():
            print(fName, alphaStratName, h0type)

            nulltest.run_plot_param_effect(f_data_eff, pid, decompLabels, fontsize=12,
                                           nStep=1001, nSkipTest=100, nTest=200, alphaRange=(0.001, 1),
                                           avgRand=avgRand, thrMetricDict=thrMetricDict, plotAlphaSq=False)

            suffix = 'n_' + str(nData) + '_' + alphaStratName + '_' + h0type

            plt.savefig(fName + '_cont_kaygau_scatter_vareff_'+suffix+'.svg')
            plt.show()

In [None]:
nData=10000
for fName, f_data in contFuncDict.items():
    print(fName)
    
    f_data_eff = lambda alpha: f_data(n=nData, aX=alpha, aY=alpha, aZ=alpha)
    nulltest.run_plot_param_effect_test(f_data_eff, pid, decompLabels,
                                        nStep=10, nTest=400, alphaRange=(0, 1), valThrDict=valThrDict)
    
    suffix = '' if valThrDict is None else '_withThr'
    plt.savefig(fName + '_r2_vareff_n'+str(nData)+suffix+'.png', dpi=200)
    plt.show()

In [None]:
nData=10000
f_data = lambda alpha: null3D.cont_xor_noisy(n=nData, sigX=alpha, sigY=alpha, sigZ=alpha)
nulltest.run_plot_param_effect_test_single(f_data, pid, decompLabels, 0, nTest=400)

### Effect of number of samples
Continuous

In [None]:
alpha=0.25

alphaStratDict = {
    'PureSrc': [0.001,0.001,alpha],
    'ImpureX': [alpha,0.001,alpha],
    'Impure' : [alpha,alpha,alpha],
}

# nDataArr = (10**np.linspace(2, 4, 10)).astype(int)
# thrLstUnq = [0.2110814493379376, 0.16689265970521838, 0.13685690308881257, 0.11218595017528467, 0.07860180308826488, 0.06210094806194508, 0.0526950184150174, 0.040739636899229006, 0.029856912086735292, 0.02275646333281051]
# thrLstRed = [0.04195798086977931, 0.03047641966741261, 0.01652790508850032, 0.012095615389103559, 0.0066037386073174624, 0.003432593102097955, 0.0018820360731625003, 0.001559288121841136, 0.0007872659689172739, 0.0004156815533111131]
# thrLstSyn = [0.22514955588715468, 0.21503139471261873, 0.19561667035009186, 0.18203585308711429, 0.172750706469084, 0.16579693515602625, 0.16096080804655502, 0.1590930839954668, 0.1559111008893145, 0.15474468603745337]

# thrDictUnq = dict(zip(nDataArr, thrLstUnq))
# thrDictRed = dict(zip(nDataArr, thrLstRed))
# thrDictSyn = dict(zip(nDataArr, thrLstSyn))

thrMetricDictDict = {
    'H0_orig' : None,
#     'H0_adj' : {'unq_s1': 0.08, 'unq_s2': 0.08, 'shd_s1_s2': None, 'syn_s1_s2': 0.16}
#     'H0_adj' : {'unq_s1': thrDictUnq, 'unq_s2': thrDictUnq, 'shd_s1_s2': thrDictRed, 'syn_s1_s2': thrDictSyn}
}


for fName, f_data in contFuncDict.items():
    for alphaStratName, alphaFunc in alphaStratDict.items():
        f_data_eff = lambda n: f_data(n, *alphaFunc)

        for h0type, thrMetricDict in thrMetricDictDict.items():
            print(fName, alphaStratName, h0type)

            nulltest.run_plot_data_effect(f_data_eff, pid, decompLabels,
                                          nStep=101, nSkipTest=10, nTest=200, pVal=0.01,
                                          thrMetricDict=thrMetricDict, fontsize=12)

            suffix = 'alpha_' + str(alpha) + '_' + alphaStratName + '_' + h0type

            plt.savefig(fName + '_cont_kaygau_scatter_nEff_'+suffix+'.svg')
            plt.show()

In [None]:
alpha=0.5
for fName, f_data in contFuncDict.items():
    print(fName)

    f_data_eff = lambda n: f_data(n=n, aX=alpha, aY=alpha, aZ=alpha)
    nulltest.run_plot_data_effect_test(f_data_eff, pid, decompLabels,
                                       nStep=10, nTest=400, valThrDict=valThrDict)
    
    suffix = '' if valThrDict is None else '_withThr'
    plt.savefig(fName + '_mmi_nEff_sig'+str(sig)+suffix+'.png', dpi=200)
    plt.show()

### Test relationship of synergy and redundancy for fixed data size

#### 1. Finding max synergy parameters - GridSearch3D

In [None]:
for nData in [1000, 3000, 5000, 7000, 10000]:
    print(nData)
    nulltest.run_gridsearch_3D(null3D.cont_red_noisy, pid, 'syn_s1_s2',
                              varLimits=(0, 2), nData=nData, nStep=20)

#### 2. Finding max synergy parameters - GridSearch1D

Previous analysis found that in all cases maximal synergy is located at the diagonal $\alpha_x = \alpha_y$

In [None]:
tableauColors = base_colors_rgb(key='tableau')

In [None]:
tableauMap = {
    'unq' : tableauColors[0],
    'red' : tableauColors[2],
    'syn' : tableauColors[3]
}

In [None]:
loopLst = [
    ['red', 'unq', 'shd_s1_s2', 'unq_s1',    lambda nData, alpha: null3D.cont_red_noisy(nData, alpha, alpha, alpha)],
    ['red', 'syn', 'shd_s1_s2', 'syn_s1_s2', lambda nData, alpha: null3D.cont_red_noisy(nData, alpha, alpha, 0)],
    ['unq', 'red', 'unq_s1',    'shd_s1_s2', lambda nData, alpha: null3D.cont_unq_noisy(nData, alpha, alpha, alpha)],
    ['unq', 'syn', 'unq_s1',    'syn_s1_s2', lambda nData, alpha: null3D.cont_unq_noisy(nData, alpha, alpha, alpha)],
    ['syn', 'red', 'syn_s1_s2', 'shd_s1_s2', lambda nData, alpha: null3D.cont_xor_noisy(nData, alpha, alpha, alpha)],
    ['syn', 'unq', 'syn_s1_s2', 'unq_s1',    lambda nData, alpha: null3D.cont_xor_noisy(nData, alpha, alpha, alpha)]
]

In [None]:
nDataLst = (10**np.linspace(2, 4, 10)).astype(int)

for labelA, labelB, atomA, atomB, f_data_1D in loopLst:
    prefix = labelA+'_cont_pid_1Dscan_'+labelB
    print(prefix)

    alphaMaxLst = []
    thrAdjLst = []
    thrRandLst = []

    for nData in nDataLst:
        print('--', nData)
        alphaMax, thr = nulltest.run_plot_1D_scan(f_data_1D, pid, atomA, atomB,
                                                  varLimits=(0, 1), nData=nData, nStep=100, nTest=100,
                                                  colorA = tableauMap[labelA], colorB = tableauMap[labelB])
        plt.savefig(prefix+'_n_'+str(nData)+'.svg')
        plt.show()

        # Get also shuffle distribution at this alpha
        datagen_func_noparam = lambda nData: f_data_1D(nData, alphaMax)
        randValues = nulltest.sample_decomp(datagen_func_noparam, pid, atomB,
                                            nData=nData, nSample=10000, haveShuffle=True)

        alphaMaxLst += [alphaMax]
        thrAdjLst += [thr]
        thrRandLst += [np.quantile(randValues, 0.99)]

    plt.figure()
#         plt.plot(nDataLst, alphaMaxLst)
    plt.plot(nDataLst, thrAdjLst, label='adjusted', color='purple')
    plt.plot(nDataLst, thrRandLst, label='shuffle')
    plt.legend()
    plt.ylim([0, None])
    plt.savefig(prefix + '_summary.svg')
    plt.show()

    print(thrAdjLst)

#### 3. Determining Scatter Relationship

In [None]:
discrDataMethodDict = {
    'Cont' : null3D.cont_method_dict(),
    'Discr' : null3D.discr_method_dict()
}

atomCombList = {
    ['shd_s1_s2', 'unq_s1'],
    ['shd_s1_s2', 'syn_s1_s2'],
    ['unq_s1',    'shd_s1_s2'],
    ['unq_s1',    'syn_s1_s2'],
    ['syn_s1_s2', 'shd_s1_s2'],
    ['syn_s1_s2', 'unq_s1']
}

In [None]:
for discrKey, dataMethodsDict in discrDataMethodDict.items():
    for fDataLabel, f_data_3D in dataMethodsDict.items():
        for atomA, atomB in atomCombList:
            nulltest.run_plot_scatter_explore(f_data_3D, f_metric_cont,
                                              atomA, atomB, 3,
                                              varLimits=(0, 1), nData=1000, nTestDim=20)

# Determining testing thresholds for real data

In [None]:
import h5py

In [None]:
# Only test combinations that matter
loopLst = [
    ['red', 'unq', 'shd_s1_s2', 'unq_s1',    lambda nData, alpha: null3D.cont_red_noisy(nData, alpha, alpha, alpha)],
    ['red', 'syn', 'shd_s1_s2', 'syn_s1_s2', lambda nData, alpha: null3D.cont_red_noisy(nData, alpha, alpha, 0)],
    ['unq', 'red', 'unq_s1',    'shd_s1_s2', lambda nData, alpha: null3D.cont_unq_noisy(nData, alpha, alpha, alpha)]
]

# TEX + AUD
nDataLst = [1315, 1209, 3967, 1910, 1724, 4784, 1307, 1324, 5191, 1132, 1014, 3111] + \
           [1070, 510, 2498, 1274, 735, 3407, 1918, 953, 4472, 1008, 630, 2320] + \
           [564, 591, 605, 643, 812, 1040, 1131, 1166, 1263, 1317, 1406, 1412, 1448,
            1525, 1668, 1974, 2438, 2767, 2891, 3228, 3278, 7106, 8209]

In [None]:
for labelA, labelB, atomA, atomB, f_data_1D in loopLst:
    for nData in nDataLst:
        key = labelA + '_' + labelB + '_' + str(nData)
        with h5py.File('pid_rand_dist.h5', 'a') as h5f:
            if key in h5f.keys():
                print(key, 'already done')
                continue
                
        print(key)
        
        randValues = nulltest.run_1D_scan_bare(f_data_1D, f_metric_cont, atomB,
                                               varLimits=(0, 1), nData=nData,
                                               nStep=100, nTest=100, nTestResample=10000)[1]
        
        
        with h5py.File('pid_rand_dist.h5', 'a') as h5f:
            h5f[key] = randValues