In [None]:
# default_exp joint_entropy

In [None]:
# hide
import blackhc.project.script

Appended /home/blackhc/PycharmProjects/blackhc.batchbald/src to paths
Switched to directory /home/blackhc/PycharmProjects/blackhc.batchbald
%load_ext autoreload
%autoreload 2


# joint_entropy
> Compute joint entropy estimates

Module to help compute joint entropies for dependent categorical variables given via a density $p((y_i)_i|w))$ in the Bayesian setting. We compute the density $p((y_i)_i)$ by marginalizing over $w$.

Two classes are provided:
    * exact computation (which works for up 5 to joint variables);
    * estimate using MC sampling of configurations.

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
# exporti
import torch
from toma import toma

In [None]:
import torch
import itertools
import math
import numpy as np
from toma import toma

To run tests, we need a few distributions to run tests with.

In [None]:
def get_mixture_prob_dist(p1, p2, m):
  return (1. - m) * np.asarray(p1) + m * np.asarray(p2)

p1=[0.1,0.2,0.2,0.5]
p2=[0.5,0.2,0.1,0.2]
get_mixture_prob_dist(p1,p2,0.5)

Number of inference samples `K`:

In [None]:
K=20

In [None]:
p1=[0.1,0.2,0.2,0.5]
p2=[0.5,0.2,0.1,0.2]
y1_ws=[get_mixture_prob_dist(p1,p2,m) for m in np.linspace(0, 1, K)]

In [None]:
p1=[0.1,0.6,0.2,0.1]
p2=[0.0,0.5,0.5,0.0]
y2_ws=[get_mixture_prob_dist(p1,p2,m) for m in np.linspace(0, 1, K)]

In [None]:
def nested_to_tensor(*l):
    return torch.stack(list(map(torch.as_tensor, l))) 

In [None]:
ys_ws=nested_to_tensor(y1_ws, y2_ws, y1_ws,y2_ws, y1_ws,y2_ws, y1_ws,y2_ws)

## Computing exact joint entropies

To compute exact joint entropies, we have to compute all possible configurations of the $y_i$ and evaluate $p(y_1, \dots, y_n)$ by averaging over $p(y_1, \dots, y_n|w)$.

The number of samples $M=C^N$, where $N$ is the number of variables in the joint and $C$ is the number of classes.

For this, we provide a class `ExactJointEntropy` that takes $K$ and starts with no variables in the joint.

In [None]:
# exports

