In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

In [2]:
# convert class to one hot
def convert_to_one_hot(y):
    dict_size = np.unique(y).shape[0]
    y_hot = np.eye(dict_size)[y.astype('int32')]
    return y_hot

def make_one_hot(y, dims):
    y_hot = []
    for i in range(dims):
        y_hot.append(convert_to_one_hot(y[:, i]))
    return y_hot

In [3]:
# Discretize to equi-probability bins
def discretize(data, bins):
    split = np.array_split(np.sort(data), bins)
    cutoffs = [x[-1] for x in split]
    cutoffs = cutoffs[:-1]
    discrete = np.digitize(data, cutoffs, right=True)
    return discrete, cutoffs

def discretize_batch(data, bins, batch_size):
    z_disc = np.zeros((data.shape[0], data.shape[1]))
    for d in range(data.shape[1]):
        z_disc[:, d], _ = discretize(data[:, d], bins)
    return z_disc

In [4]:
# Creating the covariance matrix
def making_cov(rho, dims):
    cov = np.zeros((2 * dims, 2 * dims))
    for i in range(dims):
        cov[i, i] = 1
        cov[i + dims, i + dims] = 1
        cov[i, i + dims] = rho
        cov[i + dims, i] = rho
    return cov

def generate_gaussian(rho, batch_size, dims):
    cov = making_cov(rho, dims)
    z = np.random.multivariate_normal(mean=np.repeat(0, dims * 2), cov=cov, size=batch_size)
    return torch.tensor(z, dtype=torch.float32)

In [5]:
class ModelBasicClassification(nn.Module):
    def __init__(self, input_shape, class_size):
        super(ModelBasicClassification, self).__init__()
        self.l1 = nn.Linear(input_shape, 500)
        self.l2 = nn.Linear(500, 500)
        self.l3 = nn.Linear(500, 500)
        self.output = nn.Linear(500, class_size)

    def forward(self, x):
        x = torch.relu(self.l1(x))
        x = torch.relu(self.l2(x))
        x = torch.relu(self.l3(x))
        output = self.output(x)
        return output

In [6]:
# Initialize entropy models
bins = 250
dims = 20

model_lst = []
opt_lst = []
for m in range(0, dims):
    if m == 0:
        model_lst.append(None)
    else:
        model_lst.append(ModelBasicClassification(m, bins))

for m in range(0, dims):
    if m == 0:
        opt_lst.append(None)
    else:
        opt_lst.append(optim.Adam(model_lst[m].parameters()))

In [7]:
# Initialize conditional entropy models
model_lst_cond = []
opt_lst_cond = []
for m in range(0, dims):
    model_lst_cond.append(ModelBasicClassification(dims + m, bins))

for m in range(0, dims):
    opt_lst_cond.append(optim.Adam(model_lst_cond[m].parameters()))

In [8]:
# Running different true values of I/rho
r_lst = []
I = np.arange(2, 12, 2)
for i in I:
    r_lst.append((1 - np.exp(-2 * i / dims)) ** 0.5)

In [None]:
# Run 4000 epochs for every rho/MI value
epochs = 4000
batch_size = 256
sub_loss_lst = []
H_y_lst = [[] for _ in range(dims)]
H_yx_lst = [[] for _ in range(dims)]
I_hat = []
H_y_res = []
EMA_SPAN = 200

for r in r_lst:
    for i in range(epochs):
        z_0 = generate_gaussian(r, batch_size, dims)
        for j in range(dims):
            if j != 0:  # if this is not the first co
                x = torch.tensor(z_0[:, :j], dtype=torch.float32)
                y = torch.tensor(z_0[:, j], dtype=torch.long)
                y = torch.reshape(y, (-1, 1))
                y = discretize_batch(y, bins, batch_size)
                y_hot = nn.functional.one_hot(torch.tensor(y).long(), bins)

                opt_lst[j].zero_grad()
                output = model_lst[j](x)
                loss = nn.functional.cross_entropy(output, torch.argmax(y_hot, dim=1).float())
                loss.backward()
                opt_lst[j].step()
                H_y_lst[j].append(loss.item())
            else:
                y = torch.tensor(z_0[:, j], dtype=torch.long)
                y = torch.reshape(y, (-1, 1))
                y = discretize_batch(y, bins, batch_size)
                _, p_1 = np.unique(y, return_counts=True)
                p_1 = p_1 / (p_1.sum() + 10 ** -5)
                loss = -np.sum(np.array(p_1) * np.log(p_1)) + (pd.DataFrame(y).nunique() - 1) / (2 * batch_size)
                H_y_lst[j].append(loss.item())

            x = torch.tensor(z_0[:, :dims+j], dtype=torch.float32)
            y = torch.tensor(z_0[:, dims+j], dtype=torch.long)
            y = torch.reshape(y, (-1, 1))
            y = discretize_batch(y, bins, batch_size)
            y_hot = nn.functional.one_hot(torch.tensor(y).long(), bins)

            opt_lst_cond[j].zero_grad()
            output = model_lst_cond[j](x)
            loss = nn.functional.cross_entropy(output, torch.argmax(y_hot, dim=1).float())
            loss.backward()
            opt_lst_cond[j].step()
            H_yx_lst[j].append(loss.item())

        H_y = pd.Series(np.reshape(np.sum(H_y_lst, axis=0), [-1]))
        H_yx = pd.Series(np.reshape(np.sum(H_yx_lst, axis=0), [-1]))
        I_hat.append(H_y.iloc[-1] - H_yx.iloc[-1])

  y = torch.tensor(z_0[:, j], dtype=torch.long)
  x = torch.tensor(z_0[:, :dims+j], dtype=torch.float32)
  y = torch.tensor(z_0[:, dims+j], dtype=torch.long)
  x = torch.tensor(z_0[:, :j], dtype=torch.float32)
  y = torch.tensor(z_0[:, j], dtype=torch.long)
  y = torch.tensor(z_0[:, j], dtype=torch.long)
  x = torch.tensor(z_0[:, :dims+j], dtype=torch.float32)
  y = torch.tensor(z_0[:, dims+j], dtype=torch.long)
  x = torch.tensor(z_0[:, :j], dtype=torch.float32)
  y = torch.tensor(z_0[:, j], dtype=torch.long)
  y = torch.tensor(z_0[:, j], dtype=torch.long)
  x = torch.tensor(z_0[:, :dims+j], dtype=torch.float32)
  y = torch.tensor(z_0[:, dims+j], dtype=torch.long)
  x = torch.tensor(z_0[:, :j], dtype=torch.float32)
  y = torch.tensor(z_0[:, j], dtype=torch.long)
  y = torch.tensor(z_0[:, j], dtype=torch.long)
  x = torch.tensor(z_0[:, :dims+j], dtype=torch.float32)
  y = torch.tensor(z_0[:, dims+j], dtype=torch.long)
  x = torch.tensor(z_0[:, :j], dtype=torch.float32)
  y = torch.tens

In [None]:
# Plot
EMA_SPAN = 200
I_real = np.hstack([np.repeat(i, epochs) for i in I])
plt.plot(I_real, 'k', label='True MI')
mi = pd.DataFrame(I_hat).clip(lower=0)
mi_smooth = mi.ewm(span=EMA_SPAN).mean()
plt.plot(mi_smooth, 'tab:red', label='MI_hat')
plt.ylabel('mutual information(nats)')
plt.xlabel('batch number')
plt.xlim(0, epochs * len(I))
plt.legend()
plt.title('rho = 0.99 \n the real entropy of H(x) is (-)5.39')
plt.savefig(f'MI-{epochs}1.png')