# Comparison of Blackwell redundancy $I_\cap^\prec$ to other measures

This notebook recreates comparison tables (Table 2 and Table 3) from 

* A Kolchinsky, A Novel Approach to the Partial Information Decomposition, *Entropy*, 2022.



In [None]:
from __future__ import print_function
from collections import OrderedDict
import numpy as np
import warnings
warnings.filterwarnings("ignore")
import dit

from blackwell_redundancy import get_Iprec
from i_GH                 import get_I_GH

def wrap(f): # useful function to return redundancy value from PID computed by dit
    def cf(pjoint):
        d = f(pjoint.copy())
        red_node = tuple([(i,) for i in range(len(pjoint.rvs)-1)])
        red = d[red_node]
        return red
    return cf

def calculate_GH(pjoint):
    # Calculate redundancy measure proposed by Griffith and Ho 2015
    # This measure can be very slow to calculate. We use the following trick:
    # We know that I_wedge <= I^\wedge <= I^\prec. So we calculate I^\wedge and I^\prec,
    # and if they give approximately the same value we return that for I_GH
    lb = wrap(dit.pid.PID_GK)(pjoint)
    up = get_Iprec(pjoint)[0]
    if np.abs(up-lb) < 1e-4:
        return up
    else:
        return get_I_GH(pjoint, n_q=3)[0]

# Set of different redundancy measures to compute
methods = OrderedDict()
methods['≺']      = lambda pjoint: get_Iprec(pjoint)[0]  # Our measure of blackwell redundancy
methods['WB']     = wrap(dit.pid.PID_WB)    # Williams and Beer, 2010
methods['MMI']    = wrap(dit.pid.PID_MMI)   # (Minimum Mutual Information) Barrett, 2015
methods['∧']      = wrap(dit.pid.PID_GK)    # Griffith et al, 2014
methods['GH']     = calculate_GH            # Griffith and Ho, 2015
methods['Ince']   = wrap(dit.pid.PID_CCS)   # Ince, 2017
methods['FL']     = wrap(dit.pid.PID_PM)    # Finn and Lizier, 2018
methods['BROJA']  = wrap(dit.pid.PID_BROJA) # Bertschinger et al., 2014
methods['Harder'] = wrap(dit.pid.PID_Proj)  # Harder et al. 2013
methods['dep']    = wrap(dit.pid.PID_dep)   # James et al, 2018



In [None]:
# Construct a set of different bivariate distributions. 
# This recreates Table 2 from the paper

bivariate_dists = OrderedDict()
bivariate_dists['Y=X1 AND X2'] = dit.pid.distributions.bivariates['and']
bivariate_dists['Y=X1 + X2']   = dit.Distribution(['000', '011', '101', '112'], [0.25,0.25,0.25,0.25])
bivariate_dists['Y=X1']        = dit.pid.distributions.bivariates['unique 1']
bivariate_dists['Y=(X1,X2)']   = dit.pid.distributions.bivariates['cat']

# Can also introduce some correlations (c) between sources into the Y=X1 and COPY gates
alpha = 0.7
p     = np.array([alpha, 1-alpha, 1-alpha, alpha])
p    /= p.sum()
bivariate_dists['Y=X1/c'] = dit.Distribution(['000', '011', '100', '111'], p)
bivariate_dists['Y=(X1,X2)/c'] = dit.Distribution(['000', '011', '102', '113'], p)

print('%12s'%'',end=' ')
for m in methods.keys():
    print('%7s'%m, end=' ')
print()

for k, pjoint in bivariate_dists.items():
    print('%12s'%k, end=' ')
    for m, method in methods.items():
        if k in ['COPY','COPYalpha'] and m in ('Idep',): # too slow
            v = ' -----'
        else:
            v = '% .3f'%method(pjoint)
        print(' ' + v, end=' ')
    print()
    
    

In [None]:
# Trivariate distributions
# This recreates Table 3 from our paper
trivariate_methods = ['≺', 'WB', 'MMI', '∧', 'Ince','FL']

# 3-way AND
trivariate_dists = OrderedDict()
states = []
for i in range(2**3):
    s = format(i, '03b')
    states.append(s + ('0' if s!='111' else '1'))
trivariate_dists['Y=X1 AND X2 AND X3'] = dit.Distribution(states, np.ones(len(states))/len(states))
trivariate_dists['Y=X1 + X2 + X3'] = dit.pid.distributions.trivariates['sum']

# Creates the overlap gate
# X1=(A,B), X2=(A,C), X3=(A,D) (one variable in common)
# Y=(X1,X2,X3)
states = []
ndx = 0
statenames='0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!@'
for x1 in range(4):
    for x2 in range(4):
        for x3 in range(4):
            x1a = int(x1<2)
            x2a = int(x2<2)
            x3a = int(x3<2)
            if x1a != x2a or x1a != x3a:
                continue
            sname = str(x1)+str(x2)+str(x3)+statenames[ndx]
            states.append(sname)
            ndx+=1 
    
trivariate_dists['Y=((A,B),(A,C),(A,D))'] = dit.Distribution(states, np.ones(len(states))/len(states))

print('%22s'%'',end=' ')
for m in trivariate_methods:
    print('%7s'%m, end=' ')
print()

for k, pjoint in trivariate_dists.items():
    print('%22s'%k, end=' ')
    for m in trivariate_methods:
        v = methods[m](pjoint)
        print(' % .3f'%v, end=' ')
    print()
