In [1]:
import numpy as np
import scipy as sp
import tensorflow as tf
import tensorflow.contrib.distributions as ds

In [2]:
sess = tf.InteractiveSession()

In [3]:
class GMM(object):
    def __init__(self, mu, sigma, weights, dim):
        # Required parameters 
        self.mu = mu
        self.sigma = sigma
        self.weights = weights
        self.dim = dim
        
    def log_px(self, x):
        # log_px = log(sum(exp(log(w_i) + log(p_i(x)))))
        log_px = []
        for i in range(weights.shape[0]):
            mu_, sigma_ = self.mu[i] * tf.ones(dim), self.sigma[i] * tf.ones(dim)
            mvn = ds.MultivariateNormalDiag(loc=mu_, scale_diag=sigma_)
            # Calculate log_px for each component
            log_px_i = tf.reduce_logsumexp(mvn.log_prob(x)) + tf.log(tf.to_float(weights[i]))
            log_px.append(log_px_i)
        return tf.reduce_logsumexp(log_px)
    
    def d_log_px(self, x):
        # d_log_px = 1 / exp(log(sum(exp(log(w_i) + log(p_i(x)))))) 
        #            * sum(exp(log(w_i) + log(p_i(x)) + log(-(x - mu)/sigma^2)))
        # Use symbolic differentiation instead
        log_px = self.log_px(x)
        return tf.gradients(log_px, [x])

In [4]:
mu = np.array([1., -1.]); sigma = np.sqrt(np.array([0.1, 0.05])); weights = np.array([1./3, 2./3]); dim=6
gmm = GMM(mu, sigma, weights, dim)

In [5]:
def median(x):
    x = tf.reshape(x, [-1])
    med = tf.floordiv(tf.shape(x)[0], 2)
    check_parity = tf.equal(tf.to_float(med), tf.divide(tf.to_float(tf.shape(x)[0]), 2.))
    def is_true():
        return tf.reduce_sum(tf.nn.top_k(x, med+1).values[-2:]) / 2.
    def is_false():
        return tf.nn.top_k(x, med+1).values[-1]
    return tf.cond(check_parity, is_true, is_false) 

