In [None]:
import autograd.numpy as np
import autograd
import autograd.numpy as anp
import itertools
from functools import partial, namedtuple
from autograd.extend import primitive, defvjp
from autograd import jacobian as jac

In [None]:
def logsumexp( v, axis=0 ):
    max_v = anp.max( v )
    return anp.log( anp.sum( anp.exp( v - max_v ), axis=axis ) ) + max_v

def logsumexp_vjp(ans, x):
    x_shape = x.shape
    return lambda g: anp.full(x_shape, g) * anp.exp(x - np.full(x_shape, ans))
defvjp(logsumexp, logsumexp_vjp)

In [None]:
def gumbelSample( shape, eps=1e-8 ):
    u = anp.random.random( shape )
    return -anp.log( -anp.log( u + eps ) + eps )
def gumbelSoftmaxSample( logits, g=None, temp=1.0 ):
    if( g is None ):
        g = gumbelSample( logits.shape )
    y = logits + g
    ans = anp.exp( y ) / temp
    return ans / ans.sum()

In [None]:
def alphas_unrolled( theta ):
    T, K = theta.L.shape
    a_0 = theta.pi0 + theta.L[ 0 ]
    a_1 = logsumexp( a_0[ :, None ] + theta.pi, axis=0 ) + theta.L[ 1 ]
    a_2 = logsumexp( a_1[ :, None ] + theta.pi, axis=0 ) + theta.L[ 2 ]
    return anp.array( [ a_0, a_1, a_2 ] )

def betas_unrolled( theta ):
    T, K = theta.L.shape
    b_2 = anp.zeros( K )
    b_1 = logsumexp( b_2 + theta.pi + theta.L[ 2 ], axis=1 )
    b_0 = logsumexp( b_1 + theta.pi + theta.L[ 1 ], axis=1 )
    return anp.array( [ b_0, b_1, b_2 ] )    

def joints_unrolled( alpha, beta ):
    j_1 = alpha[ 0 ][ :, None ] + pi + L[ 1 ] + beta[ 1 ]
    j_2 = alpha[ 1 ][ :, None ] + pi + L[ 2 ] + beta[ 2 ]
    return anp.array( [ j_1, j_2 ] )

def predictive_unrolled( alpha, beta ):
    T, d_latent = alpha.shape
    joint = joints_unrolled( alpha, beta )
    return joint - anp.reshape( ( alpha + beta )[ :-1 ], ( T-1, d_latent, 1 ) )
    
def alphas( theta ):
    T, K = theta.L.shape
    alpha = anp.zeros( ( T, K ) )
    alpha[ 0 ] = theta.pi0 + theta.L[ 0 ]
    for t in range( 1, T ):
        alpha[ t ] = logsumexp( alpha[ t - 1 ][ :, None ] + theta.pi, axis=0 ) + theta.L[ t ]
    return alpha

def betas( theta ):
    T, K = theta.L.shape
    beta = anp.zeros( ( T, K ) )
    for t in reversed( range( 0, T - 1 ) ):
        beta[ t ] = logsumexp( beta[ t + 1 ] + theta.pi + theta.L[ t + 1 ], axis=1 )
    return beta

def joints( alpha, beta ):
    joints = anp.zeros( ( T-1, d_latent, d_latent ) )
    for t in range( T - 1 ):
        joints[ t ] = alpha[ t ][ :, None ] + pi + L[ t + 1 ] + beta[ t + 1 ]
    return joints

def predictive( alpha, beta ):
    T, d_latent = alpha.shape
    joint = joints( alpha, beta )
    return joint - anp.reshape( ( alpha + beta )[ :-1 ], ( T-1, d_latent, 1 ) )

In [None]:
def sampleX( alpha, beta ):
    T, K = alpha.shape
    log_z = logsumexp( alpha[ -1 ] )
    
    preds = predictive( alpha, beta )
    x_samples = anp.zeros( ( T, K ) )
    
    x_samples[ 0 ] = alpha[ 0 ] + beta[ 0 ] - log_z
    
    for t, p in enumerate( preds ):
        logits = logsumexp( p + x_samples[ t ], axis=1 )
        x_samples[ t + 1 ] = anp.log( gumbelSoftmaxSample( logits ) )
        
    return x_samples

