In [1]:
import numpy as np
import scipy as sp
import tensorflow as tf
import tensorflow.contrib.distributions as ds

In [2]:
sess = tf.InteractiveSession()

In [5]:
class SteinIS(object):
    def __init__(self): 
        self.A = tf.convert_to_tensor(np.arange(9.).reshape((3, 3)))
        self.B = tf.convert_to_tensor(np.arange(9., 18.).reshape((3, 3)))
        
        
        # Register functions for debugging
        self.k_A_B, self.sum_grad_A_k_A_B = self.construct_map()
        
    def replace_none_with_zero(l):
        return [0 if i==None else i for i in l] 
        
    def construct_map(self):
        # Calculate ||leader - leader'||^2/h_0, refer to leader as A as in SteinIS
        x2_A_B_T = 2. * tf.matmul(self.A, tf.transpose(self.B)) # 100 x 100
        A_Squared = tf.reduce_sum(tf.square(self.A), keep_dims=True, axis=1) # 100 x 1
        B_Squared = tf.reduce_sum(tf.square(self.B), keep_dims=True, axis=1) # 100 x 1
        A_B_Distance_Squared = A_Squared - x2_A_B_T + tf.transpose(B_Squared) # 100 x 100
        k_A_B = tf.exp(-A_B_Distance_Squared)
        # Can't use vanilla tf.gradients as it sums dy/dx wrt to dx, want sum dy/dx wrt to dy, tf.map_fn is not available
#         k_A_A_list = tf.split(k_A_A, 3, axis=1)
#         A_copies = [tf.identity(self.A) for i in k_A_A_list]
#         grad_A_k_A_B_0 = tf.gradients(k_A_B[0], self.A)[0]
#         grad_B0_grad_A_k_A_B_0 = tf.gradients(grad_A_k_A_B_0[0], self.B)
        sum_grad_A_k_A_B = tf.stack([tf.matmul(tf.reshape(k_A_B[i, :], (1, -1)), -2 * (self.A - self.B[i])) for i in range(3)])
        return k_A_B, tf.squeeze(sum_grad_A_k_A_B)

In [6]:
model = SteinIS()
[v, w, x, y] = sess.run([model.A, model.B, model.k_A_B, model.sum_grad_A_k_A_B])

In [7]:
print v
print w
print x
print y
# print z

[[ 0.  1.  2.]
 [ 3.  4.  5.]
 [ 6.  7.  8.]]
[[  9.  10.  11.]
 [ 12.  13.  14.]
 [ 15.  16.  17.]]
[[  2.92712250e-106   2.42540248e-188   7.09945017e-294]
 [  1.24794646e-047   2.92712250e-106   2.42540248e-188]
 [  1.87952882e-012   1.24794646e-047   2.92712250e-106]]
[[  5.26882049e-105   5.26882049e-105   5.26882049e-105]
 [  2.99507151e-046   2.99507151e-046   2.99507151e-046]
 [  5.63858645e-011   5.63858645e-011   5.63858645e-011]]


In [8]:
np.matmul(x[0].reshape((1, -1)), -2* (v - w[0]))

array([[  5.26882049e-105,   5.26882049e-105,   5.26882049e-105]])