In [1]:
import tensorflow as tf
from tensorflow.contrib.layers import xavier_initializer
import numpy as np
from scipy.spatial.distance import pdist

from matplotlib import pyplot as plt
%matplotlib inline

In [None]:
class VAE(object):
    def __init__(self, x_dim, latent_dim, 
                 random_seed = None):
        self.x_dim = x_dim
        self.latent_dim = latent_dim
        
        #Initializer
        xavier_init = xavier_initializer(seed = random_seed)
        
        params = {
            'W' : tf.Variable(xavier_init([x_dim, latent_dim]), name = 'W')
            'b' : tf.Variable(tf.zeros([latent_dim]), name = 'b')
        }
        self.params = params
        
        #Graph
        X = tf.placeholder(tf.float32, shape=[x_dim, None], name='X')
        self.X = X
        Z = tf.placeholder(tf.float32, shape=[latent_dim, None], name='Z')
        self.Z = Z
        
        
        
        #Loss
        

In [10]:
def med(Z):
    A = Z.reshape(Z.shape[0], -1)
    r = tf.reduce_sum(A*A, 1)
    r = tf.reshape(r, [-1, 1])
    psqdists = r - 2*tf.matmul(A, tf.transpose(A)) + tf.transpose(r)
    return tf.contrib.distributions.percentile(psqdists, 50)

def rbfBandwidth(Z):
    h = tf.square(med(Z))/np.log(Z.shape[0])
    return(h)

def rbf(zi, zj, h):
    return(tf.exp(-tf.square(zi-zj)/h))

def linear(zi, zj):
    return tf.tensordot(zi, zj,axes=1)

def phi_star_j(zi, zj, p, kernel):
    zj_var = tf.constant(zj)
    return kernel(zi, zj)*tf.gradients(p(zj_var) + kernel(zj_var, zi), [zj_var])

def phi_star(Z, p, kernel):
    out = []
    for zi in Z:
        entry = tf.reduce_mean([phi_star_j(zi, zj, p, kernel) for zj in Z])
        out.append(entry)
    return(tf.stack(out))

def Stein(Z, p, kernel):
    ps = phi_star(Z, p, kernel)
    return(tf.tensordot(tf.transpose(Z), tf.stop_gradient(ps), axes = 1))

In [11]:
n = 5
Z = np.random.normal(0,1,n).astype('float32')
Z = Z.reshape(Z.shape[0], -1)

In [14]:
sess = tf.Session()
sess.run(Stein(Z, np.square, lambda x,y: rbf(x,y,rbfBandwidth(Z))))

array([ 3.19304872], dtype=float32)

In [714]:
sess.run(Stein(Z, np.square, linear))

array([ 39.0635643], dtype=float32)