In [1]:
import numpy as np
import numpy.ma as ma
import scipy as sp
import tensorflow as tf
import tensorflow.contrib.distributions as ds

In [2]:
class GMM(object):
    def __init__(self, mu, sigma, weights, dim):
        # Required parameters 
        self.mu = []
        self.sigma = []
        self.weights = weights
        self.dim = dim
        
        self.distributions = []
        for i in range(self.weights.shape[0]):
            mu_, sigma_ = mu[i] * np.ones(dim), sigma[i] * np.ones(dim)
            # print(mu_, sigma_)
            mvnd_i = tf.contrib.distributions.MultivariateNormalDiag(mu_, sigma_)
            self.mu.append(mu_)
            self.sigma.append(sigma_)
            self.distributions.append(mvnd_i)
        self.mix = tf.contrib.distributions.Mixture(tf.contrib.distributions.Categorical(probs=self.weights), self.distributions)      
        self.mu = np.asarray(self.mu)
        self.sigma = np.asarray(self.sigma)
        
    def reshape_fish_comp(self, vec):
        # Reshapes a component of a Fisher vector
        vec_shape = tf.shape(vec)
        return tf.reshape(tf.transpose(vec, perm=[1, 0, 2]), [vec_shape[1], vec_shape[0] * vec_shape[2]])

    def log_px_(self, x):
        # x = tf.convert_to_tensor(x)
        return self.mix.log_prob(x)
       
    def dx_log_px(self, x):
        # d_log_px = 1 / exp(log(sum(exp(log(w_i) + log(p_i(x)))))) 
        #            * sum(exp(log(w_i) + log(p_i(x)) + log(-(x - mu)/sigma^2)))
        # Use symbolic differentiation instead
        x = tf.convert_to_tensor(x)
        log_px = self.log_px_(x)
        return tf.gradients(log_px, [x])[0], log_px
    
    def dtheta_log_px(self, x):
        # Returns an n * # of components matrix
        # x_t = x
        n_points = x.shape[0]
        x_t = tf.convert_to_tensor(x)
        px = tf.reshape(self.mix.prob(x_t), (-1, 1))
        dw_log_px_, dmu_log_px_, dsigma2_log_px_, w_px_, exponent_, xi_ = [], [], [], [], [], []
        for i in range(self.weights.shape[0]):
            px_i = tf.reshape(self.distributions[i].prob(x_t), [-1, 1])
            dw_log_px_.append(px_i)
            w_px_i = self.weights[i] * px_i
            w_px_.append(w_px_i)
            exponent_i = tf.divide((x_t - self.distributions[i].mean()), self.distributions[i].variance())
            exponent_.append(exponent_i)
            dmu_log_px_.append(w_px_i * exponent_i)
            xi_i = 0.5 * (tf.multiply(exponent_i, exponent_i) - tf.divide(1, self.distributions[i].variance()))
            xi_.append(xi_i)
            dsigma2_log_px_.append(w_px_i * xi_i)
        dw_log_px_, dmu_log_px_, dsigma2_log_px_ =  tf.stack(dw_log_px_) / px, tf.stack(dmu_log_px_) / px, tf.stack(dsigma2_log_px_) / px
        dw_log_px, dmu_log_px, dsigma2_log_px = self.reshape_fish_comp(dw_log_px_), self.reshape_fish_comp(dmu_log_px_), self.reshape_fish_comp(dsigma2_log_px_)
        dtheta_log_px_ = tf.concat([dw_log_px, dmu_log_px, dsigma2_log_px], 1)
        dtheta_log_px_norm = tf.norm(dtheta_log_px_, axis=1, keep_dims=True)
        dtheta_log_px = dtheta_log_px_ / dtheta_log_px_norm
        return dtheta_log_px, [dw_log_px_ / dtheta_log_px_norm, dmu_log_px_ / dtheta_log_px_norm, dsigma2_log_px_ / dtheta_log_px_norm], dmu_log_px_, w_px_, exponent_, xi_, px, dtheta_log_px_norm

    def dx_dtheta_log_px(self, dmu_log_px_, w_px_, exponent_, xi_, px, dtheta_log_px_norm):
        # Returns a n * d * # of components matrix
        dx_dw_log_px_, dx_dmu_log_px_, dx_dsigma2_log_px_ = [], [], []
        zeta = tf.reduce_sum(dmu_log_px_, [0])
        for i in range(self.weights.shape[0]):
            zeta_m_exponent_T = tf.expand_dims(zeta - exponent_[i], 1)
            w_px_i_d_px_norm = w_px_[i] / (px * dtheta_log_px_norm)
            dx_dw_log_px_.append(zeta_m_exponent_T * tf.expand_dims(w_px_i_d_px_norm, -1))
            diag_precision = tf.diag(1. / self.distributions[i].variance())
            exponent_i_tensor = tf.expand_dims(exponent_[i], -1)
            w_px_i_d_px_norm = tf.expand_dims(w_px_[i] / (px * dtheta_log_px_norm), -1)
            dx_dmu_log_px_.append((tf.matmul(exponent_i_tensor, zeta_m_exponent_T) + diag_precision) * w_px_i_d_px_norm)
            xi_i_tensor = tf.expand_dims(xi_[i], -1)
            exponent_i_diag = tf.matrix_diag(exponent_[i])
            diag_precision = tf.expand_dims(diag_precision, 0)
            dx_dsigma2_log_px_.append((tf.matmul(xi_i_tensor, zeta_m_exponent_T) + exponent_i_diag * diag_precision)* w_px_i_d_px_norm)
        dx_dw_log_px, dx_dmu_log_px, dx_dsigma2_log_px = tf.stack(dx_dw_log_px_), tf.stack(dx_dmu_log_px_), tf.stack(dx_dsigma2_log_px_)
        return [dx_dw_log_px, dx_dmu_log_px, dx_dsigma2_log_px]

