In [1]:
%matplotlib inline

Sampling from a Hierarchical Dirichlet Process

Sampling from a Hierarchical Dirichlet Process
As we saw earlier the Dirichlet process describes the distribution of a random probability distribution. The Dirichlet process takes two parameters: a base distribution $H_0$ and a dispersion parameter $\alpha$. A sample from the Dirichlet process is itself a probability distribution that looks like $H_0$. On average, the larger $\alpha$ is, the closer a sample from $\mathrm{DP}\left(\alpha H_0\right)$ will be to $H_0$.
Suppose we're feeling masochistic and want to input a distribution sampled from a Dirichlet process as base distribution to a new Dirichlet process. (It will turn out that there are good reasons for this!) Conceptually this makes sense. But can we construct such a thing in practice? Said another way, can we build a sampler that will draw samples from a probability distribution drawn from these nested Dirichlet processes? We might initially try construct a sample (a probability distribution) from the first Dirichlet process before feeding it into the second.
But recall that fully constructing a sample (a probability distribution!) from a Dirichlet process would require drawing a countably infinite number of samples from $H_0$ and from the beta distribution to generate the weights. This would take forever, even with Hadoop!
Dan Roy, et al helpfully described a technique of using stochastic memoization to construct a distribution sampled from a Dirichlet process in a just-in-time manner. This process provides us with the equivalent of the Scipy rvs method for the sampled distribution. Stochastic memoization is equivalent to the Chinese restaurant process: sometimes you get seated an an occupied table (i.e. sometimes you're given a sample you've seen before) and sometimes you're put at a new table (given a unique sample).
Here is our memoization class again:


In [2]:
from numpy.random import choice
from scipy.stats import beta

class DirichletProcessSample():
    def __init__(self, base_measure, alpha):
        self.base_measure = base_measure
        self.alpha = alpha

        self.cache = []
        self.weights = []
        self.total_stick_used = 0.

    def __call__(self):
        remaining = 1.0 - self.total_stick_used
        i = DirichletProcessSample.roll_die(self.weights + [remaining])
        if i is not None and i < len(self.weights) :
            return self.cache[i]
        else:
            stick_piece = beta(1, self.alpha).rvs() * remaining
            self.total_stick_used += stick_piece
            self.weights.append(stick_piece)
            new_value = self.base_measure()
            self.cache.append(new_value)
            return new_value

    @staticmethod
    def roll_die(weights):
        if weights:
            return choice(range(len(weights)), p=weights)
        else:
            return None