<a href="https://colab.research.google.com/github/VRehnberg/mutual-information/blob/main/mutual_information.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [16]:
%%bash
pip install git+https://github.com/VRehnberg/torch-utils.git

Collecting git+https://github.com/VRehnberg/torch-utils.git
  Cloning https://github.com/VRehnberg/torch-utils.git to /tmp/pip-req-build-6hy6n1it
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
    Preparing wheel metadata: started
    Preparing wheel metadata: finished with status 'done'
Building wheels for collected packages: torch-utils
  Building wheel for torch-utils (PEP 517): started
  Building wheel for torch-utils (PEP 517): finished with status 'done'
  Created wheel for torch-utils: filename=torch_utils-0.0.1-cp37-none-any.whl size=5808 sha256=8def2f286943b2280c3383b628f09d38dbb4732334f369457d728eded0e3c231
  Stored in directory: /tmp/pip-ephem-wheel-cache-mcpxp9f1/wheels/69/5c/fd/8fe71800b6c3026c5683efb40102541c2e0d5a44121b15e6fa
Successfully built torch-utils


  Running command git clone -q https://github.com/VRehnberg/torch-utils.git /tmp/pip-req-build-6hy6n1it


In [17]:
import torch
from torch import nn, linalg
from torch.autograd.functional import jacobian

from torchutils.kmeans import kmeans

# Mutual Information
This notebook was written to investigate a few different ways to estimate the mutual information between random variables from sampled data. This is then compared with the true mutual information.

1. Analytical mutual information TODO
2. Jacobian/Hessian based mutual information TODO
3. Quantized/binned mutual information TODO

## Network

In [18]:
class NormalLinear(nn.Linear):
    def __init__(self, input_shape, output_shape):
        super().__init__(input_shape, output_shape)
        self.input_shape = input_shape
    
    def output_mutual_information(self, partition):
        '''
            partition (BoolTensor): n_modules × output_shape
        '''
        if (partition.int().sum(0) > 1).any():
            raise NotImplementedError("MI for overlapping modules not implemented.")

        weight, bias = self.parameters()
        #mean = bias
        cov_full = weight @ weight.T

        # Check ranks (this is unescessary if full rank)
        rank0 = linalg.matrix_rank(cov_full)
        rank1 = torch.sum(torch.hstack([
            linalg.matrix_rank(cov_full[mask, :][:,mask]) for mask in partition
        ]))
        if rank0 < rank1:
            return float("inf")
        elif rank1 != cov_full.size(0):
            raise NotImplementedError()
        
        # Compute MI
        det0 = cov_full.det()
        det1 = torch.prod(torch.hstack([
            cov_full[mask, :][:,mask].det() for mask in partition
        ]))
        return -0.5 * torch.log(det0 / det1)
        
    def sample(self, batch_shape):
        return torch.randn(batch_shape, self.input_shape)


In [19]:
n_modules = 2
in_size, out_size = (15, 7)
nl = NormalLinear(in_size, out_size)
partition = nn.functional.one_hot(torch.randint(n_modules, (out_size,))).bool().T
with torch.no_grad():
    print(nl.output_mutual_information(partition))


tensor(0.4771)


## Network utilities

In [20]:
def get_activations(network, x):
    activations = []
    hooks = []
    for name, m in network.named_modules():
        if isinstance(m, nn.Linear):
            save_activations = lambda mod, inp, out: activations.append(out)
            hooks.append(m.register_forward_hook(save_activations))
    
    network(x)
    for h in hooks:
        h.remove()
    
    return torch.hstack(activations)


def batched_jacobian(func, x, to_embedding=False, **kwargs):

    # Copmute batched Jacobian
    new_func = lambda x: func(x).sum(0)
    jac = jacobian(new_func, x, **kwargs)
    
    # Move batch dimension first
    dims = torch.arange(jac.ndim, device=device)
    batch_dim = dims[-x.ndim]
    jac.movedims(dims[:batch_dim + 1], [batch_dim, *dims[:batch_dim]])

    return jac


def clean_mem():
    gc.collect()
    if device.type == "cuda":
        torch.cuda.empty_cache()


## Mutual information

