In [1]:
import numpy as np
import scipy as sp
import tensorflow as tf
import tensorflow.contrib.distributions as ds

In [2]:
# Check if multivariate normal with diagonal covariance == univariate normals with independent components 

class GMM(object):
    def __init__(self, mu, sigma, weights, dim):
        # Required parameters 
        self.mu = mu
        self.sigma = sigma
        self.weights = weights
        self.dim = dim
        
    def log_px(self, x, y):
        # log_px = log(sum(exp(log(w_i) + log(p_i(x)))))
        log_px, log_py = [], []
        for i in range(weights.shape[0]):
            mu_, sigma_ = self.mu[i] * np.ones(dim), self.sigma[i] * np.ones(dim)
            mvn = ds.MultivariateNormalDiag(loc=mu_, scale_diag=sigma_)
            # Calculate log_px for each component
            log_px_i = tf.reduce_logsumexp(mvn.log_prob(x)) + tf.log(weights[i])
            log_py_i = tf.reduce_logsumexp(mvn.log_prob(y)) + tf.log(weights[i])
            log_px.append(log_px_i)
            log_py.append(log_py_i)
        return tf.reduce_logsumexp(log_px), tf.reduce_logsumexp(log_py)
    
    def d_log_px(self, x, y):
        # d_log_px = 1 / exp(log(sum(exp(log(w_i) + log(p_i(x)))))) 
        #            * sum(exp(log(w_i) + log(p_i(x)) + log(-(x - mu)/sigma^2)))
        # Use symbolic differentiation instead
        '''
        # To stabilise tf.log
        log_denominator = [] # = log_px
        log_numerator = []
        for i in range(weights.shape[0]):
            mu_, sigma_ = self.mu[i] * np.ones(dim), self.sigma[i] * np.ones(dim)
            mvn = ds.MultivariateNormalDiag(loc=mu_, scale_diag=sigma_)
            # Calculate log_px for each component
            log_px_i = tf.reduce_logsumexp(mvn.log_prob(self.x)) + tf.log(weights[i])
            log_denominator.append(log_px_i)
            # Calculate precision for each component 
            pre_i = (1./(sigma_ ** 2)) * np.eye(dim)
            # Calculate numerator for each component
            add_term = tf.matmul(-(self.x - mu_), pre_i)
            log_px_i_plus_add_term = log_px_i + tf.log(add_term)
            log_numerator.append(log_px_i_plus_add_term)
        denominator = tf.exp(tf.reduce_logsumexp(log_denominator))
        numerator = tf.reduce_logsumexp(log_numerator, axis=1)
        return log_numerator
        '''
        log_px, log_py = self.log_px(x, y)
        return tf.gradients(log_px, [self.x]), tf.gradients(log_py, [self.y])

In [3]:
sess = tf.InteractiveSession()
mu = np.array([1, -1]); sigma = np.sqrt(np.array([0.1, 0.05])); weights = np.array([1./3, 2./3]); dim=6
gmm = GMM(mu, sigma, weights, dim)

In [4]:
x = -0.1 * np.ones((10, 6)).astype(np.float64)

In [5]:
sess.run(gmm.log_px, feed_dict={gmm.x: x})

-33.701830290660368