# PyTorch mutual information neural estimation tests

Trivial tests with multivariate Gaussian and uniform distribution

In [None]:
import sys
sys.path.append("../python")

In [None]:
import numpy as np

In [None]:
import torch
import torchvision
import torchkld

In [None]:
import mutinfo

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
#device = "cpu"
print("Device: " + device)
print(f"Devices count: {torch.cuda.device_count()}")

In [None]:
from tqdm import tqdm, trange

In [None]:
from misc.modules import *
from misc.utils import *
from misc.plots import *

In [None]:
config = {}

## Dataset

Dataset and dataloader

In [None]:
from mutinfo.distributions.base import CorrelatedUniform

config["mutual_information"] = 3.0
config["n_copies"] = 1

In [None]:
image_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    #torchvision.transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

In [None]:
config["dataset"] = "MNIST"
#config["dataset"] = "CIFAR10"
config["n_classes"] = 10
aggregate = lambda x_list: torch.cat(x_list, dim=1)

train_dataset = getattr(torchvision.datasets, config["dataset"])(root="./.cache", download=True, transform=image_transform)
test_dataset  = getattr(torchvision.datasets, config["dataset"])(root="./.cache", download=True, transform=image_transform, train=False)

In [None]:
config["batch_size"] = 512

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
test_dataloader  = torch.utils.data.DataLoader(test_dataset,  batch_size=config["batch_size"], shuffle=False)

## Estimating MI

Model

In [None]:
class MNIST_T_network(torchkld.mutual_information.MINE):
    def __init__(self, X_channels: int=1, Y_channels: int=1) -> None:
        super().__init__()
        
        self.conv2d_1 = torch.nn.Conv2d(X_channels + Y_channels, 64, 3, padding='same')
        self.conv2d_2 = torch.nn.Conv2d(64, 128, 3, padding='same')
        self.conv2d_3 = torch.nn.Conv2d(128, 128, 3, padding='same')

        self.linear_1 = torch.nn.Linear(128*7*7, 128)
        self.linear_2 = torch.nn.Linear(128, 1)

        self.pooling = torch.nn.AvgPool2d(2)
        self.activation = torch.nn.LeakyReLU()

    def forward(self, x: torch.Tensor, y: torch.Tensor, marginalize: bool=False) -> torch.Tensor:
        x, y = super().forward(x, y, marginalize)
        
        z = torch.cat([x, y], axis=1)

        z = self.conv2d_1(z)
        z = self.pooling(z)
        z = self.activation(z)

        z = self.conv2d_2(z)
        z = self.pooling(z)
        z = self.activation(z)

        z = self.conv2d_3(z)
        z = self.activation(z)

        z = z.flatten(start_dim=1)
        
        z = self.linear_1(z)
        z = self.activation(z)

        z = self.linear_2(z)

        return z

In [None]:
model = MNIST_T_network(config["n_copies"], config["n_copies"]).to(device)
total_parameters = sum(parameter.numel() for parameter in model.parameters())
print(f"Total parameters: {total_parameters}")

In [None]:
config["n_parameters"] = total_parameters

Loss

In [None]:
# Loss.
config["biased"] = False
config["ema_multiplier"] = 1.0e-2
config["marginalize"] = "permute" # "permute", "product"

losses = {
    "DonskerVaradhan": torchkld.loss.DonskerVaradhanLoss(biased=config["biased"], ema_multiplier=config["ema_multiplier"]),
    "NWJ": torchkld.loss.NWJLoss(),
    "Nishiyama": torchkld.loss.NishiyamaLoss(),
    "InfoNCE": torchkld.loss.InfoNCELoss(),
}

config["loss_name"] = "DonskerVaradhan"
loss = losses[config["loss_name"]]

Optimizer

In [None]:
config["learning_rate"] = 1.0e-3
optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])

Training

In [None]:
import matplotlib
from matplotlib import pyplot as plt

In [None]:
# Total number of epochs.
config["n_epochs"] = 500

# Number of epochs used to average the estimate.
config["average_epochs"] = 20

In [None]:
from mutinfo.distributions.base import CorrelatedUniform

def apply_noise(
    samples,
    labels,
    governing_random_variable=CorrelatedUniform(
        mutual_information=config["mutual_information"],
        X_dim=config["n_copies"],
        Y_dim=config["n_copies"],
        randomize_interactions=False,
        shuffle_interactions=False,
    )
):
    batch_size = samples.shape[0]
    device = samples.device
    
    parameters_1, parameters_2 = governing_random_variable.rvs(batch_size)
    noise_1 = 1.0 - torch.tensor(parameters_1, dtype=torch.float32, device=device)[...,None,None]# * torch.rand(samples.shape, device=device)
    noise_2 = 1.0 - torch.tensor(parameters_2, dtype=torch.float32, device=device)[...,None,None]# * torch.rand(samples.shape, device=device)

    return samples * noise_1, samples[torch.randperm(batch_size, device=device)] * noise_2

In [None]:
from collections import defaultdict
from IPython.display import clear_output
from tqdm import trange

history = defaultdict(list)
for epoch in trange(1, config["n_epochs"] + 1, mininterval=1):    
    # Training.
    for index, batch in enumerate(train_dataloader):
        x, y = batch
        batch_size = x.shape[0]

        x, y = x.to(device), y.to(device)
        
        x_1, x_2 = apply_noise(x, y)
        
        optimizer.zero_grad()
        
        T_joined   = model(x_1.to(device), x_2.to(device))
        T_marginal = model(x_1.to(device), x_2.to(device), marginalize=config["marginalize"])
        _loss = loss(T_joined, T_marginal)
        _loss.backward()
        
        optimizer.step()

    history["train_mutual_information"].append(
        model.get_mutual_information(
            train_dataloader,
            loss,
            device,
            marginalize=config["marginalize"],
            transform=apply_noise
        )
    )
    history["test_mutual_information"].append(
        model.get_mutual_information(
            test_dataloader,
            loss,
            device,
            marginalize=config["marginalize"],
            transform=apply_noise
        )
    )

    if epoch % 5 == 0:        
        clear_output(wait=True)
        plot_estimated_MI_trainig(config["mutual_information"], np.arange(1, epoch+1), history["train_mutual_information"])
        plot_estimated_MI_trainig(config["mutual_information"], np.arange(1, epoch+1), history["test_mutual_information"])
        print(f"Current estimate: {history['test_mutual_information'][-1]:.3f}")
        print(f"Running median: {np.median(history['test_mutual_information'][-config['average_epochs']:]):.3f}")

### Saving the results

In [None]:
config

In [None]:
from datetime import datetime

experiment_name = f"{config['dataset']}_{config['n_copies']}_{config['mutual_information']:.1f}__{datetime.now().strftime('%d-%b-%Y_%H:%M:%S')}"
print(experiment_name)

In [None]:
import os
from pathlib import Path

data_path = Path(os.path.abspath(os.path.join(os.path.abspath(os.getcwd()), "../../data")))
experiment_path = data_path / f"{config['dataset']}" / config['loss_name'] / experiment_name

In [None]:
save_results(history, config, experiment_path, average_epochs=config['average_epochs'])