In [1]:
import numpy as np
import sys

In [2]:
# define log and divide functions to avoid numerical errors
EPS = np.finfo(np.float).eps
def safelog(x):
    return np.log(x+EPS)
def safedivide(x,y):
    return x / (y+EPS)

In [3]:
np.set_printoptions(suppress=True)
np.random.seed(420)

In [4]:
# to simulate our data, we will have p = product(dims) >> n
dims = [100,100]
n = 100
support_size = np.prod(dims)
print_probs = True
if np.product(dims) > 50:
    print_probs = False
print(">>> SETTINGS")
print("> dims = {}".format(dims))
print("> size of support = {}".format(support_size))
print("> n = {}".format(n))

>>> SETTINGS
> dims = [100, 100]
> size of support = 10000
> n = 100


In [5]:
# true_p_xy = np.random.ranp.prod(dims)nd(dims[0],dims[1])
# true_p_xy /= np.sum(true_p_xy)
true_p_xy = np.random.dirichlet((2/support_size)*np.ones(support_size)).reshape(dims[0],dims[1])
# print(true_p_xy)
true_p_x = np.sum(true_p_xy,axis=1)
true_p_y = np.sum(true_p_xy,axis=0)
true_p_ygx = safedivide(true_p_xy, true_p_x.reshape(-1,1))
true_h_x = -np.sum(true_p_x * safelog(true_p_x))
true_h_y = -np.sum(true_p_y * safelog(true_p_y))
true_h_xy = -np.sum(true_p_xy * safelog(true_p_xy))
true_mi_xy = true_h_x + true_h_y - true_h_xy

if print_probs:
    print(">>> TRUE")
    print(true_p_xy)
    print(true_p_x)
    print(true_p_y)
    print(true_p_ygx)

In [6]:
# sample from multinoulli n times with probabilities true_p_xy
obs_xy = np.random.multinomial(n,true_p_xy.flatten()).reshape(dims[0],dims[1])
obs_x = np.sum(obs_xy,axis=1)
obs_y = np.sum(obs_xy,axis=0)

if print_probs:
    print(">>> OBS")
    print(obs_xy)
    print(obs_x)
    print(obs_y)

In [7]:
mle_p_xy = safedivide(obs_xy, np.sum(obs_xy))
mle_p_x = safedivide(obs_x, np.sum(obs_x))
mle_p_y = safedivide(obs_y, np.sum(obs_y))
mle_p_ygx = safedivide(obs_xy, obs_x.reshape(-1,1))

# verify that all of the probs are still probs
assert np.isclose(np.sum(mle_p_xy), 1.)
assert np.isclose(np.sum(mle_p_x), 1.)
assert np.isclose(np.sum(mle_p_y), 1.)
assert np.all(np.logical_or(np.isclose(np.sum(mle_p_ygx,axis=1),1.),np.isclose(np.sum(mle_p_ygx,axis=1),0.)))

if print_probs:
    print(">>> MLE")
    print(mle_p_xy)
    print(mle_p_x)
    print(mle_p_y)
    print(mle_p_ygx)

In [8]:
t_xy = 1. / np.prod(obs_xy.shape)
lambda_xy = (1. - np.sum(mle_p_xy**2)) / ((n-1)*np.sum((t_xy-mle_p_xy)**2))
js_p_xy = lambda_xy*t_xy + (1-lambda_xy)*mle_p_xy

t_x = 1. / np.prod(obs_x.shape)
lambda_x = (1. - np.sum(mle_p_x**2)) / ((n-1)*np.sum((t_x-mle_p_x)**2))
js_p_x = lambda_x*t_x + (1-lambda_x)*mle_p_x

t_y = 1. / np.prod(obs_y.shape)
lambda_y = (1. - np.sum(mle_p_y**2)) / ((n-1)*np.sum((t_y-mle_p_y)**2))
js_p_y = lambda_y*t_y + (1-lambda_y)*mle_p_y

