In [175]:
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.autograd import Variable


# Sliced Mutual Information Estimator

## Mutual Information Estimator
According to the paper, it is necessary to define an MI estimator for further use.

### 1.Mutual Information Neural Estimator

In [176]:
class MINE(nn.Module):
    def __init__(self, x_size, y_size, proj_size=32):
        super(MINE, self).__init__()
        self.fc1 = nn.Linear(x_size, proj_size)
        self.fc2 = nn.Linear(y_size, proj_size)
        self.fc3 = nn.Linear(proj_size, 16)
        self.fc4 = nn.Linear(16, 1)
    
    def forward(self, x, y):
        h = self.fc1(x) + self.fc2(y)
        h = F.relu(h)
        return self.fc4(F.relu(self.fc3(h)))
    

In [177]:
def train_mine(x, y, num_epoch=500):
    model = MINE(x.shape[1], y.shape[1])
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    S_sample = Variable(torch.from_numpy(x).type(torch.FloatTensor), requires_grad=True)
    w_sample = Variable(torch.from_numpy(y).type(torch.FloatTensor), requires_grad=True)
    w_shuffe_sample = Variable(torch.from_numpy(np.random.permutation(y)).type(torch.FloatTensor), requires_grad=True)
    
    mi = []
    for _ in range(num_epoch):
        pred_xy = model(S_sample, w_sample)
        pred_x_y = model(S_sample, w_shuffe_sample)
    
        loss1 = -torch.mean(pred_xy)
        loss2 = torch.log(torch.mean(torch.exp(pred_x_y)))
        loss = loss1 + loss2
        
        mi.append((-loss).data.numpy())
        
        model.zero_grad()
        loss.backward()
        optimizer.step()
    
    return mi[-1]

### 2.RBIG

## Implementing Algorithms

In [178]:
def estmate_smi(m: int, X, Y, pair_num: int, mi_calculator):
    def hyperspherical_sample(num: int, dim: int) -> np.ndarray:
        angles = np.random.uniform(0, 2*np.pi, size=(num, dim-1))
        coord = np.zeros((num, dim))
        coord[:, :-1] = np.cos(angles)
        coord[:, -1] = np.sin(angles[:, -1])
        return coord
    
    num_x = X.shape[0]
    num_y = Y.shape[0]
    dim_x = X.shape[1]
    dim_y = Y.shape[1]
    S = []
    for _ in tqdm(range(m)):
        theta = hyperspherical_sample(num=num_x, dim=dim_x)
        phi = hyperspherical_sample(num=num_y, dim=dim_y)
        
        theta = np.linalg.norm(theta)
        phi = np.linalg.norm(phi)
        
        X_input = np.dot(theta.T, X)
        Y_input = np.dot(phi.T, Y)
        
        _S = mi_calculator(X_input, Y_input)
        S.append(_S)
    
    S = np.array(S)
    return np.sum(S) / m
        

# Test

In [179]:
X = np.random.normal(0, 0.3, (1000, 10))
Y = np.random.normal(0,3, (1000, 10))

In [180]:
mi_calculator = train_mine
si = estmate_smi(500, X, Y, 1000, mi_calculator=mi_calculator)

  2%|▏         | 8/500 [00:17<17:52,  2.18s/it]


KeyboardInterrupt: 

In [None]:
print(si)

nan
