# Gaussian Information Bottleneck
Chechik et al., (2005) JMLR - http://www.jmlr.org/papers/v6/chechik05a.html

In [None]:
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt

from numpy.linalg import inv, det
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm

plt.rc('text', usetex=True)
plt.rc('font', family='serif')

Details from Figure 1. caption

$\Sigma_x = I_2$

$\Sigma_{xy} = [0.1, 0.2]$ cross-covariance matrix

$\Sigma_{\zeta} = 0.3 I_{2 \times 2}$

In [None]:
cov_x = np.eye(2)
cov_xy = np.array([0.1, 0.2])
cov_zeta = 0.3 * np.eye(2)  # noise

In [None]:
cov_x_g_y = cov_x - np.outer(cov_xy.T, cov_xy)
print(cov_x_g_y)

$L = I(X; T) - \beta I(T; Y)$

$L = h(T) - h(T|X) - (\beta h(T) - \beta h(T|Y))$

$L = h(T) - h(T|X) - \beta h(T) + \beta h(T|Y)$

Because Gaussian

$L =  \log(|\Sigma_t|) - \log(|\Sigma_{t|x}|) - \beta \log(|\Sigma_t|) + \beta \log(|\Sigma_{t|y}|)$

$L = (1-\beta) \log(|A \Sigma_x A^T + \Sigma_{\zeta}|) - \log(|\Sigma_{\zeta}|) + \beta\log(|A \Sigma_{x|y} A^T + \Sigma_{\zeta}|)$

In [None]:
beta = 100  # try 15
d = 25
z = np.zeros((d, d))

a = np.arange(-5, 5, 0.4)
xv, yv = np.meshgrid(a, a)

for i in range(d):
    for j in range(d):
        A = np.array([[a[i], a[j]],
                      [0.,     0]])
        IXT = np.log(det(np.dot(np.dot(A, np.eye(2)), A.T) + cov_zeta))
        ITY = np.log(det(np.dot(np.dot(A, cov_x_g_y), A.T) + cov_zeta))
        z[i, j] = (1. - beta) * IXT - np.log(det(cov_zeta)) + beta * ITY

In [None]:
fig = plt.figure()
ax = fig.gca(projection='3d')
ax.plot_surface(xv, yv, z, cmap=cm.hot, linewidth=0, antialiased=True)
ax.set_xticks([-5, 0, 5])
ax.set_yticks([-5, 0, 5])
#ax.set_xticklabels('A1')
#ax.set_zticklabels(r'$L_{IB}$')
ax.set_xlim(-5, 5)
ax.set_ylim(-5, 5)
ax.tick_params(labelsize='x-large')
ax.view_init(elev=35., azim=-15)
plt.tight_layout()
plt.show()

In [None]:
fig.savefig('GIB_fig1_beta%d.png' % beta)

In [None]:
cov_tx = np.dot(A, cov_x)
# print(cov_tx)

cov_ty = np.dot(A, cov_xy)
# print(cov_ty)

cov_t = np.dot(np.dot(A, cov_x), A.T) + cov_zeta
# print(cov_t)

In [None]:
import matplotlib
import matplotlib.pyplot as plt
matplotlib.rcParams['text.usetex'] = True
matplotlib.rcParams['text.latex.unicode'] = True

string = r'z=${value}^{upper}_{lower}$'.format(
    value='{' + str(0.27) + '}',
    upper='{+' + str(0.01) + '}',
    lower='{-' + str(0.01) + '}')
print(string)

fig = plt.figure(figsize=(3, 1))
fig.text(0.1, 0.5, string, size=24, va='center')