Here we will test the Dirichlet Process as examples from the tutorial of TFP.


In [11]:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions

Copy pasting the examples from the webpage:
https://www.tensorflow.org/probability/api_docs/python/tfp/distributions/DirichletMultinomial

Multinomial Dirichlet distribution is a distribution over...
histograms / counts of multiclass indicator vectors.

That is: 
* Consider a K>2 number of classes
* Consider a multinomial distribution over these K classes --> [p_0, p_1, ..., p_K-1]
* Consider indicator vectors of length K to be samples from the multinomial distribution 
* Each vector has a 1 in the position corresponding to the k'th class of [0 to K-1]
* Consider drawing N of these vectors from such a distribution
* Count the classes of these N vectors. It gives a K length **count** vector: [n_0, n_1, ..., n_K-1]
* For a particular N, a multinomial-dirichlet distribution gives the probability for every such possible count vector

*Note: If the count of any particular class is zero in the count vector, the probability computation seems to fail in the math (Bishop)!!*

####  Invoke the distribution with some typical settings

In [59]:
alpha = [1., 2., 3.] # Concentration
n = 60. # Total count
dist = tfd.DirichletMultinomial(n, alpha)

#### Compute some examples

In [80]:
# eg.1 counts same shape as alpha.
# counts = [0., 0., 2.]
# dist.prob(counts)  # Shape []

# eg.2 alpha will be broadcast to [[1., 2., 3.], [1., 2., 3.]] to match counts.
# counts = [[1., 59., 0.], [10., 20., 30.]]
counts = tf.placeholder(tf.float32, shape=(2,3))
P = dist.prob(counts)  # Shape [2]

# Draw a sample from the distribution
smpl = dist.sample()
P_smpl = dist.prob(smpl)

# alpha will be broadcast to shape [5, 7, 3] to match counts.
# counts = [[...]]  # Shape [5, 7, 3]
# dist.prob(counts)  # Shape [5, 7]


####  Run the session and print some values

In [100]:
counts_ = [[1., 59., 0.], [10., 20., 30.]]
sess = tf.Session()
[sP,ssmpl, sP_smpl] = sess.run([P,smpl, P_smpl], feed_dict={counts:counts_})
print('Probability of \n{} is:\n{}'.format(counts_, sP))
print('\n\nA Sample from the distribution: {}'.format(ssmpl))
print('Probability of above sample: {}'.format(sP_smpl))

Probability of 
[[1.0, 59.0, 0.0], [10.0, 20.0, 30.0]] is:
[7.2636626e-06 1.2610100e-03]


A Sample from the distribution: [ 5. 45. 10.]
Probability of above sample: 0.00036755765904672444