def sampleX_unrolled( alpha, beta, gumb ):
    T, K = alpha.shape
    log_z = logsumexp( alpha[ -1 ] )
        
    preds = predictive_unrolled( alpha, beta )
    
    x_0 = alpha[ 0 ] + beta[ 0 ] - log_z
    
    logits = anp.log( anp.exp( preds[ 0 ] ) @ anp.exp( x_0 ) )
    x_1 = anp.log( gumbelSoftmaxSample( logits, g=gumb[ 1 ] ) )
    
    logits = anp.log( anp.exp( preds[ 1 ] ) @ anp.exp( x_1 ) )
    x_2 = anp.log( gumbelSoftmaxSample( logits, g=gumb[ 1 ] ) )
    
    return anp.array( [ x_0, x_1, x_2 ] )

def hmmSamples( theta ):
    alpha, beta = alphas( theta ), betas( theta )
    return sampleX( alpha, beta )

def hmmSamples_unrolled( theta, gumb ):
    alpha, beta = alphas_unrolled( theta ), betas_unrolled( theta )
    return sampleX_unrolled( alpha, beta, gumb )

def neuralNet( x_samples, theta ):
    
    N = theta.d_latent * theta.d_obs
    W = anp.arange( N ).reshape( ( theta.d_latent, theta.d_obs ) )
    
    y_dist = anp.einsum( 'ij,ti->tj', W, x_samples )
    probs = y_dist[ anp.arange( theta.T ), theta.y ]
    return anp.sum( probs )

In [None]:
Theta = namedtuple( 'Theta', [ 'pi0', 'pi', 'L', 'T', 'd_latent', 'd_obs', 'y' ] )
T = 3
d_latent = 3
d_obs = 2
y = np.random.choice( d_obs, size=T )

In [None]:
pi0 = np.random.random( d_latent )
pi = np.random.random( ( d_latent, d_latent ) )
pi0 = np.log( pi0 )
pi = np.log( pi )
L = np.random.random( ( d_latent, d_obs ) )
L = L.T[ y ]

In [None]:
gumb = gumbelSample( ( T, d_latent ) )

In [None]:
def trueAnswer( L ):
    theta = Theta( pi0, pi, L, T, d_latent, d_obs, y )
    x_samples = hmmSamples_unrolled( theta, gumb )
    return neuralNet( x_samples, theta )
jac( trueAnswer )( L )

In [None]:
def a( L ):
    theta = Theta( pi0, pi, L, T, d_latent, d_obs, y )
    x_samples = hmmSamples_unrolled( theta, gumb )
    return x_samples
def b( x_samples ):
    theta = Theta( pi0, pi, L, T, d_latent, d_obs, y )
    return neuralNet( x_samples, theta )
da = jac( a )( L )
db = jac( b )( a( L ) )

In [None]:
np.einsum( 'ijab,ij->ab', da, db )

## Step 1 - Compute x samples
## Step 2 - Compute dlogP( y | x )/dx
## Step 3 - Accumulate dlogP( y | x )/dL by computing dx/dL and summing immediately

In [None]:
def deriv( theta, gumb ):
    temp = 1.0
    
    # Get the needed stats
    alpha, beta = alphas( theta ), betas( theta )
    preds = predictive( alpha, beta )
    log_z = logsumexp( alpha[ -1 ] )
    
    T, K = alpha.shape
    
    # Initialize the variables
    dXtdLs = np.zeros( ( T, T, K, K ) )
    dXtdXt1 = np.zeros( ( T-1, K, K ) )
    x_samples = np.zeros( ( T, K ) )
    
    # Base case derivative
    x_samples[ 0 ] = alpha[ 0 ] + beta[ 0 ] - log_z
    dXtdLs[ 0, 0 ] = np.eye( K ) - np.exp( x_samples[ 0 ] )

    print( '\nt', 0, 's', 0, 'dXtdLs[ t, t ]\n', dXtdLs[ 0, 0 ] )
    
    for i, p in enumerate( preds ):
        t = i + 1
        
        # Compute x_t | x_t-1
        p = theta.pi + theta.L[ t ] + beta[ t ] - beta[ t - 1 ][ :, None ]
        logits = logsumexp( p + x_samples[ t - 1 ], axis=1 )
        unnormx = logits + gumb[ t ] - np.log( temp )
        x_samples[ t ] = unnormx - logsumexp( unnormx )
        
        # Compute dLogit / dL_t
        dXPdLt = np.eye( K ) - np.exp( theta.L[ t ] + theta.pi + beta[ t ] ).T
        
        # Compute dLogit / dX_t-1
        dLogitdXt1 = np.exp( p + x_samples[ t - 1 ] - logits[ :, None ] )
        dLogitdP = dLogitdXt1
        
        # Compute dX_t / dLogit
        dXtdLogit = np.eye( K ) - np.exp( x_samples[ t ] )
        
        # Compute dX_t / dX_t-1
        dXtdXt1[ i ] = dXtdLogit @ dLogitdXt1
    
        # Compute the derivative dX_t / dL_s for s == t
        dXtdLs[ t, t ] = dXtdLogit @ dLogitdP @ dXPdLt
        
        # Update each of the L derivatives
        for s in reversed( range( t ) ):
            # Compute dX_t / dL_s for s < t
            dXtdLs[ t, s ] = dXtdXt1[ i ] @ dXtdLs[ t-1, s ]
        
    return dXtdLs

