In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from collections import defaultdict
from idtxl.bivariate_pid import BivariatePID
from idtxl.data import Data

In [None]:
def bernoulli(n, p):
    return (np.random.uniform(0, 1, n) < p).astype(int)

def gen_discrete_random(nSample, alphaX=0.5, alphaY=0.5, alphaZ=0.5):
    T =   2*bernoulli(nSample, 0.5)
    nuX = 2*bernoulli(nSample, 0.5)
    nuY = 2*bernoulli(nSample, 0.5)
    nuZ = 2*bernoulli(nSample, 0.5)
    aX = bernoulli(nSample, alphaX)
    aY = bernoulli(nSample, alphaY)
    aZ = bernoulli(nSample, alphaZ)
    
    x = (1 - aX)*T + aX*nuX
    y = nuY
    z = (1 - aZ)*T + aZ*nuZ
    return x,y,z

def shuffle(x):
    x1 = x.copy()
    np.random.shuffle(x1)
    return x1

In [None]:
decompLabels = ['unq_s1', 'unq_s2', 'shd_s1_s2', 'syn_s1_s2']

In [None]:
def pid(dataPS):
    settings = {'pid_estimator': 'TartuPID', 'lags_pid': [0, 0]}

    dataIDTxl = Data(dataPS, dim_order='ps', normalise=False)
    pid = BivariatePID()
    rez = pid.analyse_single_target(settings=settings, data=dataIDTxl, target=2, sources=[0, 1])
    rezTrg = rez.get_single_target(2)

    # Getting rid of negative and very low positive PID's.
    # Statistical tests behave unexplectedly - perhaps low values contaminated by roundoff errors?
#     return {k : np.clip(rezTrg[k], 1.0E-6, None) for k in decompLabels}
    return {k : rezTrg[k] for k in decompLabels}

In [None]:
nDataLst = []
rezLst = []
rezSh = []
for nData in (10**np.linspace(3, 5, 1000)).astype(int):
    x,y,z = gen_discrete_random(nData, alphaX=0.1, alphaY=0.1, alphaZ=0.1)
    pidRez = pid(np.array([x,y,z]))
    pidSh = pid(np.array([x,y,shuffle(z)]))
    nDataLst += [nData]
#     nDataLst += [nData, nData]
#     rezLst += [pidRez['unq_s1'], pidRez['unq_s2']]
#     rezSh += [pidSh['unq_s1'], pidSh['unq_s2']]
    rezLst += [pidRez['shd_s1_s2']]
    rezSh += [pidSh['shd_s1_s2']]

In [None]:
plt.figure()
plt.loglog(nDataLst, rezLst, '.', label='data')
plt.loglog(nDataLst, rezSh, '.', label='sh')
plt.legend()
plt.show()