t_ygx = 1. / obs_xy.shape[1]
lambda_ygx = (1. - np.sum(mle_p_ygx**2,axis=1)) / ((n-1)*np.sum((t_ygx-mle_p_ygx)**2,axis=1))
lambda_ygx = lambda_ygx.reshape(-1,1)
js_p_ygx = lambda_ygx*t_ygx + (1-lambda_ygx)*mle_p_ygx

# verify that all of the probs are still probs
assert np.isclose(np.sum(js_p_xy), 1.)
assert np.isclose(np.sum(js_p_x), 1.)
assert np.isclose(np.sum(js_p_y), 1.)
assert np.all(np.logical_or(np.isclose(np.sum(js_p_ygx,axis=1),1.),np.isclose(np.sum(js_p_ygx,axis=1),np.max(lambda_ygx))))

if print_probs:
    print(">>> JS")
    print(js_p_xy)
    print(js_p_x)
    print(js_p_y)
    print(js_p_ygx)

In [9]:
# compare estimators for various information quantities

# these are the ground truth values

print(">>> TRUE")
print(f"H(X) = {true_h_x:.2f}")
print(f"H(Y) = {true_h_y:.2f}")
print(f"H(X,Y) = {true_h_xy:.2f}")
print(f"I(X;Y) = {true_mi_xy:.2f}")

>>> TRUE
H(X) = 1.15
H(Y) = 1.15
H(X,Y) = 1.15
I(X;Y) = 1.14


In [10]:
# first approach: MLE

mle_h_x = -np.sum(mle_p_x * safelog(mle_p_x))
mle_h_y = -np.sum(mle_p_y * safelog(mle_p_y))
mle_h_xy = -np.sum(mle_p_xy * safelog(mle_p_xy))
mle_mi_xy = mle_h_x + mle_h_y - mle_h_xy

print(">>> MLE")
print(f"H(X) = {mle_h_x:.2f}, Δ={mle_h_x-true_h_x:.2f}")
print(f"H(Y) = {mle_h_y:.2f}, Δ={mle_h_y-true_h_y:.2f}")
print(f"H(X,Y) = {mle_h_xy:.2f}, Δ={mle_h_xy-true_h_xy:.2f}")
print(f"I(X;Y) = {mle_mi_xy:.2f}, Δ={mle_mi_xy-true_mi_xy:.2f}")

>>> MLE
H(X) = 1.04, Δ=-0.10
H(Y) = 1.04, Δ=-0.11
H(X,Y) = 1.04, Δ=-0.11
I(X;Y) = 1.04, Δ=-0.10


In [11]:
# second approach: James-Stein shrinkage

js_h_x = -np.sum(js_p_x * safelog(js_p_x))
js_h_y = -np.sum(js_p_y * safelog(js_p_y))
js_h_xy = -np.sum(js_p_xy * safelog(js_p_xy))
js_mi_xy_1 = js_h_x + js_h_y - js_h_xy
js_h_xgy = -np.sum(js_p_x * np.sum(js_p_ygx * safelog(js_p_ygx), axis=1))
# js_mi_xy_2 = js_h_x - js_h_xgy

print(">>> JS")
print(f"H(X) = {js_h_x:.2f}, Δ={js_h_x-true_h_x:.2f}")
print(f"H(Y) = {js_h_y:.2f}, Δ={js_h_y-true_h_y:.2f}")
print(f"H(X,Y) = {js_h_xy:.2f}, Δ={js_h_xy-true_h_xy:.2f}")
print(f"I(X;Y) = {js_mi_xy_1:.2f}, Δ={js_mi_xy_1-true_mi_xy:.2f}")
# print(f"{js_mi_xy_2:.2f}, Δ={js_mi_xy_2-true_mi_xy:.2f}")

>>> JS
H(X) = 1.12, Δ=-0.03
H(Y) = 1.12, Δ=-0.03
H(X,Y) = 1.17, Δ=0.02
I(X;Y) = 1.07, Δ=-0.07


In [12]:
# support_size = 3*5
# d = np.random.dirichlet((2/support_size)*np.ones(support_size))
# print(d)