In [None]:
theta = Theta( pi0, pi, L, T, d_latent, d_obs, y )
deriv( theta, gumb )

In [None]:
da.transpose( 0, 2, 1, 3 )

In [None]:
alpha, beta = alphas( theta ), betas( theta )

In [None]:
def forAG( func ):
    def wrapper( _L ):
        theta = Theta( pi0, pi, _L, T, d_latent, d_obs, y )
        return func( theta )
    return wrapper

In [None]:
def pred_unrolled( theta ):
    beta = betas_unrolled( theta )
    p1 = theta.pi + theta.L[ 1 ] + beta[ 1 ] - beta[ 0 ][ :, None ]
    p2 = theta.pi + theta.L[ 2 ] + beta[ 2 ] - beta[ 1 ][ :, None ]
    return anp.array( [ p1, p2 ] )
@forAG
def pred_unrolled_ag( theta ):
    return pred_unrolled( theta )

In [None]:
def mypred_jac( theta ):
    preds = pred_unrolled( theta )

# d*beta*<sup>(t)</sup>/d*L*<sup>(s)</sup>

In [None]:
def betas_unrolled( theta ):
    T, K = theta.L.shape
    b_2 = anp.zeros( K )
    b_1 = logsumexp( b_2 + theta.pi + theta.L[ 2 ], axis=1 )
    b_0 = logsumexp( b_1 + theta.pi + theta.L[ 1 ], axis=1 )
    return anp.array( [ b_0, b_1, b_2 ] )    

@forAG
def betas_unrolled_ag( theta ):
    return betas_unrolled( theta )

def dbetadL( theta ):
    betas = betas_unrolled( theta )
    T, K = betas.shape
    betas_jac = np.zeros( ( T, K, T, K ) )
    for t in range( T-2, -1, -1 ):
        val = np.exp( theta.pi + theta.L[ t+1 ] + betas[ t+1 ] - betas[ t ][ :, None ] )
        betas_jac[ t, :, t+1, : ] = val
        for s in range( t+2, T ):
            betas_jac[ t, :, s, : ] = val @ betas_jac[ t+1, :, s, : ]
    return betas_jac

# d*F<sup>(t)</sup>*/d*L<sup>(s)</sup>*

In [None]:
def F_unrolled( theta ):
    beta = betas_unrolled( theta )
    f1 = theta.pi + theta.L[ 1 ] + beta[ 1 ] - beta[ 0 ][ :, None ]
    f2 = theta.pi + theta.L[ 2 ] + beta[ 2 ] - beta[ 1 ][ :, None ]
    return anp.array( [ f1, f2 ] )

@forAG
def F_unrolled_ag( theta ):
    return F_unrolled( theta )

def dFdL( theta ):
    F = F_unrolled( theta )
    T, K = theta.L.shape
    f_jac = np.zeros( F.shape + theta.L.shape )
    b_jac = dbetadL( theta )
    for t in range( 1, T ):
        val = np.eye( K ) - np.exp( F[ t-1 ] )
        for s in range( T ):
            
            f_jac[ t-1, :, :, s, : ] = b_jac[ t, :, s, : ] - b_jac[ t-1, :, s, : ][ :, None, : ]
            if( t == s ):
                f_jac[ t-1, :, :, s, : ] += np.eye( K )
            
    return f_jac

# d*H<sup>(t)</sup>*/d*L<sup>(s)</sup>*

