# Basic experiment

# Make toy data

 Type 1. Independent Gaussian

$$
(x_1, x_2) \sim \mathcal{N}(0,I)
$$

 Type 2. Correlated Gaussian

$$
(y_1, y_2) \sim \mathcal{N}(0,\begin{bmatrix}1&\rho\\ \rho&1\end{bmatrix})
$$

In [1]:
import numpy as np

In [None]:
x = np.random.multivariate_normal( mean=[0,0],
                                  cov=[[1,0],[0,1]],
                                 size = 300)

In [None]:
rho = 0.8
y = np.random.multivariate_normal( mean=[0,0],
                                  cov=[[1,rho],[rho,1]],
                                 size = 300)

In [2]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set()

In [None]:
# plt indep Gaussian
sns.scatterplot(x=x[:,0],y=x[:,1])

In [None]:
# plt cor Gaussian
sns.scatterplot(x=y[:,0],y=y[:,1])

# Define MINE

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.autograd as autograd

In [4]:
class Mine(nn.Module):
    def __init__(self, input_size=2, hidden_size=100):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, 1)
        nn.init.normal_(self.fc1.weight,std=0.02)
        nn.init.constant_(self.fc1.bias, 0)
        nn.init.normal_(self.fc2.weight,std=0.02)
        nn.init.constant_(self.fc2.bias, 0)
        nn.init.normal_(self.fc3.weight,std=0.02)
        nn.init.constant_(self.fc3.bias, 0)
        
    def forward(self, input):
        output = F.elu(self.fc1(input))
        output = F.elu(self.fc2(output))
        output = self.fc3(output)
        return output

In [5]:
def mutual_information(joint, marginal, mine_net):
    t = mine_net(joint)
    et = torch.exp(mine_net(marginal))
    mi_lb = torch.mean(t) - torch.log(torch.mean(et))
    return mi_lb, t, et

def learn_mine(batch, mine_net, mine_net_optim,  ma_et, ma_rate=0.01):
    # batch is a tuple of (joint, marginal)
    joint, marginal = batch
    joint = torch.autograd.Variable(torch.FloatTensor(joint)).cuda()
    marginal = torch.autograd.Variable(torch.FloatTensor(marginal)).cuda()
    mi_lb, t, et = mutual_information(joint, marginal, mine_net)
    ma_et = (1-ma_rate)*ma_et + ma_rate*torch.mean(et)
    
    # unbiasing use moving average
    loss = -(torch.mean(t) - torch.mean(et))/(ma_et.mean().detach())
    # use biased estimator
    # loss = - mi_lb
    
    mine_net_optim.zero_grad()
    autograd.backward(loss)
    mine_net_optim.step()
    return mi_lb, ma_et

In [6]:
def sample_batch(data, batch_size=100, sample_mode='joint'):
    if sample_mode == 'joint':
        index = np.random.choice(range(data.shape[0]), size=batch_size, replace=False)
        batch = data[index]
    else:
        joint_index = np.random.choice(range(data.shape[0]), size=batch_size, replace=False)
        marginal_index = np.random.choice(range(data.shape[0]), size=batch_size, replace=False)
        batch = np.concatenate([data[joint_index][:,0].reshape(-1,1),
                                         data[marginal_index][:,1].reshape(-1,1)], axis=1)
    return batch

In [None]:
joint_data = sample_batch(y, batch_size=100, sample_mode='joint')
sns.scatterplot(x=joint_data[:,0], y=joint_data[:,1], color='red')
marginal_data = sample_batch(y, batch_size=100, sample_mode='marginal')
sns.scatterplot(x=marginal_data[:,0], y=marginal_data[:,1])

In [7]:
def train(data, mine_net,mine_net_optim, batch_size=100, iter_num=int(5e+3), log_freq=int(1e+3)):
    # data is x or y
    result = list()
    ma_et = 1.
    for i in range(iter_num):
        batch = sample_batch(data, batch_size), sample_batch(data, batch_size, sample_mode='marginal')
        mi_lb, ma_et = learn_mine(batch, mine_net, mine_net_optim, ma_et)
        result.append(mi_lb.detach().cpu().numpy())
        if (i+1)%(log_freq)==0:
            print(result[-1])
    return result

In [8]:
def ma(a, window_size=100):
    return [np.mean(a[i:i+window_size]) for i in range(0, len(a)-window_size)]

In [None]:
mine_net_indep = Mine().cuda()
mine_net_optim_indep = optim.Adam(mine_net_indep.parameters(), lr=1e-3)
result_indep = train(x, mine_net_indep, mine_net_optim_indep)

