# Mutual Information

In [39]:
#@title Install Packages
# %%capture
try:
    import sys, os
    from pyprojroot import here

    # spyder up to find the root
    root = here(project_files=[".here"])

    # append to path
    sys.path.append(str(here()))
except ModuleNotFoundError:
    import os
    os.system("pip install git+https://github.com/IPL-UV/rbig_jax.git#egg=rbig_jax")
    
finally:
    # import library functions
    from rbig_jax.data import get_classic
    from rbig_jax.plots import plot_joint, plot_joint_prob, plot_info_loss
    from rbig_jax.information.mi import rbig_mutual_info, rbig_mutual_info_sum

In [8]:
# jax packages
import jax
import jax.numpy as np
from jax.config import config
import chex
config.update("jax_enable_x64", False)

import numpy as onp
from functools import partial

# logging
import tqdm
import wandb

# plot methods
import matplotlib.pyplot as plt
import seaborn as sns
sns.reset_defaults()
sns.set_context(context="talk", font_scale=0.7)

%matplotlib inline
%load_ext autoreload
%autoreload 2

## Demo Data

In [9]:
#Parameters
n_samples = 10000
d_dimensions = 10

seed = 123

rng = onp.random.RandomState(seed)

# Generate random Data
A = rng.rand(2 * d_dimensions, 2 * d_dimensions)

# Covariance Matrix
C = A @ A.T
mu = np.zeros((2 * d_dimensions))

dat_all = rng.multivariate_normal(mu, C, n_samples)

CX = C[:d_dimensions, :d_dimensions]
CY = C[d_dimensions:, d_dimensions:]

X = dat_all[:, :d_dimensions]
Y = dat_all[:, d_dimensions:]



In [10]:
H_X = 0.5 * np.log(2 * np.pi * np.exp(1) * np.abs(np.linalg.det(CX)))
H_Y = 0.5 * np.log(2 * np.pi * np.exp(1) * np.abs(np.linalg.det(CY)))
H = 0.5 * np.log(2 * np.pi * np.exp(1) * np.abs(np.linalg.det(C)))

mi_original = H_X + H_Y - H
mi_original *= np.log(2)

print(f"MI: {mi_original:.4f}")

MI: 8.0710


## Mutual Information with RBIG

In [36]:
%%time

zero_tolerance = 30

X = np.array(X, np.float32)
Y = np.array(Y, np.float32)

mi_XY_rbig = rbig_mutual_info(
    X=X.block_until_ready(),
    Y=Y.block_until_ready(),
    zero_tolerance=zero_tolerance,
)


CPU times: user 33 s, sys: 14.2 s, total: 47.1 s
Wall time: 27.8 s


In [37]:

print(f"RBIG MIXY: {mi_XY_rbig.mi_X:.5f}")
print(f"RBIG MIx: {mi_XY_rbig.mi_Y:.5f}")
print(f"RBIG MIy: {mi_XY_rbig.mi_XY:.5f}")

RBIG MIXY: 6.97406
RBIG MIx: 6.42184
RBIG MIy: 8.46860


### Summation Version


$$
I(\mathbf{X};\mathbf{Y}) = H(\mathbf{X}) + H(\mathbf{Y}) - H(\mathbf{X},\mathbf{Y})
$$

In [45]:
%%time

zero_tolerance = 30

X = np.array(X, np.float32)
Y = np.array(Y, np.float32)

mi_XY_rbig = rbig_mutual_info_sum(
    X=X.block_until_ready(),
    Y=Y.block_until_ready(),
    zero_tolerance=zero_tolerance,
)


CPU times: user 32.7 s, sys: 13.3 s, total: 46 s
Wall time: 27.8 s


In [46]:
print(f"RBIG H_X: {mi_XY_rbig.H_X:.5f}")
print(f"RBIG H_Y: {mi_XY_rbig.H_Y:.5f}")
print(f"RBIG H_XY: {mi_XY_rbig.H_XY:.5f}")
print(f"RBIG I_XY: {mi_XY_rbig.mi_XY:.5f}")

RBIG H_X: 26.95807
RBIG H_Y: 27.42295
RBIG H_XY: 44.69664
RBIG I_XY: 9.68437