In [None]:
def x_samples_unrolled( theta, gumb, temp=1.0 ):
    F = F_unrolled( theta )
    T, K = theta.T, theta.d_latent
    x_samples = anp.zeros( ( T, K ) )
    
    alpha, beta = alphas_unrolled( theta ), betas_unrolled( theta )
    log_z = logsumexp( alpha[ -1 ] )
    x_samples_0 = theta.pi0 + theta.L[ 0 ] + beta[ 0 ] - log_z
    
    H0 = logsumexp( F[ 0 ] + x_samples_0, axis=1 )
    G = H0 + gumb[ 1 ] - anp.log( temp )
    x_samples_1 = G - logsumexp( G )

    H1 = logsumexp( F[ 1 ] + x_samples_1, axis=1 )
    G = H1 + gumb[ 2 ] - anp.log( temp )
    x_samples_2 = G - logsumexp( G )

    return anp.array( [ x_samples_0, x_samples_1, x_samples_2 ] )

@forAG
def x_samples_unrolled_ag( theta ):
    return x_samples_unrolled( theta, gumb )

def dHdL( theta, gumb, temp=1.0 ):
    F = F_unrolled( theta )
    dF = dFdL( theta )
    dB = dbetadL( theta )
    
    T, K = theta.T, theta.d_latent
    x_samples = anp.zeros( ( T, K ) )
    dX = np.zeros( ( T, K, T, K ) )
    
    alpha, beta = alphas_unrolled( theta ), betas_unrolled( theta )
    log_z = logsumexp( alpha[ -1 ] )
    x_samples[ 0 ] = theta.pi0 + theta.L[ 0 ] + beta[ 0 ] - log_z
    
    dX[ 0 ] = dB[ 0 ]
    dX[ 0, :, :, : ] -= np.exp( alpha + beta - log_z )
    dX[ 0, :, 0, : ] += np.eye( K )
    
    for t in range( 1, T ):
        H = logsumexp( F[ t-1 ] + x_samples[ t-1 ], axis=1 )
        G = H + gumb[ t ] - anp.log( temp )
        x_samples[ t ] = G - logsumexp( G )
        
        # H are the logits
        # F are the conditioned transition matrix
        # G are the unnormalized x samples

        for s in range( T ):
            deriv = dF[ t-1, :, :, s, : ] + dX[ t-1, :, s, : ][ None, :, : ]
            val = np.exp( F[ t-1 ] + x_samples[ t-1 ] - H[ :, None ] )
            dH = np.einsum( 'ij,ijk->ijk', val, deriv )
            tmp = 1 - np.exp( G - logsumexp( G ) )
            dX[ t, :, s, : ] = np.einsum( 'i,ijk->ik', tmp, dH )
        
    return dX

In [None]:
dHdL( theta, gumb )

In [None]:
jac( x_samples_unrolled_ag )( L )

In [None]:
np.ones( 4 )[ None, : ]

In [None]:
def gumbelSample( shape, eps=1e-8 ):
    u = anp.random.random( shape )
    return -anp.log( -anp.log( u + eps ) + eps )
gumb = gumbelSample( ( T, d_latent ) )

In [None]:
def activation( x ):
    return anp.sin( x )

def check( x ):
    d_in = x.shape[ -1 ]
    d_out = 100
    W1 = anp.arange( d_in * d_out ).reshape( ( d_in, d_out ) )
    b1 = anp.arange( d_out )
    
    d_in = 100
    d_out = 1
    W2 = anp.arange( d_in * d_out ).reshape( ( d_in, d_out ) )
    b2 = anp.arange( d_out )
    
    z = activation( anp.einsum( 'ij,ti->tj', W1, x ) + b1 )
    
    return anp.sum( activation( anp.einsum( 'ij,ti->tj', W2, z + b2 ) ) )

In [None]:
x = np.random.random( ( 10, 3 ) )

In [None]:
check( x )

In [None]:
%%timeit
jac( check )( x )

In [None]:
x

In [None]:
def asdf( x ):
    return anp.sum( x**2 )

In [None]:
jac( asdf )( x )

In [None]:
2*x

In [None]:
def fwd( theta ):
    T, K = theta.L.shape
    alpha = []
    alpha.append( theta.pi0 + theta.L[ 0 ] )
    for t in range( 1, T ):
        alpha.append( logsumexp( alpha[ -1 ][ :, None ] + theta.pi, axis=0 ) + theta.L[ t ] )
    return anp.array( alpha )

In [None]:
@forAG
def fwdAG( theta ):
    return fwd( theta )

In [None]:
jac( fwdAG )( L )

In [None]:
@forAG
def blah( theta ):
    return alphas_unrolled( theta )

In [None]:
jac( blah )( L )