In [21]:
def jacobian_mutual_information(jac_full, jac_blocks):
    assert jac_full.ndim == 3

    # Covariences
    def det(jac):
        return jac.bmm(jac.transpose(1, 2)).det()

    det_full = det(jac_full)
    det_blocks = torch.stack([det(jac_block) for jac_block in jac_blocks], 1)

    # Local mutual information
    jmi = -0.5 * torch.log(det_full / torch.prod(det_blocks, 1))

    return jmi.mean(0)


In [22]:
jac_full = torch.rand(7, 10, 30)
partition = torch.randint(3, (10,))
jac_blocks = [jac_full[:, id==partition, :] for id in torch.unique(partition)]
jacobian_mutual_information(jac_full, jac_blocks)

tensor(2.6649)

In [23]:
help(kmeans)

Help on function kmeans in module torchutils.kmeans:

kmeans(points, k, global_start=1, max_iter=None, reinitialize_empty_clusters=True, verbose=False)
    Naive K-means with parallelized global start.
    
    Arguments:
        points (Tensor): Number of points times number of features.
        k (int): Number of clusters.
        global_start (int): Different intializations that are run in parallel.
            Default 1.
        reinatialize_empty_clusters (bool): If empty clusters are reinitialized.
            Default True.
        verbose (bool): Controls verbosity. Default True. 
    
    Returns:
        i_cluster (LongTensor): What cluster each point belongs to.



In [27]:
def quantized_mutual_information(activations, partition, n_bins, cluster_method="kmeans"):
    if not isinstance(partition, torch.BoolTensor):
        raise TypeError("Datatype of partition should be BoolTensor.")
    device = activations.device
    batch_size, n_features = activations.shape
    n_parts = partition.size(0)
    assert n_features == partition.size(1)
    assert batch_size >= n_bins

    if cluster_method=="kmeans":
        def quantize(points):
            # Cosine similarity
            points = (points - points.mean(0, keepdim=True)) / points.std(0, keepdim=True)
            return kmeans(points, n_bins ** n_parts).argsort(0)
    else:
        raise ValueError()

    quantized_activations = torch.zeros((batch_size, n_parts), dtype=int, device=device)
    for i_part, mask in enumerate(partition):
        quantized_activations[:, i_part] = quantize(activations[:, mask])

    # Compute pmfs
    activations_onehot = nn.functional.one_hot(quantized_activations).float()
    p_xy = torch.einsum("bij, bkl -> ikjl", activations_onehot, activations_onehot) / batch_size

    # Compute pairwise mutual information
    p_x = torch.einsum("iikk -> ik", p_xy)
    qmin = p_xy.div(p_x.unsqueeze(0).unsqueeze(2)).div(p_x.unsqueeze(1).unsqueeze(3)).pow(p_xy).log().sum((2, 3))
    return qmin
    

n_modules = 2
in_size, out_size = (15, 7)
activations = torch.rand(2000, out_size)
partition = nn.functional.one_hot(torch.randint(n_modules, (out_size,))).bool().T
with torch.no_grad():
    print(quantized_mutual_information(activations, partition, 10))


tensor([[ 920,  606],
        [1949, 1064],
        [1722, 1547],
        ...,
        [ 273,  825],
        [  18,  648],
        [ 197, 1965]])
tensor([[7.6008, 7.6008],
        [7.6008, 7.6008]])


In [None]:

    # v-- old below

    ## Quantize activations
    #batch_size, n_activations = activations.shape
    #assert n_activations == partition.numel()
    #assert batch_size >= n_bins
    #quantized_activations = n_bins * activations.argsort(0).argsort(0) // batch_size

    ## Compute pmfs
    #activations_onehot = nn.functional.one_hot(quantized_activations).float()
    #p_xy = torch.einsum("bij, bkl -> ikjl", activations_onehot, activations_onehot) / batch_size

    ## Compute pairwise mutual information
    #p_x = torch.einsum("iikk -> ik", p_xy)
    #qmin = p_xy.div(p_x.unsqueeze(0).unsqueeze(2)).div(p_x.unsqueeze(1).unsqueeze(3)).pow(p_xy).log().sum((2, 3))
    #return qmin