In [3]:
sess = tf.InteractiveSession()
mu = np.array([1, -1]); sigma = np.sqrt(np.array([0.1, 0.05])); weights = np.array([1./3, 2./3]); dim=6
# mu = np.array([[-.5], [.5], [-1.], [1.0], [-1.5], [1.5], [-2.0], [2.0], [-2.5], [2.5]]); sigma = np.sqrt(2) * np.ones(10); weights = (1 / 10.0) * np.ones(10); dim = 2
gmm = GMM(mu, sigma, weights, dim)

In [4]:
# x = -0.1 * np.ones((10, 2)).astype(np.float64)
x = -0.1 * np.ones((10, 6)).astype(np.float64)

In [5]:
# [dw_log_px_gc, dw_log_px, px_], px = sess.run([gmm.dtheta_log_px(x), gmm.px(x)])
dx_log_px, log_px = sess.run(gmm.dx_log_px(x))
dtheta_log_px, y, zeta, w_px_, exponent_, xi_, px, dtheta_log_px_norm = sess.run(gmm.dtheta_log_px(x))
# dtheta_log_px = sess.run(gmm.dtheta_log_px(x))
z = sess.run(gmm.dx_dtheta_log_px(zeta, w_px_, exponent_, xi_, px, dtheta_log_px_norm))
# sum_grad_A_k_A_A = sess.run(tf.reduce_sum(gmm.grad_kernel(y[0], z[0]) + gmm.grad_kernel(y[1], z[1]) + gmm.grad_kernel(y[2], z[2]), axis=0))
# d_A_d_A_T = sess.run(gmm.grad_grad_kernel(z[0], z[0]))

In [60]:
dx_log_px_T = dx_log_px.transpose()
k_x_dx = np.einsum('imk,ijkl->mjl', y[0], z[0],)
# print dx_log_px_T.shape, k_x_dx.shape
# print np.einsum('ijk,lk->ijl', a, dx_log_px_T)
# np.sum(a * dx_log_px_T, -1)

print y[0].shape, y[1].shape
print z[0].shape, z[1].shape

(2, 10, 1) (2, 10, 6)
(2, 10, 1, 6) (2, 10, 6, 6)


array([[ 10.99788814],
       [ 10.99788814],
       [ 10.99788814],
       [ 10.99788814],
       [ 10.99788814],
       [ 10.99788814]])

In [59]:
for i in range(10):
    if i == 0:
        sum_dx_log_px_T_k_x_dx = np.matmul(dx_log_px_T[:, i].reshape(6, 1), np.expand_dims(k_x_dx[i], 1))
    else:
        sum_dx_log_px_T_k_x_dx += np.matmul(dx_log_px_T[:, i].reshape(6, 1), np.expand_dims(k_x_dx[i], 1))
sum_dx_log_px_T_k_x_dx