In [None]:
def nonFBSMultiplyTerms( terms ):
    # Basically np.einsum but in log space
    terms = list( terms )

    # Separate out where the feedback set axes start and get the largest fbs_axis.
    # Need to handle case where ndim of term > all fbs axes
    # terms, fbs_axes_start = list( zip( *terms ) )
    fbs_axes_start = [ -1, -1 ]

    if( max( fbs_axes_start ) != -1 ):
        max_fbs_axis = max( [ ax if ax != -1 else term.ndim for ax, term in zip( fbs_axes_start, terms ) ] )

        if( max_fbs_axis > 0 ):
            # Pad extra dims at each term so that the fbs axes start the same way for every term
            for i, ax in enumerate( fbs_axes_start ):
                if( ax == -1 ):
                    for _ in range( max_fbs_axis - terms[ i ].ndim + 1 ):
                        terms[ i ] = terms[ i ][ ..., None ]
                else:
                    for _ in range( max_fbs_axis - ax ):
                        terms[ i ] = anp.expand_dims( terms[ i ], axis=ax )
    else:
        max_fbs_axis = -1

    ndim = max( [ len( term.shape ) for term in terms ] )

    axes = [ [ i for i, s in enumerate( t.shape ) if s != 1 ] for t in terms ]

    # Get the shape of the output
    shape = anp.ones( ndim, dtype=int )
    for ax, term in zip( axes, terms ):
        shape[ anp.array( ax ) ] = term.squeeze().shape

    total_elts = shape.prod()
    if( total_elts > 1e8 ):
        assert 0, 'Don\'t do this on a cpu!  Too many terms: %d'%( int( total_elts ) )

    # Basically np.einsum in log space
    ans = anp.zeros( shape )
    for ax, term in zip( axes, terms ):

        for _ in range( ndim - term.ndim ):
            term = term[ ..., None ]

        ans += term
#         ans += np.broadcast_to( term, ans.shape )

    return ans

In [None]:
def blah( x ):
    y = np.random.random( ( 1, 4, 2, 1, 1, 1, 8 ) )
    return nonFBSMultiplyTerms( ( x, y ) )

In [None]:
x = np.random.random( ( 3, 1, 2, 1 ) )

In [None]:
import recordclass

In [None]:
Something = recordclass.recordclass( 'something', [ 'a', 'b' ] )

In [None]:
def blah2( s ):
    s.a *= 4
    return ( s.a * s.b )**2

In [None]:
def blahh( a ):
    s = Something( a, 5 )
    return blah2( s )

In [None]:
blahh( 3 )

In [None]:
jac( blahh )( 3.0 )

In [None]:
def blah3( a ):
    a *= 4
    return ( a*5 )**2

In [None]:
jac( blah3 )( 3.0 )

In [None]:
blah3( 3.0 )

In [None]:
class tester():
    def __init__( self ):
        self.a = 4
        
    def blah( self, b ):
        self.b = b
        self.stuff()
        return self.c
        
    def stuff( self ):
        self.b *= 4
        self.c = ( self.b * 3 )**2

In [None]:
t = tester()

In [None]:
t.blah( 2.4 )

In [None]:
def blah4( b ):
    t = tester()
    return t.blah( b )

In [None]:
blah4( 2.4 )

In [None]:
jac( blah4 )( 2.4 )

In [None]:
from recordclass import recordclass

In [None]:
Succ = recordclass( 'Succ', [ 'a', 'b' ] )
class blahh( Succ ):
    @property
    def yo( self ):
        return self.a*2

In [None]:
blah = blahh( 3, 1 )

In [None]:
blah.yo

In [None]:
a, b = blah

In [None]:
a

In [None]:
from autograd.misc.optimizers import adam
from autograd import grad
import autograd.numpy as np

In [None]:
def blah( params, iter ):
    a, b, c = params
    return np.sum( np.sum( a )*np.sum( b ) + c )

In [None]:
a = np.random.random( ( 4, 3 ) )
b = np.random.random( ( 4, 3 ) )
c = np.random.random( ( 4, 3 ) )
params = ( a, b, c )

In [None]:
g = grad( blah )

In [None]:
g( params, 1 )

In [None]:
def blah( a, b ):
    print( a, b )


In [None]:
k = partial( blah, b=4 )

In [None]:
k( 1 )

In [None]:
k( 5 )

In [None]:
import string
letters = string.ascii_lowercase[ :10 ]

In [None]:
letters + ',' + ','.join( [ i for i in letters[ :-1 ] ] ) + '->' + letters[ -1 ]

