# PyTorch mutual information neural estimation tests

Trivial tests with multivariate Gaussian and uniform distribution

In [1]:
import sys
sys.path.append("../python")

In [2]:
import numpy as np

In [3]:
import torch
import torchkld
import mutinfo

In [4]:
device = "cuda:1" if torch.cuda.is_available() else "cpu"
#device = "cpu"
print("Device: " + device)
print(f"Devices count: {torch.cuda.device_count()}")

Device: cuda:1
Devices count: 2


In [5]:
from tqdm import tqdm, trange

In [6]:
from mutinfo.distributions.base import *
from mutinfo.distributions.tools import mapped_multi_rv_frozen

In [7]:
from misc.modules import *
from misc.utils import *
from misc.plots import *

## Dataset

Experimental setup

In [8]:
dimension = 4

mutual_information = 5.0 * dimension

dataset_type = "CorrelatedUniform"
assert dataset_type in ["CorrelatedNormal", "CorrelatedStudent", "CorrelatedStudent_arcsinh", "CorrelatedUniform", "SmoothedUniform", "UniformlyQuantized"]
degrees_of_freedom = 2 # For Student's distribution

In [9]:
randomize_interactions = False
shuffle_interactions = True

if dataset_type == "CorrelatedNormal":
    random_variable = CorrelatedNormal(mutual_information, dimension, randomize_interactions=randomize_interactions, shuffle_interactions=shuffle_interactions)

elif dataset_type in ["CorrelatedStudent", "CorrelatedStudent_arcsinh"]:
    random_variable = CorrelatedStudent(
        mutual_information, dimension, dimension, degrees_of_freedom, randomize_interactions=randomize_interactions, shuffle_interactions=shuffle_interactions
    )

    if dataset_type == "CorrelatedStudent_arcsinh":
        random_variable = mapped_multi_rv_frozen(random_variable, lambda x, y: (np.arcsinh(x), np.arcsinh(y)), lambda x, y: (np.sinh(x), np.sinh(y)))

    dataset_type += f"_dof_{degrees_of_freedom}"
    
elif dataset_type == "CorrelatedUniform":
    random_variable = CorrelatedUniform(mutual_information, dimension, randomize_interactions=randomize_interactions, shuffle_interactions=shuffle_interactions)

elif dataset_type == "SmoothedUniform":
    random_variable = SmoothedUniform(mutual_information, dimension, dimension, randomize_interactions=randomize_interactions)

elif dataset_type == "UniformlyQuantized":
    from scipy.stats import norm
    
    random_variable = UniformlyQuantized(mutual_information, dimension, norm(loc=0.0, scale=1.0), randomize_interactions=randomize_interactions)

In [10]:
n_samples = 10000

## Estimating MI

Dataset and dataloader

In [11]:
x, y = random_variable.rvs(n_samples)
train_dataset = torch.utils.data.TensorDataset(
    torch.tensor(x, dtype=torch.float32),
    torch.tensor(y, dtype=torch.float32)
)

x, y = random_variable.rvs(n_samples)
test_dataset = torch.utils.data.TensorDataset(
    torch.tensor(x, dtype=torch.float32),
    torch.tensor(y, dtype=torch.float32)
)

In [12]:
batch_size = 512

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader  = torch.utils.data.DataLoader(test_dataset,  batch_size=batch_size, shuffle=False)

In [13]:
class DenseT(torchkld.mutual_information.MINE):
    def __init__(self, X_dim: int, Y_dim: int, inner_dim: int=256) -> None:
        super().__init__()

        self.model = torch.nn.Sequential(
            torch.nn.Linear(X_dim + Y_dim, inner_dim),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(inner_dim, inner_dim),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(inner_dim, 1)
        )        

    @torchkld.mutual_information.MINE.marginalizable
    def forward(self, x: torch.tensor, y: torch.tensor) -> torch.tensor:
        return self.model(torch.cat((x, y), dim=1))

Model

In [14]:
config = {}

config["projection_dim"] = 2
config["embedding_dim"] = dimension

