In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from idtxl.bivariate_pid import BivariatePID
from idtxl.data import Data

from mesostat.utils.decorators import redirect_stdout
from mesostat.visualization.mpl_colors import base_colors_rgb

# Append base directory
import os,sys #,inspect
rootname = "pub-2020-exploratory-analysis"
#thispath = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
thispath = os.getcwd()
rootpath = os.path.join(thispath[:thispath.index(rootname)], rootname)
sys.path.append(rootpath)
print("Appended root directory", rootpath)

import lib.nullmodels.null_models_3D as null3D
import lib.nullmodels.null_test as nulltest

%load_ext autoreload
%autoreload 2

## PID Funictions

[] TODO: Move to library class

In [None]:
decompLabels = ['unq_s1', 'unq_s2', 'shd_s1_s2', 'syn_s1_s2']

In [None]:
def bin_data_1D(data, nBins):
    boundaries = np.quantile(data, np.linspace(0, 1, nBins + 1))
    boundaries[-1] += 1.0E-10
    return np.digitize(data, boundaries, right=False) - 1


def pid_bin(x, y, z, nBins=4):
    dataEff = np.array([
        bin_data_1D(x, nBins),
        bin_data_1D(y, nBins),
        bin_data_1D(z, nBins)
    ])

    return pid(dataEff)


@redirect_stdout
def pid(dataPS):
    settings = {'pid_estimator': 'TartuPID', 'lags_pid': [0, 0]}

    dataIDTxl = Data(dataPS, dim_order='ps', normalise=False)
    pid = BivariatePID()
    rez = pid.analyse_single_target(settings=settings, data=dataIDTxl, target=2, sources=[0, 1])
    rezTrg = rez.get_single_target(2)

    # Getting rid of negative and very low positive PID's.
    # Statistical tests behave unexplectedly - perhaps low values contaminated by roundoff errors?
    return {k : np.clip(rezTrg[k], 1.0E-6, None) for k in decompLabels}

## Models
### Noisy Redundant Scenario

We want to check if white noise added to a purely redundant scenario results in correct identification of redundancy

$$X = T + \nu_X$$
$$Y = T + \nu_Y$$
$$Z = T + \nu_Z$$

where $Y$ is the target of $X$ and $Z$, and

$$T \sim \mathcal{N}(0, 1)$$
$$\nu_X, \nu_Y, \nu_Z \sim \mathcal{N}(0, \sigma)$$

and $\sigma$ is a free parameter, denoting the Noise-To-Signal ratio. So the signal should be a mixture of redundant signal and white noise.

Since the signal is continuous, we bin it using different bin counts.

### Noisy Unique Scenario

Same as before, but

$$X = T + \nu_X$$
$$Y = T + \nu_Y$$
$$Z = \nu_Z$$

### Noisy Redundant Scenario - Discrete Case

It is important to test if false positives are caused by binning, or are an intrinsic property of the noise in the covariate. Here I propose a discretized noisy redundancy model. Instead of added noise, each variable has a random chance to produce the redundant outcome or a purely random outcome.

$$X \sim A_X \nu_X + (1 - A_X) T $$
$$Y \sim A_Y \nu_Y + (1 - A_Y) T $$
$$Z \sim A_Z \nu_Z + (1 - A_Z) T $$

where

$$T, \nu_X, \nu_Y, \nu_Z \sim Ber(0.5) $$
$$A_X \sim Ber(\alpha_X)$$
$$A_Y \sim Ber(\alpha_Y)$$
$$A_Z \sim Ber(\alpha_Z)$$

and $\alpha_X, \alpha_Y, \alpha_Z \in [0, 1]$ are flexible.

So, $\alpha = 0$ means purely redundant signal, and $\alpha=1$ means purely noisy signal.

In [None]:
discrFuncDict = null3D.discr_method_dict()
contFuncDict = null3D.cont_method_dict()

### Testing binning-dependence

In [None]:
# valThrDict = None
valThrDict = {'unq_s1': 0.08, 'unq_s2': 0.08, 'shd_s1_s2': None, 'syn_s1_s2': 0.16}