In [6]:
class stein_is(object):
    def __init__(self, gmm_model, mu, sigma, dim, n_leaders, n_followers, step_size=0.01): # n_trials, step_size=0.01):
        # Required parameters
        self.gmm_model = gmm_model
        self.mu = mu
        self.sigma = sigma
        self.dim = dim
        self.n_leaders = n_leaders
        self.n_followers = n_followers
        # self.n_trials = n_trials
        self.step_size = step_size
        self.eps = 1e-10
        
        # Set seed
        seed = 30
        
        # Intialisation
        self.B, self.B_density, self.A = self.initialise_variables()
        self.pB = self.gmm_model.log_px(self.B)
        
        # Register functions for debugging
        # self.k_A_A, self.sum_grad_A_k_A_A, self.A_Squared, self.h_0 = self.construct_map()
        # self.k_A_B, self.sum_grad_A_k_A_B, self.grad_A_grad_B_k_A_B, self.grad_B_k_A_B = self.apply_map()        
        # self.A, self.B, self.phi_B, self.grad_B_phi_B = self.svgd_update()
        # self.q_density = self.density_update()
        
        
    def initialise_variables(self):
        init_distribution = tf.contrib.distributions.MultivariateNormalDiag(self.mu * tf.ones(dim), self.sigma * tf.ones(dim))
        
        # followers = tf.reshape(init_distribution.sample(self.n_trials * self.n_followers, seed=123), [self.n_trials, self.n_followers, self.h_dim] 
        # leaders = tf.reshape(init_distribution.sample(self.n_trials * self.n_leaders, seed=123), [self.n_trials, self.n_leaders, self.h_dim] 
        
        followers = tf.reshape(init_distribution.sample(self.n_followers, seed=123), [self.n_followers, self.dim]) 
        q_density = init_distribution.log_prob(followers)
        leaders = tf.reshape(init_distribution.sample(self.n_leaders, seed=123), [self.n_leaders, self.dim])
                           
        return followers, q_density, leaders
                             
    def construct_map(self):
        # Calculate ||leader - leader'||^2/h_0, refer to leader as A as in SteinIS
        x2_A_A_T = tf.multiply(2., tf.matmul(self.A, tf.transpose(self.A)))
        A_Squared = tf.reduce_sum(tf.square(self.A), 1)
        A_Distance = tf.add(tf.subtract(A_Squared, x2_A_A_T), tf.transpose(A_Squared))   
        # h_0 = tf.divide(tf.add(median(A_Distance), self.eps), 2. * (tf.log(tf.cast(self.n_leaders, tf.float32)) + 1.))
        h_0 = tf.divide(median(A_Distance), 2. * (tf.log(tf.to_float(self.n_leaders)) + 1.))
        k_A_A = tf.exp(-tf.div(A_Distance, tf.square(h_0)))
        sum_grad_A_k_A_A = tf.reduce_sum(tf.gradients(k_A_A, [self.A]), 1)
        return k_A_A, sum_grad_A_k_A_A, A_Squared, h_0
    
    def apply_map(self):
        # Calculate ||leader - follower||^2/h_0, refer to follower as B as in SteinIS
        x2_A_B_T = tf.multiply(2., tf.matmul(self.A, tf.transpose(self.B)))
        B_Squared = tf.reduce_sum(tf.square(self.B), 1)
        A_B_Distance  = tf.add(tf.subtract(self.A_Squared, x2_A_B_T), B_Squared)
        k_A_B = tf.exp(-tf.div(A_B_Distance, tf.square(self.h_0)))
        sum_grad_A_k_A_B = tf.reduce_sum(tf.gradients(k_A_B, [self.A]), 1)
        return k_A_B, sum_grad_A_k_A_B 
                    
    def svgd_update(self):
        self.k_A_A, self.sum_grad_A_k_A_A, self.A_Squared, self.h_0 = self.construct_map()
        self.k_A_B, self.sum_grad_A_k_A_B = self.apply_map()
        self.d_log_pA = self.gmm_model.d_log_px(self.A)[0]
        sum_d_log_pA_T_k_A_A = tf.reduce_sum(tf.matmul(self.k_A_A, self.d_log_pA), 0)       
        phi_A = (1. / tf.to_float(self.n_leaders)) * tf.add(sum_d_log_pA_T_k_A_A, self.sum_grad_A_k_A_A)
        A = tf.add(self.A, self.step_size * phi_A)  
        sum_d_log_pA_T_k_A_B = tf.reduce_sum(tf.matmul(self.k_A_B, self.d_log_pA), 0)       
        phi_B = (1. / tf.to_float(self.n_leaders)) * tf.add(sum_d_log_pA_T_k_A_B, self.sum_grad_A_k_A_B)
        B = tf.add(self.B, self.step_size * phi_B) 
        grad_B_phi_B = tf.gradients(phi_B, [self.B])
        return A, B, phi_B, grad_B_phi_B[0]
    
    def density_update(self):
        I = tf.eye(self.dim)
        inv_abs_det_I_grad_B_phi = tf.map_fn(lambda x: 1./tf.abs(tf.matrix_determinant(tf.add(I, x))), self.grad_B_phi_B)
        return tf.multiply(self.B_density, inv_abs_det_I_grad_B_phi) 
    
    def main(self, iteration):
        for i in range(iteration):
            self.A, self.B, self.phi_B, self.grad_B_phi_B = self.svgd_update()
            self.q_density = self.density_update()
            print 'Iteration ', str(i), ' done'
        self.importance_weights = tf.divide(self.q_density, self.pB)
        self.normalisation_constant = 1./tf.to_float(self.n_followers) * tf.reduce_sum(self.importance_weights)
        self.final_B = self.B
        return self.final_B, self.importance_weights, self.normalisation_constant                       

In [7]:
mu = 0.; sigma = 1.; dim = 6; n_leaders = 100; n_followers = 100;
model = stein_is(gmm,  mu, sigma, dim, n_leaders, n_followers)

In [8]:
sess.run([model.q_density])

[array([ -1.03152409e+01,  -1.51769896e+01,  -6.91211319e+00,
         -9.09164906e+00,  -1.66581357e+00,  -3.77475023e+00,
         -3.40547347e+00,  -5.77358770e+00,  -5.48494339e-01,
         -1.16953516e+01,  -4.43257749e-01,  -7.55881071e-01,
         -5.07833898e-01,  -2.52877498e+00,  -3.10027575e+00,
         -3.16426206e+00,  -2.75525767e-02,  -9.08142948e+00,
         -3.49198437e+00,  -7.64666021e-01,  -8.81662941e+00,
         -5.76153088e+00,  -7.92342424e+00,  -3.00987768e+00,
         -3.99649167e+00,  -4.43154144e+00,  -1.14873905e+01,
         -3.34710050e+00,  -7.03338194e+00,  -1.07595420e+00,
         -9.31732170e-03,  -8.85389709e+00,  -8.23503304e+00,
         -5.73243809e+00,  -3.09585500e+00,  -7.30990982e+00,
         -8.96162415e+00,  -6.74901628e+00,  -7.30081940e+00,
         -3.87950373e+00,  -2.41132855e+00,  -8.01904774e+00,
         -1.48389554e+00,  -3.09095216e+00,  -4.00711834e-01,
         -1.04508793e+00,  -4.29516840e+00,  -4.06943232e-01,
        