config["discriminator_network"] = "DenseT"
config["discriminator_network_inner_dim"] = 256
config["discriminator_network_output_dim"] = 256

_discriminator_network_factory = {
    "SeparableT": lambda: SeparableT(
        config["projection_dim"] + config["embedding_dim"] * config["projection_dim"],
        config["projection_dim"] + config["embedding_dim"] * config["projection_dim"],
        inner_dim=config["discriminator_network_inner_dim"],
        output_dim=config["discriminator_network_output_dim"],
    ).to(device),
    "DenseT": lambda: DenseT(
        config["projection_dim"] + config["embedding_dim"] * config["projection_dim"],
        config["projection_dim"] + config["embedding_dim"] * config["projection_dim"],
        inner_dim=config["discriminator_network_inner_dim"]
    ).to(device),
    "AdditiveGaussainT": lambda: AdditiveGaussainT(p=0.99).to(device)
}

model = _discriminator_network_factory[config["discriminator_network"]]()

Loss

In [15]:
# Loss.
biased = False
ema_multiplier = 1.0e-2
marginalize = "permute" # "permute", "product"

losses = {
    "DonskerVaradhan": torchkld.loss.DonskerVaradhanLoss(biased=biased, ema_multiplier=ema_multiplier),
    "NWJ": torchkld.loss.NWJLoss(),
    "Nishiyama": torchkld.loss.NishiyamaLoss(),
    "InfoNCE": torchkld.loss.InfoNCELoss(),
}

loss_name = "DonskerVaradhan"
loss = losses[loss_name]

Optimizer

In [16]:
learning_rate = 1.0e-3
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

Training

In [17]:
import matplotlib
from matplotlib import pyplot as plt

In [18]:
# Number of epochs used to average the estimate.
average_epochs = 100

In [None]:
from collections import defaultdict
from IPython.display import clear_output
from tqdm import trange

n_epochs = 1000

history = defaultdict(list)
for epoch in trange(1, n_epochs + 1, mininterval=1):    
    # Training.

    mi = 0.0
    batches = 0
    
    for index, batch in enumerate(train_dataloader):
        x, y = batch
        batch_size = x.shape[0]
        
        optimizer.zero_grad()

        Q_X, _ = np.linalg.qr(np.random.randn(x.shape[-1], config["projection_dim"]))
        Q_Y, _ = np.linalg.qr(np.random.randn(y.shape[-1], config["projection_dim"]))

        Q_X = torch.tensor(Q_X, device=x.device, dtype=torch.float32)
        Q_Y = torch.tensor(Q_Y, device=y.device, dtype=torch.float32)

        x = torch.cat([x @ Q_X, Q_X.flatten().repeat(x.shape[0], 1)], axis=-1)
        y = torch.cat([y @ Q_Y, Q_Y.flatten().repeat(y.shape[0], 1)], axis=-1)
        
        T_joined   = model(x.to(device), y.to(device))
        T_marginal = model(x.to(device), y.to(device), marginalize=marginalize)
        _loss = loss(T_joined, T_marginal)
        _loss.backward()
        
        optimizer.step()

        mi -= _loss.item()
        batches += 1
        
    #history["train_mutual_information"].append(model.get_mutual_information(train_dataloader, loss, device, marginalize=marginalize))
    #history["test_mutual_information"].append(model.get_mutual_information(test_dataloader, loss, device, marginalize=marginalize))
    history["test_mutual_information"].append(mi / batches)

    if epoch % 10 == 0:
        clear_output(wait=True)
        #plot_estimated_MI_trainig(mutual_information, np.arange(1, epoch+1), history["train_mutual_information"])
        plot_estimated_MI_trainig(mutual_information, np.arange(1, epoch+1), history["test_mutual_information"])
        print(f"Current estimate: {history['test_mutual_information'][-1]:.2f}")
        print(f"Running median: {np.median(history['test_mutual_information'][-average_epochs:]):.2f}")

  0%|▍                                                                                                                                          | 3/1000 [00:05<29:15,  1.76s/it]

In [21]:
print(f"Current estimate: {history['test_mutual_information'][-1]:.2f}")

Current estimate: 0.84