In [None]:
taskDict = {
    'yolo':    0.5*np.array([0,0,0]),
    'norand':  0.5*np.array([0,0,1]),
    'randx':   0.5*np.array([1,0,1]),
    'rand':    0.5*np.array([1,1,1])
}

for taskName, params in taskDict.items():
    print(taskName)
    rezDict = {}

    # Do continuous tests
    for funcName, func in contFuncDict.items():
        print('-', funcName)
        
#         for nBins in range(2, 6):
        for nBins in [2]:
            f_data   = lambda: func(10000, *params)
            f_metric = lambda x, y, z: pid_bin(x,y,z, nBins)

            rezDF   = nulltest.run_tests(f_data, f_metric, decompLabels, nTest=100)
            rezDFsh = nulltest.run_tests(f_data, f_metric, decompLabels, nTest=100, haveShuffle=True)

            rezDict[(funcName, nBins)] = (rezDF, rezDFsh)
                        
    # Do discrete tests
    f_metric = lambda x, y, z: pid(np.array([x,y,z]))
    for funcName, func in discrFuncDict.items():
        f_data = lambda: func(10000, *(0.5*params))
        rezDF   = nulltest.run_tests(f_data, f_metric, decompLabels, nTest=100)
        rezDFsh = nulltest.run_tests(f_data, f_metric, decompLabels, nTest=100, haveShuffle=True)

        rezDict[('red_discr', 2)] = (rezDF, rezDFsh)

    for k, v in rezDict.items():
        print(k)
        funcName, nBin = k
        rezDF, rezDFsh = v

        nulltest.plot_test_summary(rezDF, rezDFsh, suptitle=funcName, haveEff=False, valThrDict=valThrDict)
        suffix = '' if valThrDict is None else '_withThr'
        plt.savefig(funcName + '_pid_nbin'+str(nBin)+'_summary_'+taskName+suffix+'.png', dpi=200)
        plt.show()

### Effect of variance

Continuous

In [None]:
nBin = 2
f_metric_cont = lambda x, y, z: pid_bin(x,y,z, nBin)
f_metric_discr = lambda x, y, z: pid(np.array([x,y,z]))

In [None]:
# Do continuous tests
nSample = 10000

alphaStratDict = {
    'PureSrc': lambda alpha: [0,0,alpha],
    'ImpureX': lambda alpha: [alpha,0,alpha],
    'Impure' : lambda alpha: [alpha,alpha,alpha],
}

thrMetricDictDict = {
    'H0_orig' : None,
    'H0_adj' : {'unq_s1': 0.08, 'unq_s2': 0.08, 'shd_s1_s2': None, 'syn_s1_s2': 0.16}
}

for fName, f_data in contFuncDict.items():
    for alphaStratName, alphaFunc in alphaStratDict.items():
        
        f_data_eff = lambda alpha: f_data(nSample, *alphaFunc(alpha))
        
        for h0type, thrMetricDict in thrMetricDictDict.items():
            print(fName, alphaStratName, h0type)

            nulltest.run_plot_param_effect(f_data_eff, f_metric_cont, decompLabels,
                                           nStep=1001, nSkipTest=100, nTest=200, alphaRange=(0, 1),
                                           thrMetricDict=thrMetricDict, plotAlphaSq=False, fontsize=12)

            suffix = 'n_' + str(nSample) + '_' + alphaStratName + '_' + h0type

            plt.savefig(fName + '_pid_cont_nBin_'+str(nBin)+'_scatter_vareff_'+suffix+'.png', dpi=300)
            plt.show()

In [None]:
nSample=10000
for fName, f_data in contFuncDict.items():
    print(fName)
    
    f_data_eff = lambda alpha: f_data(n=nSample, sigX=alpha, sigY=alpha, sigZ=alpha)
    nulltest.run_plot_param_effect_test(f_data_eff, f_metric_cont, decompLabels,
                                        nStep=10, nTest=400, alphaRange=(0, 2), valThrDict=valThrDict)
    
    suffix = '' if valThrDict is None else '_withThr'
    plt.savefig(fName + '_pid_nBin2_vareff_n'+str(nSample)+suffix+'.png', dpi=200)
    plt.show()

