Dirichlet Processes and Hierarchical Dirichlet Processes
==

In [1]:
import numpy as np

Make a simple Gaussian object that we can sample from

In [8]:
class Gaussian(object):
    def __init__(self,mu,sd):
        self.mu = mu
        self.sd = sd
    def sample(self):
        # To sample from a gaussian with standard deviation s
        # and mean mu, sample from a standard one, multiple by s and add mu
        return np.random.randn()*self.sd + self.mu

Make a Dirichlet Process object

In [26]:
class DP(object):
    def __init__(self,base,alpha):
        self.base = base # Must be something with a sample() method
        self.alpha = alpha
        self.N = 0 # Keep track of the number of samples
        self.samples = [] # And the samples themselves
    def sample(self):
        if self.N == 0:
            self.samples.append(self.base.sample())
            self.N += 1
        else:
            # Compute p(new)
            p_new = float(self.alpha)/float(self.N + self.alpha)
            if np.random.rand() <= p_new:
                # Generate from the base
                self.samples.append(self.base.sample())
                self.N += 1
            else:
                # Pick one of the previous ones uniformly
                # This is the same as picking the unique ones
                # prorportional to the number of times they've been samples
                pos = np.random.randint(self.N)
                self.samples.append(self.samples[pos])
                self.N += 1
        return self.samples[-1]
    def nice_output(self):
        # Return the unique values
        # And the number of times each was selected
        # Returns a list of tuples
        unique = {}
        for s in self.samples:
            if not s in unique:
                unique[s] = 1
            else:
                unique[s] += 1
        u_c = zip(unique.keys(),unique.values())
        return sorted(u_c,key = lambda x: x[1], reverse = True)
    def reset(self):
        self.N = 0
        self.samples = []
        
            

Demonstrate the DP -- varying alpha will change the number of unique values sampled

In [24]:
g = Gaussian(0,100) # make the base object
d = DP(g,10.0)
for i in range(100):
    d.sample()
op = d.nice_output()
for v,c in op:
    print v,c

-207.229957516 9
-79.3268286841 9
60.4504331168 8
-78.9814961785 8
97.6532899876 6
-37.9716379396 6
10.8641822076 6
8.17541885598 5
127.460503987 5
-22.5670330183 5
72.6510928538 4
10.6281397734 4
-115.44445357 4
21.3661648854 4
-69.585098738 3
-12.98585771 2
-153.715874257 2
-5.04500494451 2
74.1317816759 2
-60.435099588 1
-69.3381549986 1
39.80566167 1
55.581807788 1
-31.9250139845 1
156.378187127 1


To make a hierarchical DP, simply use another DP as the base

In [33]:
bottom_dps = []
# make a base dp with alpha = 10
base_dp = DP(g,10.0)
# make 5 dps with alpha = 1 and the base dp as their base
for i in range(5):
    bottom_dps.append(DP(base_dp,1.0))

Sample from them

In [34]:
for bottom_dp in bottom_dps:
    for i in range(100):
        bottom_dp.sample()

Display the top level samples. The total count here is the total number of tables across all restaurants...

In [36]:
for v,c in base_dp.nice_output():
    print v,c

59.8566399219 8
253.527008094 5
85.31922073 3
13.0800460328 2
-139.048187902 2
-25.9265423117 2
-203.014057244 1
40.5659593236 1
47.3634316806 1
50.52657938 1
28.0819836225 1
-94.4238591997 1
-59.0510246964 1
142.482242432 1


Print the counts in the individual DPs.
Counts are now the total number of people sitting at tables serving a particular dish.

In [38]:
for i,bottom_dp in enumerate(bottom_dps):
    print "Lower DP {}".format(i)
    for v,c in bottom_dp.nice_output():
        print v,c
    print
    print

Lower DP 0
50.52657938 26
-94.4238591997 25
47.3634316806 21
253.527008094 17
-59.0510246964 10
40.5659593236 1


Lower DP 1
59.8566399219 87
-203.014057244 13


Lower DP 2
28.0819836225 44
253.527008094 31
59.8566399219 19
13.0800460328 3
-139.048187902 2
85.31922073 1


Lower DP 3
-25.9265423117 66
59.8566399219 19
-139.048187902 12
85.31922073 2
253.527008094 1


Lower DP 4
142.482242432 64
59.8566399219 23
85.31922073 9
13.0800460328 4


