In [1]:
import numpy as np
import scipy as sp
import tensorflow as tf
import tensorflow.contrib.distributions as ds

In [2]:
sess = tf.InteractiveSession()

In [3]:
class SteinIS(object):
    def __init__(self): 
        self.A = tf.convert_to_tensor(np.arange(9.).reshape((3, 3)))
        
        # Register functions for debugging
        self.k_A_A, self.sum_grad_A0_k_A_A, self.A_Squared, self.A_A_Distance_Squared = self.construct_map()
        
    def construct_map(self):
        # Calculate ||leader - leader'||^2/h_0, refer to leader as A as in SteinIS
        x2_A_A_T = 2. * tf.matmul(self.A, tf.transpose(self.A)) # 100 x 100
        A_Squared = tf.reduce_sum(tf.square(self.A), keep_dims=True, axis=1) # 100 x 1
        A_A_Distance_Squared = A_Squared - x2_A_A_T + tf.transpose(A_Squared) # 100 x 100
        k_A_A = tf.exp(-A_A_Distance_Squared)
        # Can't use vanilla tf.gradients as it sums dy/dx wrt to dx, want sum dy/dx wrt to dy, tf.map_fn is not available
        sum_grad_A0_k_A_A = tf.gradients(k_A_A[2], self.A)[0]
        return k_A_A, sum_grad_A0_k_A_A, A_Squared, A_A_Distance_Squared

In [4]:
model = SteinIS()
[w, x, y, z] = sess.run([model.A, model.k_A_A[2], model.k_A_A, model.sum_grad_A0_k_A_A])

In [5]:
print w
print x
print y
print z

[[ 0.  1.  2.]
 [ 3.  4.  5.]
 [ 6.  7.  8.]]
[  1.24794646e-47   1.87952882e-12   1.00000000e+00]
[[  1.00000000e+00   1.87952882e-12   1.24794646e-47]
 [  1.87952882e-12   1.00000000e+00   1.87952882e-12]
 [  1.24794646e-47   1.87952882e-12   1.00000000e+00]]
[[  1.49753576e-46   1.49753576e-46   1.49753576e-46]
 [  1.12771729e-11   1.12771729e-11   1.12771729e-11]
 [ -1.12763132e-11  -1.12763132e-11  -1.12763132e-11]]


In [6]:
(w - w[2])

array([[-6., -6., -6.],
       [-3., -3., -3.],
       [ 0.,  0.,  0.]])

In [7]:
-2* (w - w[2]) * np.tile(x.reshape((1, 3)).T, 3)

array([[  1.49753576e-46,   1.49753576e-46,   1.49753576e-46],
       [  1.12771729e-11,   1.12771729e-11,   1.12771729e-11],
       [ -0.00000000e+00,  -0.00000000e+00,  -0.00000000e+00]])