In [None]:
nSample=10000
f_data = lambda alpha: null3D.cont_xor_noisy(n=nSample, sigX=alpha, sigY=alpha, sigZ=alpha)
nulltest.run_plot_param_effect_test_single(f_data, f_metric_cont, decompLabels, 0, nTest=400)

Discrete

In [None]:
# Do continuous tests
nSample = 10000

alphaStratDict = {
    'PureSrc': lambda alpha: [0,0,alpha],
    'ImpureX': lambda alpha: [alpha,0,alpha],
    'Impure' : lambda alpha: [alpha,alpha,alpha],
}

thrMetricDictDict = {
    'H0_orig' : None
#     'H0_adj' : {'unq_s1': 0.08, 'unq_s2': 0.08, 'shd_s1_s2': None, 'syn_s1_s2': 0.16}
}

for fName, f_data in discrFuncDict.items():
    for alphaStratName, alphaFunc in alphaStratDict.items():
        
        f_data_eff = lambda alpha: f_data(nSample, *alphaFunc(alpha))
        
        for h0type, thrMetricDict in thrMetricDictDict.items():
            print(fName, alphaStratName, h0type)

            nulltest.run_plot_param_effect(f_data_eff, f_metric_discr, decompLabels,
                                           nStep=1001, nSkipTest=100, nTest=200, alphaRange=(0, 1),
                                           thrMetricDict=thrMetricDict, fontsize=12)

            suffix = 'n_' + str(nSample) + '_' + alphaStratName + '_' + h0type

            plt.savefig(fName + '_pid_discr_nBin_'+str(nBin)+'_scatter_vareff_'+suffix+'.png', dpi=300)
            plt.show()

In [None]:
nSample=10000
for fName, f_data in discrFuncDict.items():
    f_data_eff = lambda alpha: f_data(nSample=nSample, alphaX=alpha, alphaY=alpha, alphaZ=alpha)
    nulltest.run_plot_param_effect_test(f_data_eff, f_metric_discr, decompLabels,
                                        nStep=10, nTest=400, alphaRange=(0, 1), valThrDict=valThrDict)

    suffix = '' if valThrDict is None else '_withThr'
    plt.savefig(fName + '_pid_vareff_n'+str(nSample)+suffix+'.png', dpi=200)
    plt.show()

### Effect of number of samples
Continuous

In [None]:
alpha=0.5

alphaStratDict = {
    'PureSrc': [0,0,alpha],
    'ImpureX': [alpha,0,alpha],
    'Impure' : [alpha,alpha,alpha],
}

thrMetricDictDict = {
    'H0_orig' : None,
    'H0_adj' : {'unq_s1': 0.08, 'unq_s2': 0.08, 'shd_s1_s2': None, 'syn_s1_s2': 0.16}
}


for fName, f_data in contFuncDict.items():
    for alphaStratName, alphaFunc in alphaStratDict.items():
        f_data_eff = lambda n: f_data(n, *alphaFunc)

        for h0type, thrMetricDict in thrMetricDictDict.items():
            print(fName, alphaStratName, h0type)

            nulltest.run_plot_data_effect(f_data_eff, f_metric_cont, decompLabels,
                                          nStep=101, nSkipTest=10, nTest=200, pVal=0.01,
                                          thrMetricDict=thrMetricDict, fontsize=12)

            suffix = 'alpha_' + str(alpha) + '_' + alphaStratName + '_' + h0type

            plt.savefig(fName + '_pid_cont_nBin_'+str(nBin)+'_scatter_nEff_'+suffix+'.png', dpi=300)
            plt.show()

In [None]:
alpha=0.5
for fName, f_data in contFuncDict.items():
    print(fName)

    f_data_eff = lambda n: f_data(n=n, aX=alpha, aY=alpha, aZ=alpha)
    nulltest.run_plot_data_effect_test(f_data_eff, f_metric_cont, decompLabels,
                                       nStep=10, nTest=400, valThrDict=valThrDict)
    
    suffix = '' if valThrDict is None else '_withThr'
    plt.savefig(fName + '_pid_nBin2_nEff_sig'+str(sig)+suffix+'.png', dpi=200)
    plt.show()