In [None]:
from autograd import value_and_grad

In [None]:
t = None
def blah( a ):
    global t
    b = np.sum( a, axis=0 )
    def blah2( b ):
        c = np.linalg.cholesky( b )
        d = np.sum( c )
        return d
    val, t = value_and_grad( blah2 )( b )
    return val

def blahh( a ):
    b = np.sum( a, axis=0 )
    c = np.linalg.cholesky( b )
    d = np.sum( c )
    return d

In [None]:
import scipy.stats
a = scipy.stats.invwishart.rvs( scale=np.eye( 4 ), df=5, size=5 )

In [None]:
value_and_grad( blah )( a )

In [None]:
t._value

In [None]:
grad( blahh )( a )

In [None]:
logsumexp

In [None]:
logsumexp( np.array( [ np.array( [ 1, 2, 3 ] ) ] ), axis=0 )

In [None]:
class blah():
    data = 1
    def __init__( self ):
        self.data = blah.data
        blah.data += 1

In [None]:
blahs = [ blah() for i in range( 10 ) ]

In [None]:
np.array( blahs )[ [ 1, 2, 3 ]]

In [None]:
asdf = np.array( blahs )

In [None]:
asdf

In [None]:
np.random.shuffle( asdf )

In [None]:
list( range( 10 ) )[ ]

In [None]:
p = np.array( [-4.90028392, -0.54319243, -0.8875461 ] )

In [None]:
p[ p < 3 ] = 2

In [None]:
p

In [None]:
grad

In [None]:
def blah( x ):
    return np.maximum( x, 0 )

In [None]:
x = np.random.random( 10 ) - 0.5
jac( blah )( x )

In [None]:
blah( x )

In [None]:
w = np.random.random( ( 3, 4 ) )
a = np.random.random( ( 5, 4 ) )
b = np.random.random( 3 )

In [None]:
np.log( np.einsum( 'ij,tj->ti', w, a ) + b[ None ] )

In [None]:
k = logsumexp( np.log( w[ None, :, : ] ) + np.log( a[ :, None, : ] ), axis=2 )

In [None]:
np.log( np.exp( k - 4 ) + np.exp( np.log( b[ None ] ) - 4 ) ) + 4

In [None]:
def logadd( log_a, log_b ):
    max_a = np.max( log_a )
    max_b = np.max( log_b )
    maximum = np.max( [ max_a, max_b ] )
    return np.log( np.exp( log_a - maximum ) + np.exp( log_b - maximum ) ) + maximum
import os
import sys
top_level_dir = '/'.join( os.getcwd().split( '/' )[ :-2 ] )
if top_level_dir not in sys.path:
    sys.path.append( top_level_dir )
from GenModels.GM.Distributions.Normal import Normal
from GenModels.GM.Distributions.TensorNormal import TensorNormal

In [None]:
y = np.array( [ [ 1 ] ] )
d_out = 2
d_in = 3
recognizer_hidden_size = 10
Wr1 = TensorNormal.generate( Ds=( recognizer_hidden_size, d_out ) )[ 0 ]
br1 = Normal.generate( D=recognizer_hidden_size )
Wr2 = TensorNormal.generate( Ds=( d_in, recognizer_hidden_size ) )[ 0 ]
br2 = Normal.generate( D=d_in )
recognizer_params = [ ( Wr1, br1 ), ( Wr2, br2 ) ]

In [None]:
# Turn y into a one hot vector
last_layer = np.zeros( ( y.shape[ 0 ], d_out ) )
last_layer[ np.arange( y.shape[ 0 ] ), y ] = 1.0

last_layer = np.log( last_layer )

for W, b in recognizer_params[ :-1 ]:

    # last_layer = np.tanh( np.einsum( 'ij,tj->ti', W, last_layer ) + b[ None ] )
    k = logsumexp( np.log( W[ None, :, : ] ) + last_layer[ :, None, : ], axis=2 )
    last_layer = logadd( k, np.log( b[ None ] ) )
    last_layer = last_layer - logsumexp( last_layer, axis=1 )[ None ]

W, b = recognizer_params[ -1 ]

last_layer = logsumexp( np.log( W[ None, :, : ] ) + last_layer[ :, None, : ], axis=2 )
last_layer = logsumexp( last_layer, axis=0 )
last_layer = logadd( last_layer, np.log( b ) )
# last_layer = np.einsum( 'ij,tj->i', W, last_layer ) + b
logits = last_layer - logsumexp( last_layer )
print( logits )