<a href="https://colab.research.google.com/github/VRehnberg/mutual-information/blob/main/mutual_information.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
running_on_colab = "google.colab" in str(get_ipython())
if running_on_colab:
    %pip install git+https://github.com/VRehnberg/torch-utils.git
else:
    !git clone https://github.com/VRehnberg/torch-utils && cp -alf torch-utils/src/torchutils . && rm -rf torch-utils


Cloning into 'torch-utils'...
remote: Enumerating objects: 111, done.[K
remote: Counting objects: 100% (111/111), done.[K
remote: Compressing objects: 100% (67/67), done.[K
remote: Total 111 (delta 37), reused 90 (delta 29), pack-reused 0[K
Receiving objects: 100% (111/111), 20.25 KiB | 2.89 MiB/s, done.
Resolving deltas: 100% (37/37), done.


In [2]:
import torch
from torch import nn, linalg
from torch.autograd.functional import jacobian

from torchutils import batched_jacobian
from torchutils.kmeans import kmeans

# Mutual Information
This notebook was written to investigate a few different ways to estimate the mutual information between random variables from sampled data. This is then compared with the true mutual information.

1. Analytical mutual information TODO
2. Jacobian/Hessian based mutual information TODO
3. Quantized/binned mutual information TODO

## Network

In [3]:
class NormalLinear(nn.Linear):
    def __init__(self, input_shape, output_shape):
        super().__init__(input_shape, output_shape)
        self.input_shape = input_shape
    
    def output_mutual_information(self, partition):
        '''
            partition (BoolTensor): n_modules × output_shape
        '''
        if (partition.int().sum(0) > 1).any():
            raise NotImplementedError("MI for overlapping modules not implemented.")

        weight, bias = self.parameters()
        #mean = bias
        cov_full = weight @ weight.T

        # Check ranks (this is unescessary if full rank)
        rank0 = linalg.matrix_rank(cov_full)
        rank1 = torch.sum(torch.hstack([
            linalg.matrix_rank(cov_full[mask, :][:,mask]) for mask in partition
        ]))
        if rank0 < rank1:
            return float("inf")
        elif rank1 != cov_full.size(0):
            raise NotImplementedError()
        
        # Compute MI
        det0 = cov_full.det()
        det1 = torch.prod(torch.hstack([
            cov_full[mask, :][:,mask].det() for mask in partition
        ]))
        return -0.5 * torch.log(det0 / det1)
        
    def sample(self, batch_shape):
        return torch.randn(batch_shape, self.input_shape, requires_grad=True)


In [4]:
n_modules = 2
in_size, out_size = (15, 7)
nl = NormalLinear(in_size, out_size)
partition = nn.functional.one_hot(torch.randint(n_modules, (out_size,))).bool().T
with torch.no_grad():
    print(nl.output_mutual_information(partition))


tensor(0.5537)


## Mutual information

In [5]:
def jacobian_mutual_information(jac_full, jac_blocks):
    assert jac_full.ndim == 3

    # Covariences
    def det(jac):
        return jac.bmm(jac.transpose(1, 2)).det()

    det_full = det(jac_full)
    det_blocks = torch.hstack([det(jac_block).view(-1, 1) for jac_block in jac_blocks])

    # Local mutual information
    jmi = -0.5 * torch.log(det_full / torch.prod(det_blocks, 1))

    return jmi.mean(0)


In [6]:
jac_full = torch.rand(7, 10, 30)
partition = torch.randint(3, (10,))
jac_blocks = [jac_full[:, id==partition, :] for id in torch.unique(partition)]
jacobian_mutual_information(jac_full, jac_blocks)

tensor(2.2743)

In [7]:
help(kmeans)

Help on function kmeans in module torchutils.kmeans:

kmeans(points, k, global_start=1, max_iter=None, reinitialize_empty_clusters=True, verbose=False)
    Naive K-means with parallelized global start.
    
    Arguments:
        points (Tensor): Number of points times number of features.
        k (int): Number of clusters.
        global_start (int): Different intializations that are run in parallel.
            Default 1.
        reinatialize_empty_clusters (bool): If empty clusters are reinitialized.
            Default True.
        verbose (bool): Controls verbosity. Default True. 
    
    Returns:
        i_cluster (LongTensor): What cluster each point belongs to.



In [24]:
def quantized_mutual_information(activations, partition, n_bins, cluster_method="kmeans"):
    if not isinstance(partition, torch.BoolTensor):
        raise TypeError("Datatype of partition should be BoolTensor.")
    device = activations.device
    batch_size, n_features = activations.shape
    n_parts = partition.size(0)
    assert n_features == partition.size(1)
    assert batch_size >= n_bins

    if cluster_method=="kmeans":
        def quantize(points):
            # Cosine similarity
            points = (points - points.mean(0, keepdim=True)) / points.std(0, keepdim=True)
            return kmeans(points, n_bins)
    else:
        raise ValueError()

    quantized_activations = torch.zeros((batch_size, n_parts), dtype=int, device=device)
    for i_part, mask in enumerate(partition):
        quantized_activations[:, i_part] = quantize(activations[:, mask])

    # Compute pmfs
    activations_onehot = nn.functional.one_hot(quantized_activations).float()
    p_xy = torch.einsum("bij, bkl -> ikjl", activations_onehot, activations_onehot) / batch_size

    print(activations_onehot.shape)
    print(p_xy.shape)

    # Compute pairwise mutual information
    p_x = torch.einsum("iikk -> ik", p_xy)
    qmin = p_xy.div(p_x.unsqueeze(0).unsqueeze(2)).div(p_x.unsqueeze(1).unsqueeze(3)).pow(p_xy).log().sum((2, 3))
    return qmin
    

n_modules = 2
in_size, out_size = (15, 7)
activations = torch.rand(2000, out_size)
partition = nn.functional.one_hot(torch.randint(n_modules, (out_size,))).bool().T
with torch.no_grad():
    print(quantized_mutual_information(activations, partition, 10))


torch.Size([2000, 2, 10])
torch.Size([2, 2, 10, 10])
tensor([[2.2919, 0.0236],
        [0.0236, 2.2965]])


In [13]:
# Gaussian test set-up
n_modules = 2
in_size, out_size = (15, 7)
batch_size = 20
network = NormalLinear(in_size, out_size)
x = network.sample(batch_size)
activations = network(x)

partition = nn.functional.one_hot(torch.randint(n_modules, (out_size,))).bool().T

In [20]:

# True mutual information
mi = network.output_mutual_information(partition)
print(f"MI: {mi}")

# Local mutual information through Jacobian
jac_full = batched_jacobian(network, x)
jac_blocks = [jac_full[:, mask, :] for mask in partition]
lmi = jacobian_mutual_information(jac_full, jac_blocks)
print(f"LMI: {lmi}")

# Quantized mutual information through clustering
with torch.no_grad():
    qmi = quantized_mutual_information(activations, partition, 10)
print(f"QMI: {qmi}") #TODO


MI: 0.7629395723342896
LMI: 0.7629393935203552
QMI: tensor([[2.9957, 2.9957],
        [2.9957, 2.9957]])