Discrete

In [None]:
alpha=0.5

alphaStratDict = {
    'PureSrc': [0,0,alpha],
    'ImpureX': [alpha,0,alpha],
    'Impure' : [alpha,alpha,alpha],
}

thrMetricDictDict = {
    'H0_orig' : None
#     'H0_adj' : {'unq_s1': 0.08, 'unq_s2': 0.08, 'shd_s1_s2': None, 'syn_s1_s2': 0.16}
}


for fName, f_data in discrFuncDict.items():
    for alphaStratName, alphaFunc in alphaStratDict.items():
        f_data_eff = lambda n: f_data(n, *alphaFunc)

        for h0type, thrMetricDict in thrMetricDictDict.items():
            print(fName, alphaStratName, h0type)

            nulltest.run_plot_data_effect(f_data_eff, f_metric_discr, decompLabels,
                                          nStep=101, nSkipTest=10, nTest=200, pVal=0.01,
                                          thrMetricDict=thrMetricDict, fontsize=12)

            suffix = 'alpha_' + str(alpha) + '_' + alphaStratName + '_' + h0type

            plt.savefig(fName + '_pid_discr_nBin_'+str(nBin)+'_scatter_nEff_'+suffix+'.png', dpi=300)
            plt.show()

In [None]:
alpha=0.5
for fName, f_data in discrFuncDict.items():
    f_data_eff = lambda n: f_data(nSample=n, aX=alpha, aY=alpha, aZ=alpha)
    nulltest.run_plot_data_effect_test(f_data_eff, f_metric_discr, decompLabels,
                                       nStep=10, nTest=400, valThrDict=valThrDict)

    suffix = '' if valThrDict is None else '_withThr'
    plt.savefig('redDiscr_pid_nEff_alpha'+str(alpha)+suffix+'.png', dpi=200)
    plt.show()

### Test relationship of synergy and redundancy for fixed data size

#### 1. Finding max synergy parameters - GridSearch3D

In [None]:
for nSample in [1000, 3000, 5000, 7000, 10000]:
    print(nSample)
    nulltest.run_gridsearch_3D(null3D.cont_red_noisy, f_metric_cont, 'syn_s1_s2',
                              varLimits=(0, 2), nSample=nSample, nStep=20)

In [None]:
for nSample in [1000, 3000, 5000, 7000, 10000]:
    print(nSample)
    nulltest.run_gridsearch_3D(null3D.discr_red_noisy, f_metric_discr, 'syn_s1_s2',
                              varLimits=(0, 1), nSample=nSample, nStep=20)

#### 2. Finding max synergy parameters - GridSearch1D

Previous analysis found that in all cases maximal synergy is located at the diagonal $\alpha_x = \alpha_y$

In [None]:
tableauColors = base_colors_rgb(key='tableau')

In [None]:
nSampleLst = 1000 * np.arange(1, 11)
alphaMaxLst = []
thrLst = []

for nSample in nSampleLst:
    print(nSample)
    alphaMax, thr = nulltest.run_plot_1D_scan(null3D.cont_red_noisy, f_metric_cont, 'shd_s1_s2', 'syn_s1_s2',
                                              varLimits=(0, 1), nSample=nSample, nStep=100, nTest=100,
                                              colorA = tableauColors[2], colorB = tableauColors[3])
    plt.savefig('redCont_pid_nbin2_1Dscan_syn_n_'+str(nSample)+'.png', dpi=200)
    plt.show()
    
    alphaMaxLst += [alphaMax]
    thrLst += [thr]
    
plt.figure()
plt.plot(nSampleLst, alphaMaxLst, label='param')
plt.plot(nSampleLst, thrLst, label='thr')
plt.legend()
plt.savefig('redCont_pid_nbin2_1Dscan_syn_summary.png', dpi=200)
plt.show()

In [None]:
nSampleLst = 1000 * np.arange(1, 11)
alphaMaxLst = []
thrLst = []