class ExactJointEntropy:
    """Random variables (all with the same # of categories $C$) can be added via `ExactJointEntropy.add_variables`.

`ExactJointEntropy.compute` computes the joint entropy.

`ExactJointEntropy.compute_batch` computes the joint entropy of the added variables with each of the variables in the provided batch probabilities in turn."""
    joint_probs_M_K: torch.Tensor

    def __init__(self, joint_probs_M_K: torch.Tensor):
        self.joint_probs_M_K = joint_probs_M_K

    @staticmethod
    def empty(K: int, device=None, dtype=None) -> 'ExactJointEntropy':
        return ExactJointEntropy(torch.ones((1, K), device=device,
                                            dtype=dtype))

    def compute(self) -> torch.Tensor:
        probs_M = torch.mean(self.joint_probs_M_K, dim=1, keepdim=False)
        nats_M = -torch.log(probs_M) * probs_M
        entropy = torch.sum(nats_M)
        return entropy

    def add_variables(self, probs_N_K_C: torch.Tensor) -> 'ExactJointEntropy':
        assert self.joint_probs_M_K.shape[1] == probs_N_K_C.shape[1]

        N, K, C = probs_N_K_C.shape
        joint_probs_K_M_1 = self.joint_probs_M_K.t()[:, :, None]

        # Using lots of memory.
        for i in range(N):
            probs_i__K_1_C = probs_N_K_C[i][:,
                                            None, :].to(joint_probs_K_M_1,
                                                        non_blocking=True)
            joint_probs_K_M_C = joint_probs_K_M_1 * probs_i__K_1_C
            joint_probs_K_M_1 = joint_probs_K_M_C.reshape((K, -1, 1))

        joint_probs_M_K = joint_probs_K_M_1.squeeze(2).t()
        return ExactJointEntropy(joint_probs_M_K)

    def compute_batch(self, probs_B_K_C: torch.Tensor,
                            output_entropies_B=None):
        assert self.joint_probs_M_K.shape[1] == probs_B_K_C.shape[1]
        
        B, K, C = probs_B_K_C.shape
        M = self.joint_probs_M_K.shape[0]
        
        if output_entropies_B is None:
            output_entropies_B = torch.empty(B, dtype=probs_B_K_C.dtype, device=probs_B_K_C.device)

        @toma.execute.chunked(probs_B_K_C,
                              initial_step=1024,
                              dimension=0,
                              context="ExactJointEntropy.batch_joint_entropy")
        def chunked_joint_entropy(chunked_probs_b_K_C: torch.Tensor,
                                  start: int, end: int):
            b = chunked_probs_b_K_C.shape[0]

            probs_b_M_C = torch.empty((b, M, C),
                                      dtype=self.joint_probs_M_K.dtype,
                                      device=self.joint_probs_M_K.device)
            for i in range(b):
                torch.matmul(self.joint_probs_M_K,
                             probs_B_K_C[i].to(self.joint_probs_M_K,
                                               non_blocking=True),
                             out=probs_b_M_C[i])
            probs_b_M_C /= K

            output_entropies_B[start:end].copy_(torch.sum(
                -torch.log(probs_b_M_C) * probs_b_M_C, dim=(1, 2)),
                                                non_blocking=True)
            
        return output_entropies_B

In [None]:
show_doc(ExactJointEntropy.add_variables)
show_doc(ExactJointEntropy.compute)
show_doc(ExactJointEntropy.compute_batch)

<h4 id="ExactJointEntropy.add_variables" class="doc_header"><code>ExactJointEntropy.add_variables</code><a href="__main__.py#L26" class="source_link" style="float:right">[source]</a></h4>

> <code>ExactJointEntropy.add_variables</code>(**`probs_N_K_C`**:`Tensor`)



<h4 id="ExactJointEntropy.compute" class="doc_header"><code>ExactJointEntropy.compute</code><a href="__main__.py#L20" class="source_link" style="float:right">[source]</a></h4>

> <code>ExactJointEntropy.compute</code>()



<h4 id="ExactJointEntropy.compute_batch" class="doc_header"><code>ExactJointEntropy.compute_batch</code><a href="__main__.py#L43" class="source_link" style="float:right">[source]</a></h4>

> <code>ExactJointEntropy.compute_batch</code>(**`probs_B_K_C`**:`Tensor`, **`output_entropies_B`**=*`None`*)



### Example usages:

In [None]:
exact_joint_entropy = ExactJointEntropy.empty(K, dtype=torch.double)
entropy = exact_joint_entropy.add_variables(ys_ws[:4]).compute()
assert np.isclose(entropy, 4.6479, atol=0.1)
entropy

tensor(4.6479, dtype=torch.float64)

In [None]:
exact_joint_entropy = ExactJointEntropy.empty(K, dtype=torch.float)
entropy = exact_joint_entropy.add_variables(ys_ws[:4]).compute()
assert np.isclose(entropy, 4.6479, atol=0.1)
entropy
exact_joint_entropy.expand_joint(ys_ws[:4]).compute()

tensor(4.6479)

In [None]:
exact_joint_entropy = ExactJointEntropy.empty(K, dtype=torch.float)
entropies =exact_joint_entropy.add_variables(ys_ws[:4]).compute_batch(ys_ws)
assert np.allclose(entropies, [5.9735, 5.6362, 5.9735, 5.6362, 5.9735, 5.6362, 5.9735, 5.6362])
entropies

tensor([5.9735, 5.6362, 5.9735, 5.6362, 5.9735, 5.6362, 5.9735, 5.6362],
       dtype=torch.float64)

## Computing  approximate joint entropies