In [1]:
from __future__ import print_function
from collections import OrderedDict
import numpy as np
import dit
from dit.pid.distributions import bivariates as dit_bivariates
from dit.pid.distributions import trivariates as dit_trivariates
import istar
import i_GH

def wrap2(f): # return 2-way redundancy
    def cf(pjoint):
        d = f(pjoint.copy())
        red = d[((0,),(1,))]
        return red
    return cf

def wrap3(f): # return 3-way redundancy
    def cf(pjoint):
        d = f(pjoint.copy())
        red = d[((0,),(1,),(2,))]
        return red
    return cf



In [2]:
# Construct a set of different bivariate distributions

bivariate_dists = OrderedDict()
bivariate_dists['sum']  = dit.Distribution(['000', '011', '101', '112'], [0.25,0.25,0.25,0.25])
bivariate_dists['Unq1'] = dit_bivariates['unique 1']

alpha = 0.7
p     = np.array([alpha, 1-alpha, 1-alpha, alpha])
p    /= p.sum()

bivariate_dists['AND']       = dit_bivariates['and']
bivariate_dists['Unq1alpha'] = dit.Distribution(['000', '011', '100', '111'], p)
bivariate_dists['COPY']      = dit_bivariates['cat']
bivariate_dists['COPYalpha'] = dit.Distribution(['000', '011', '102', '113'], p)

# Set of different redundancy measures to compute
bivariate_methods = OrderedDict()
bivariate_methods['Istar']  = lambda pjoint: istar.get_Istar(pjoint)[0]
bivariate_methods['WB']     = wrap2(dit.pid.PID_WB)    # Williams and Beer 2010
bivariate_methods['BROJA']  = wrap2(dit.pid.PID_BROJA) # Bertschinger et al. 2014
bivariate_methods['Proj']   = wrap2(dit.pid.PID_Proj)  # Harder et al. 2013
bivariate_methods['Iwedge'] = wrap2(dit.pid.PID_GK)    # Griffith et al. 2014
bivariate_methods['CCS']    = wrap2(dit.pid.PID_CCS)   # Ince 2017
bivariate_methods['Ipm']    = wrap2(dit.pid.PID_PM)    # Finn and Lizier
bivariate_methods['Idep']   = wrap2(dit.pid.PID_dep)   # James  # too slow

# The following can be very slow, so it is commented out for now.
# bivariate_methods['GH']     = lambda pjoint: i_GH.get_I_GH(pjoint, n_q=4)[0]  # Griffith and Ho 2015


In [3]:
print('%10s'%'',end=' ')
for m in bivariate_methods.keys():
    print('%7s'%m, end=' ')
print()

for k, pjoint in bivariate_dists.items():
    print('%10s'%k, end=' ')
    for m, method in bivariate_methods.items():
        v = method(pjoint)
        print(' % .3f'%v, end=' ')
    print()

             Istar      WB   BROJA    Proj  Iwedge     CCS     Ipm    Idep 
       sum   0.500   0.500   0.500   0.500   0.000   0.000   0.500   0.189 
      Unq1   0.000   0.000   0.000   0.000   0.000   0.000   1.000   0.000 
       AND   0.311   0.311   0.311   0.311   0.000   0.104   0.561   0.082 
 Unq1alpha   0.119   0.119   0.119   0.119   0.000   0.340   1.000   0.119 
      COPY   0.000   1.000   0.000  -0.000   0.000   0.000   1.000   0.000 
 COPYalpha   0.000   1.000   0.119   0.119   0.000   0.340   1.000   0.119 


In [4]:
# Trivariate distributions

# 3-way AND
trivariate_dists = OrderedDict()
states = []
for i in range(2**3):
    s = format(i, '03b')
    states.append(s + ('0' if s!='111' else '1'))
trivariate_dists['AND3'] = dit.Distribution(states, np.ones(len(states))/len(states))

trivariate_dists['SUM3'] = dit_trivariates['sum']

# X1 = (A,B), X2 = (A,C), X3 = (A,D) (one variable in common)
# Y=(X1,X2,X3)
states = []
ndx = 0
statenames='0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!@'
for x1 in range(4):
    for x2 in range(4):
        for x3 in range(4):
            x1a = int(x1<2)
            x2a = int(x2<2)
            x3a = int(x3<2)
            if x1a != x2a or x1a != x3a:
                continue
            states.append(str(x1)+str(x2)+str(x3)+statenames[ndx])
            ndx+=1 
    
trivariate_dists['Overlap'] = dit.Distribution(states, np.ones(len(states))/len(states))


# Trivariate methods
trivariate_methods = OrderedDict()
trivariate_methods['Istar']  = lambda pjoint: istar.get_Istar(pjoint)[0]
trivariate_methods['WB']     = wrap3(dit.pid.PID_WB)
trivariate_methods['Iwedge'] = wrap3(dit.pid.PID_GK)
trivariate_methods['Ipm']    = wrap3(dit.pid.PID_PM)
trivariate_methods['CCS']    = wrap3(dit.pid.PID_CCS)


print('%10s'%'',end=' ')
for m in trivariate_methods.keys():
    print('%7s'%m, end=' ')
print()

for k, pjoint in trivariate_dists.items(): # trivariate distributions from dit
    print('%10s'%k, end=' ')
    for m, method in trivariate_methods.items():
        if k == 'Overlap' and m == 'CCS':
            # Gets stucks
            v = np.nan
        else:
            v = method(pjoint)
        print(' % .3f'%v, end=' ')
    print()


             Istar      WB  Iwedge     Ipm     CCS 
      AND3   0.138   0.138   0.000   0.294   0.024 
      SUM3   0.311   0.311   0.000   0.561   0.000 
   Overlap   1.000   2.000   1.000   2.000   nan 
