In [None]:
# Demonstration file showing how to compute I_\cap^\alpha
from collections import OrderedDict
import numpy as np
import dit
from dit.pid.distributions import bivariates as dit_bivariates
from dit.pid.distributions import trivariates as dit_trivariates
import istar

In [None]:
# Construct a set of different bivariate distributions
def wrap(f):
    def cf(pjoint):
        return f(pjoint.copy()).get_redundancy(((0,),(1,)))
    return cf

bivariate_dists = OrderedDict()
bivariate_dists['Redund']    = dit_bivariates['redundant']
bivariate_dists['rdn_xor']   = dit_bivariates['rdn xor']
bivariate_dists['imp. rdn']  = dit_bivariates['imp. rdn']
bivariate_dists['wb 1']      = dit_bivariates['wb 1']
bivariate_dists['wb 2']      = dit_bivariates['wb 2']
bivariate_dists['wb 3']      = dit_bivariates['wb 3']
bivariate_dists['Unq1']      = dit_bivariates['unique 1']

alpha = 0.7
p     = np.array([alpha, 1-alpha, 1-alpha, alpha])
p    /= p.sum()

bivariate_dists['AND']       = dit_bivariates['and']
bivariate_dists['Unq1alpha'] = dit.Distribution(['000', '011', '100', '111'], p)
bivariate_dists['COPY']      = dit_bivariates['cat']
bivariate_dists['COPYalpha'] = dit.Distribution(['000', '011', '102', '113'], p)
bivariate_dists['XOR']       = dit_bivariates['synergy']
bivariate_dists['f1']        = dit_bivariates['f1']

bivariate_dists['erase']     = dit_bivariates['erase']
bivariate_dists['D']         = dit_bivariates['diff'] # differentiates proj and broja
bivariate_dists['boom']      = dit_bivariates['boom']

states = ['00','01','10','11','22']
n = len(states)
p = np.round(np.random.rand(n)*50)+1
p /= p.sum()
fullstates = [s + str(i) for i, s in enumerate(states)]
bivariate_dists['CopyC']  = dit.Distribution(fullstates, p)

# Set of different redundancy measures to compute
bivariate_methods = OrderedDict()
bivariate_methods['Istar']  = lambda pjoint: istar.get_Istar(pjoint, n_q=5)[0]

bivariate_methods['WB']     = wrap(dit.pid.PID_WB)    # Williams and Beer 2010
bivariate_methods['BROJA']  = wrap(dit.pid.PID_BROJA) # Bertschinger et al. 2014
bivariate_methods['Proj']   = wrap(dit.pid.PID_Proj)  # Harder et al. 2013
bivariate_methods['Iwedge'] = wrap(dit.pid.PID_GK)    # Griffith et al. 2014
bivariate_methods['CCS']    = wrap(dit.pid.PID_CCS)   # Ince 2017
bivariate_methods['Ipm']    = wrap(dit.pid.PID_PM)    # Finn and Lizier


In [None]:
print('%10s'%'',end=' ')
for m in bivariate_methods.keys():
    print('%7s'%m, end=' ')
print()

for k, pjoint in bivariate_dists.items():
    print('%10s'%k, end=' ')
    for m, method in bivariate_methods.items():
        v = method(pjoint)
        print(' % .3f'%v, end=' ')
    print()

In [None]:
# Trivariate distributions

# 3-way AND
trivariate_dists = OrderedDict()
states = []
for i in range(2**3):
    s = format(i, '03b')
    states.append(s + ('0' if s!='111' else '1'))
trivariate_dists['AND3'] = dit.Distribution(states, np.ones(len(states))/len(states))

# X1 = (A,B), X2 = (A,C), X3 = (A,D) (one variable in common)
# Y=(X1,X2,X3)
states = []
ndx = 0
statenames='0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!@'
for x1 in range(4):
    for x2 in range(4):
        for x3 in range(4):
            x1a = int(x1<2)
            x2a = int(x2<2)
            x3a = int(x3<2)
            if x1a != x2a or x1a != x3a:
                continue
            states.append(str(x1)+str(x2)+str(x3)+statenames[ndx])
            ndx+=1 
    
trivariate_dists['COPY3'] = dit.Distribution(states, np.ones(len(states))/len(states))
#print(states)
#asdf
            





# Trivariate methods
trivariate_methods = OrderedDict()
trivariate_methods['Istar']  = lambda pjoint: istar.get_Istar(pjoint, n_q=5)[0]

trivariate_methods['WB']     = wrap(dit.pid.PID_WB)
trivariate_methods['Iwedge'] = wrap(dit.pid.PID_GK)
trivariate_methods['CCS']    = wrap(dit.pid.PID_CCS)
trivariate_methods['Ipm']    = wrap(dit.pid.PID_PM)

print('%10s'%'',end=' ')
for m in trivariate_methods.keys():
    print('%7s'%m, end=' ')
print()

for k, pjoint in trivariate_dists.items(): # trivariate distributions from dit
    print('%10s'%k, end=' ')
    for m, method in trivariate_methods.items():
        v = method(pjoint)
        print(' % .3f'%v, end=' ')
    print()