In [None]:
result_indep_ma = ma(result_indep)
print(result_indep_ma[-1])
plt.plot(range(len(result_indep_ma)), result_indep_ma)

In [None]:
mine_net_cor = Mine().cuda()
mine_net_optim_cor = optim.Adam(mine_net_cor.parameters(), lr=1e-3)
result_cor = train(y, mine_net_cor, mine_net_optim_cor)

In [None]:
result_cor_ma = ma(result_cor)
print(result_cor_ma[-1])
plt.plot(range(len(result_cor_ma)), result_cor_ma)

# Test with various correlations

In [None]:
correlations = np.linspace(-0.9,0.9,19)
print(correlations)

In [None]:
final_result = []
for rho in correlations:
    rho_data = np.random.multivariate_normal( mean=[0,0], cov=[[1,rho],[rho,1]], size = 300)
    mine_net = Mine().cuda()
    mine_net_optim = optim.Adam(mine_net.parameters(), lr=1e-3)
    result = train(rho_data, mine_net, mine_net_optim)
    result_ma = ma(result)
    final_result.append(result_ma[-1])
    print(str(rho) + ' : ' + str(final_result[-1]))
    plt.plot(range(len(result_ma)), result_ma)

In [None]:
plt.plot(correlations, final_result)

# Equitability experiment

Check more information at [original equitability paper](http://www.pnas.org/content/pnas/111/9/3354.full.pdf).

# Make toy data

In [9]:
x = np.random.uniform(low=-1.,high=1.,size=3000)
f1 = x
f2 = 2*x
f3 = np.sin(np.pi*x)
f4 = x**3
f5 = np.exp(x)
eps = np.random.normal(size=3000)

In [None]:
sns.scatterplot(x, f1, color='red')

In [None]:
sns.scatterplot(x, f2, color='blue')

In [None]:
sns.scatterplot(x, f3, color='green')

In [None]:
sns.scatterplot(x, f4,color='purple')

In [None]:
sns.scatterplot(x, f5, color='yellow')

In [10]:
sigmas = np.linspace(0.0,0.9,10)
fs = [f1, f2, f3, f4, f5]
print(sigmas)

[0.  0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9]


In [11]:
final_result = []
for sigma in sigmas:
    for fi, f in enumerate(fs):
        data = np.concatenate([x.reshape(-1,1), (f + sigma * eps).reshape(-1,1)], axis=1)
        mine_net = Mine().cuda()
        mine_net_optim = optim.Adam(mine_net.parameters(), lr=1e-3)
        result = train(data,mine_net,mine_net_optim,iter_num=int(5e+3))
        result_ma = ma(result)
        final_result.append(result_ma[-1])
        print(str(sigma) + ',' + str(fi) +' : ' + str(final_result[-1]))
        plt.plot(range(len(result_ma)),result_ma)

3.0363154
3.3839862
5.149334
5.3287854
4.0147395
0.0,0 : 5.291139
3.5445228
4.5538654
4.011033
3.231633
4.7704268
0.0,1 : 6.2936225
0.616921
1.2972538
1.2289668
1.6995142
1.1948456
0.0,2 : 1.3565087
1.0250157
1.4436197
0.8813411
1.038464
1.2029274
0.0,3 : 1.2853531
1.1445699
2.037866
2.87651
2.180523
3.3242702
0.0,4 : 3.1188543
1.4978955
2.164853
1.660373
1.7273035
2.2097654
0.1,0 : 1.6564119
1.6917387
2.306057
1.9149991
2.2249746
2.1372004
0.1,1 : 2.424303
0.70135653
0.959475
1.0716975
1.0516242
1.2927284
0.1,2 : 1.1211165
0.7749643
0.7997097
0.6392024
0.77669096
0.950734
0.1,3 : 1.01199
1.3237885
1.7541288
1.8285522
2.2041268
1.9267054
0.1,4 : 1.6947986
1.1329917
1.1430949
0.93972576
0.95999014
0.94904727
0.2,0 : 1.0627402
1.8591108
1.6010873
1.54544
1.7203736
1.6013261
0.2,1 : 1.669944
0.80149806
0.8266838
1.0344738
0.7742781
1.195403
0.2,2 : 0.88936406
0.39594558
0.49725848
0.7590377
0.44830307
0.8826014
0.2,3 : 0.60874134
0.9515597
1.1663088
0.95945156
0.75200325
1.1532695
0.2,4 :

In [None]:
re_final_result = np.transpose(np.array([final_result[4*i: 4*i+4]  for i in range(10)]))
print(re_final_result)

In [None]:
sns.heatmap(re_final_result,cmap="YlGnBu",linewidths=.5)