In [1]:
import numpy as np
import scipy as sp
import tensorflow as tf
import tensorflow.contrib.distributions as ds

In [2]:
# Check if multivariate normal with diagonal covariance == univariate normals with independent components 

class GMM(object):
    def __init__(self, mu, sigma, weights, dim):
        # Required parameters 
        self.mu = mu
        self.sigma = sigma
        self.weights = weights
        self.dim = dim
        
        distributions = []
        for i in range(weights.shape[0]):
            mu_, sigma_ = self.mu[i] * np.ones(dim), self.sigma[i] * np.ones(dim)
            mvnd_i = tf.contrib.distributions.MultivariateNormalDiag(mu_, sigma_)
            distributions.append(mvnd_i)
        self.gauss_mix = tf.contrib.distributions.Mixture(tf.contrib.distributions.Categorical(probs=self.weights), distributions)    
              
#         self.x = tf.placeholder(tf.float64, [None, self.dim])
        
#         self.log_px()
            
    def log_px(self):
        # log_px = log(sum(exp(log(w_i) + log(p_i(x)))))
        log_px = []
        for i in range(weights.shape[0]):
            mu_, sigma_ = self.mu[i] * np.ones(dim), self.sigma[i] * np.ones(dim)
            mvn = ds.MultivariateNormalDiag(loc=mu_, scale_diag=sigma_)
            # Calculate log_px for each component
            log_px_i = mvn.log_prob(self.x) + tf.log(weights[i])
            if i == 0:
                log_px = tf.reshape(log_px_i, [-1, 1])
            else:
                log_px = tf.concat([log_px, tf.reshape(log_px_i, [-1, 1])], axis=1)
        self.grad = tf.gradients(tf.reduce_logsumexp(log_px, axis=1), [self.x])
        return self.grad
    
    def log_px_(self, x):
        y = tf.convert_to_tensor(x)
        return self.gauss_mix.log_prob(y)
       
    def d_log_px(self, x):
        # d_log_px = 1 / exp(log(sum(exp(log(w_i) + log(p_i(x)))))) 
        #            * sum(exp(log(w_i) + log(p_i(x)) + log(-(x - mu)/sigma^2)))
        # Use symbolic differentiation instead
        '''
        # To stabilise tf.log
        log_denominator = [] # = log_px
        log_numerator = []
        for i in range(weights.shape[0]):
            mu_, sigma_ = self.mu[i] * np.ones(dim), self.sigma[i] * np.ones(dim)
            mvn = ds.MultivariateNormalDiag(loc=mu_, scale_diag=sigma_)
            # Calculate log_px for each component
            log_px_i = tf.reduce_logsumexp(mvn.log_prob(self.x)) + tf.log(weights[i])
            log_denominator.append(log_px_i)
            # Calculate precision for each component 
            pre_i = (1./(sigma_ ** 2)) * np.eye(dim)
            # Calculate numerator for each component
            add_term = tf.matmul(-(self.x - mu_), pre_i)
            log_px_i_plus_add_term = log_px_i + tf.log(add_term)
            log_numerator.append(log_px_i_plus_add_term)
        denominator = tf.exp(tf.reduce_logsumexp(log_denominator))
        numerator = tf.reduce_logsumexp(log_numerator, axis=1)
        return log_numerator
        '''
        log_px = self.log_px(x)
        return tf.gradients(log_px, x)
    
    def d_log_px_(self, x):
        # d_log_px = 1 / exp(log(sum(exp(log(w_i) + log(p_i(x)))))) 
        #            * sum(exp(log(w_i) + log(p_i(x)) + log(-(x - mu)/sigma^2)))
        # Use symbolic differentiation instead
        y = tf.convert_to_tensor(x)
        return tf.gradients(self.gauss_mix.log_prob(y), [y])

In [3]:
sess = tf.InteractiveSession()
mu = np.array([1, -1]); sigma = np.sqrt(np.array([0.1, 0.05])); weights = np.array([1./3, 2./3]); dim=6
gmm = GMM(mu, sigma, weights, dim)

In [4]:
x = -0.1 * np.ones((10, 6)).astype(np.float64)
print x

[[-0.1 -0.1 -0.1 -0.1 -0.1 -0.1]
 [-0.1 -0.1 -0.1 -0.1 -0.1 -0.1]
 [-0.1 -0.1 -0.1 -0.1 -0.1 -0.1]
 [-0.1 -0.1 -0.1 -0.1 -0.1 -0.1]
 [-0.1 -0.1 -0.1 -0.1 -0.1 -0.1]
 [-0.1 -0.1 -0.1 -0.1 -0.1 -0.1]
 [-0.1 -0.1 -0.1 -0.1 -0.1 -0.1]
 [-0.1 -0.1 -0.1 -0.1 -0.1 -0.1]
 [-0.1 -0.1 -0.1 -0.1 -0.1 -0.1]
 [-0.1 -0.1 -0.1 -0.1 -0.1 -0.1]]


In [5]:
# sess.run(gmm.grad, feed_dict={gmm.x: x})[0].shape

In [6]:
# x, y = sess.run([gmm.d_log_px(x), gmm.d_log_px_(x)])
x = sess.run(gmm.d_log_px_(x))

In [7]:
print x
# print y

[array([[ 10.99788814,  10.99788814,  10.99788814,  10.99788814,
         10.99788814,  10.99788814],
       [ 10.99788814,  10.99788814,  10.99788814,  10.99788814,
         10.99788814,  10.99788814],
       [ 10.99788814,  10.99788814,  10.99788814,  10.99788814,
         10.99788814,  10.99788814],
       [ 10.99788814,  10.99788814,  10.99788814,  10.99788814,
         10.99788814,  10.99788814],
       [ 10.99788814,  10.99788814,  10.99788814,  10.99788814,
         10.99788814,  10.99788814],
       [ 10.99788814,  10.99788814,  10.99788814,  10.99788814,
         10.99788814,  10.99788814],
       [ 10.99788814,  10.99788814,  10.99788814,  10.99788814,
         10.99788814,  10.99788814],
       [ 10.99788814,  10.99788814,  10.99788814,  10.99788814,
         10.99788814,  10.99788814],
       [ 10.99788814,  10.99788814,  10.99788814,  10.99788814,
         10.99788814,  10.99788814],
       [ 10.99788814,  10.99788814,  10.99788814,  10.99788814,
         10.99788814,  10.