array([[[  3.62580678e-05,   3.62580678e-05,   3.62580678e-05,
           3.62580678e-05,   3.62580678e-05,   3.62580678e-05],
        [  3.62580678e-05,   3.62580678e-05,   3.62580678e-05,
           3.62580678e-05,   3.62580678e-05,   3.62580678e-05],
        [  3.62580678e-05,   3.62580678e-05,   3.62580678e-05,
           3.62580678e-05,   3.62580678e-05,   3.62580678e-05],
        [  3.62580678e-05,   3.62580678e-05,   3.62580678e-05,
           3.62580678e-05,   3.62580678e-05,   3.62580678e-05],
        [  3.62580678e-05,   3.62580678e-05,   3.62580678e-05,
           3.62580678e-05,   3.62580678e-05,   3.62580678e-05],
        [  3.62580678e-05,   3.62580678e-05,   3.62580678e-05,
           3.62580678e-05,   3.62580678e-05,   3.62580678e-05]],

       [[  3.62580678e-05,   3.62580678e-05,   3.62580678e-05,
           3.62580678e-05,   3.62580678e-05,   3.62580678e-05],
        [  3.62580678e-05,   3.62580678e-05,   3.62580678e-05,
           3.62580678e-05,   3.62580678e-05,  

In [58]:
sum_dx_log_px_T_k_x_dx

array([[[  3.62580678e-05,   3.62580678e-05,   3.62580678e-05,
           3.62580678e-05,   3.62580678e-05,   3.62580678e-05],
        [  3.62580678e-05,   3.62580678e-05,   3.62580678e-05,
           3.62580678e-05,   3.62580678e-05,   3.62580678e-05],
        [  3.62580678e-05,   3.62580678e-05,   3.62580678e-05,
           3.62580678e-05,   3.62580678e-05,   3.62580678e-05],
        [  3.62580678e-05,   3.62580678e-05,   3.62580678e-05,
           3.62580678e-05,   3.62580678e-05,   3.62580678e-05],
        [  3.62580678e-05,   3.62580678e-05,   3.62580678e-05,
           3.62580678e-05,   3.62580678e-05,   3.62580678e-05],
        [  3.62580678e-05,   3.62580678e-05,   3.62580678e-05,
           3.62580678e-05,   3.62580678e-05,   3.62580678e-05]],

       [[  3.62580678e-05,   3.62580678e-05,   3.62580678e-05,
           3.62580678e-05,   3.62580678e-05,   3.62580678e-05],
        [  3.62580678e-05,   3.62580678e-05,   3.62580678e-05,
           3.62580678e-05,   3.62580678e-05,  

In [26]:
np.einsum('ijlk,imln->jkn', z[0], z[0])

array([[[  4.64177558e-09,   4.64177558e-09,   4.64177558e-09,
           4.64177558e-09,   4.64177558e-09,   4.64177558e-09],
        [  4.64177558e-09,   4.64177558e-09,   4.64177558e-09,
           4.64177558e-09,   4.64177558e-09,   4.64177558e-09],
        [  4.64177558e-09,   4.64177558e-09,   4.64177558e-09,
           4.64177558e-09,   4.64177558e-09,   4.64177558e-09],
        [  4.64177558e-09,   4.64177558e-09,   4.64177558e-09,
           4.64177558e-09,   4.64177558e-09,   4.64177558e-09],
        [  4.64177558e-09,   4.64177558e-09,   4.64177558e-09,
           4.64177558e-09,   4.64177558e-09,   4.64177558e-09],
        [  4.64177558e-09,   4.64177558e-09,   4.64177558e-09,
           4.64177558e-09,   4.64177558e-09,   4.64177558e-09]],

       [[  4.64177558e-09,   4.64177558e-09,   4.64177558e-09,
           4.64177558e-09,   4.64177558e-09,   4.64177558e-09],
        [  4.64177558e-09,   4.64177558e-09,   4.64177558e-09,
           4.64177558e-09,   4.64177558e-09,  

In [36]:
np.sum(np.einsum('ijkl,imk->jlm', z[0], y[0]), 0).shape

(6, 10)

In [38]:
np.sum(np.einsum('ijkl,imk->jml', z[0], y[0]), 0).shape

(10, 6)

In [None]:
np.matmul(z[1][0][0], y[1][0].transpose()) + np.matmul(z[1][1][0], y[1][1].transpose())

In [None]:
1.52344603e-05 * 0.02164134 + -1.52344603e-05* 7.88046711e-07