for nSample in nSampleLst:
    print(nSample)
    alphaMax, thr = nulltest.run_plot_1D_scan(null3D.discr_red_noisy, f_metric_discr, 'shd_s1_s2', 'syn_s1_s2',
                                              varLimits=(0, 1), nSample=nSample, nStep=100, nTest=100,
                                              colorA = tableauColors[2], colorB = tableauColors[3])
    plt.savefig('redDiscr_pid_1Dscan_syn_n_'+str(nSample)+'.png', dpi=200)
    plt.show()
    
    alphaMaxLst += [alphaMax]
    thrLst += [thr]
    
plt.figure()
plt.plot(nSampleLst, alphaMaxLst, label='param')
plt.plot(nSampleLst, thrLst, label='thr')
plt.legend()
plt.savefig('redDiscr_pid_1Dscan_syn_summary.png', dpi=200)
plt.show()

#### 3. Determining Synergy-Redundancy Relationship

In [None]:
nulltest.run_plot_scatter_explore(null3D.cont_red_noisy, f_metric_cont, 'shd_s1_s2', 'syn_s1_s2', 3,
                         varLimits=(0, 0.5), nSample=1000, nTestDim=20)

In [None]:
nulltest.run_plot_scatter_explore(null3D.discr_red_noisy, f_metric_discr, 'shd_s1_s2', 'syn_s1_s2', 3,
                         varLimits=(0, 1), nSample=1000, nTestDim=20)

### Test relationship of unique and redundancy for fixed data size

#### 2. Finding max synergy parameters - GridSearch1D

In [None]:
nSampleLst = 1000 * np.arange(1, 11)
alphaMaxLst = []
thrLst = []

for nSample in nSampleLst:
    print(nSample)
    alphaMax, thr = nulltest.run_plot_1D_scan(null3D.cont_red_noisy, f_metric_cont, 'shd_s1_s2', 'unq_s1',
                                              varLimits=(0, 1), nSample=nSample, nStep=100, nTest=100,
                                              colorA = tableauColors[2], colorB = tableauColors[0])
    plt.savefig('redCont_pid_1Dscan_unq_n_'+str(nSample)+'.png', dpi=200)
    plt.show()
    
    alphaMaxLst += [alphaMax]
    thrLst += [thr]
    
plt.figure()
plt.plot(nSampleLst, alphaMaxLst, label='param')
plt.plot(nSampleLst, thrLst, label='thr')
plt.legend()
plt.savefig('redCont_pid_1Dscan_unq_summary.png', dpi=200)
plt.show()

In [None]:
nSampleLst = 1000 * np.arange(1, 11)
alphaMaxLst = []
thrLst = []

for nSample in nSampleLst:
    print(nSample)
    alphaMax, thr = nulltest.run_plot_1D_scan(null3D.discr_red_noisy, f_metric_discr, 'shd_s1_s2', 'unq_s1',
                                              varLimits=(0, 1), nSample=nSample, nStep=100, nTest=100,
                                              colorA = tableauColors[2], colorB = tableauColors[0])
    plt.savefig('redDiscr_pid_1Dscan_unq_n_'+str(nSample)+suffix+'.png', dpi=200)
    plt.show()
    
    alphaMaxLst += [alphaMax]
    thrLst += [thr]
    
plt.figure()
plt.plot(nSampleLst, alphaMaxLst, label='param')
plt.plot(nSampleLst, thrLst, label='thr')
plt.legend()
plt.savefig('redDiscr_pid_1Dscan_unq_summary_'+suffix+'.png', dpi=200)
plt.show()

#### 3. Determining Unique-Redundancy Relationship

In [None]:
nulltest.run_plot_scatter_explore(null3D.cont_red_noisy, f_metric_cont, 'shd_s1_s2', 'unq_s1', 3,
                         varLimits=(0, 0.5), nSample=1000, nTestDim=20)

In [None]:
nulltest.run_plot_scatter_explore(null3D.discr_red_noisy, f_metric_discr, 'shd_s1_s2', 'unq_s1', 3,
                         varLimits=(0, 1), nSample=1000, nTestDim=20)