In [9]:
import numpy as np
import os
import sys
import scipy
import numpy as np
import autograd.numpy as anp
import autograd
from functools import partial, namedtuple
import itertools
from autograd.extend import primitive, defvjp
from autograd import jacobian as jac

In [10]:
top_level_dir = '/'.join( os.getcwd().split( '/' )[ :-2 ] )
if top_level_dir not in sys.path:
    sys.path.append( top_level_dir )

In [11]:
def logsumexp( v, axis=0 ):
    max_v = anp.max( v )
    return anp.log( anp.sum( anp.exp( v - max_v ), axis=axis ) ) + max_v

def logsumexp_vjp(ans, x):
    x_shape = x.shape
    return lambda g: anp.full(x_shape, g) * anp.exp(x - np.full(x_shape, ans))
defvjp(logsumexp, logsumexp_vjp)

In [12]:
def gumbelSample( shape, eps=1e-8 ):
    u = anp.random.random( shape )
    return -anp.log( -anp.log( u + eps ) + eps )
def gumbelSoftmaxSample( logits, g=None, temp=1.0 ):
    if( g is None ):
        g = gumbelSample( logits.shape )
    y = logits + g
    ans = anp.exp( y ) / temp
    return ans / ans.sum()

In [13]:
def alphas( theta ):
    T, K = theta.L.shape
    alpha = []
    alpha.append( theta.pi0 + theta.L[ 0 ] )
    for t in range( 1, T ):
        alpha.append( logsumexp( alpha[ -1 ][ :, None ] + theta.pi, axis=0 ) + theta.L[ t ] )
    return anp.array( alpha )

def betas( theta ):
    T, K = theta.L.shape
    beta = []
    beta.append( anp.zeros( K ) )
    for t in reversed( range( 0, T - 1 ) ):
        beta.append( logsumexp( beta[ -1 ] + theta.pi + theta.L[ t + 1 ], axis=1 ) )
    return anp.array( list( reversed( beta ) ) )

def joints( alpha, beta ):
    joints = []
    for t in range( T - 1 ):
        joints.append( alpha[ t ][ :, None ] + pi + L[ t + 1 ] + beta[ t + 1 ] )
    return anp.array( joints )

def predictive( alpha, beta ):
    T, d_latent = alpha.shape
    joint = joints( alpha, beta )
    return joint - anp.reshape( ( alpha + beta )[ :-1 ], ( T-1, d_latent, 1 ) )

def sampleX( alpha, beta ):
    T, K = alpha.shape
    log_z = logsumexp( alpha[ -1 ] )
    
    preds = predictive( alpha, beta )
    x_samples = []
    
    x_samples.append( alpha[ 0 ] + beta[ 0 ] - log_z )
    
    for t, p in enumerate( preds ):
        logits = logsumexp( p + x_samples[ -1 ], axis=1 )
        x_samples.append( anp.log( gumbelSoftmaxSample( logits ) ) )
        
    return anp.array( x_samples )

def hmmSamples( theta ):
    alpha, beta = alphas( theta ), betas( theta )
    return sampleX( alpha, beta )

def neuralNet( x_samples, theta ):
    
    N = theta.d_latent * theta.d_obs
    W = anp.arange( N ).reshape( ( theta.d_latent, theta.d_obs ) )
    
    y_dist = anp.einsum( 'ij,ti->tj', W, x_samples )
    probs = y_dist[ anp.arange( theta.T ), theta.y ]
    return anp.sum( probs )

In [14]:
Theta = namedtuple( 'Theta', [ 'pi0', 'pi', 'L', 'T', 'd_latent', 'd_obs', 'y' ] )
T = 3
d_latent = 3
d_obs = 2
y = np.random.choice( d_obs, size=T )

In [15]:
pi0 = np.random.random( d_latent )
pi = np.random.random( ( d_latent, d_latent ) )
pi0 = np.log( pi0 )
pi = np.log( pi )
L = np.random.random( ( d_latent, d_obs ) )
L = L.T[ y ]

In [17]:
def forAG( func ):
    def wrapper( _L ):
        theta = Theta( pi0, pi, _L, T, d_latent, d_obs, y )
        return func( theta )
    return wrapper

In [33]:
@forAG
def alphasAG( theta ):
    return alphas( theta )

@forAG
def betasAG( theta ):
    return betas( theta )

@forAG
def jointsAG( theta ):
    return joints( alphas( theta ), betas( theta ) )

@forAG
def predictiveAG( theta ):
    return predictive( alphas( theta ), betas( theta ) )

@forAG
def sampleXAG( theta ):
    return sampleX( alphas( theta ), betas( theta ) )

@forAG
def hmmSamplesAG( theta ):
    return hmmSamples( theta )

@forAG
def neuralNetAG( theta ):
    x_samples = sampleX( alphas( theta ), betas( theta ) )
    return neuralNet( x_samples, theta )

In [34]:
gumb = gumbelSample( ( T, d_latent ) )

In [35]:
jac( neuralNetAG )( L )

array([[-0.4403543 , -0.29535462,  0.73570893],
       [ 0.04340087, -0.19661872,  0.15321785],
       [ 0.15411257, -0.50888753,  0.35477496]])