In [13]:
import autograd.numpy as np
import autograd
import autograd.numpy as anp
import itertools
from functools import partial, namedtuple
from autograd.extend import primitive, defvjp
from autograd import jacobian as jac

In [2]:
def logsumexp( v, axis=0 ):
    max_v = anp.max( v )
    return anp.log( anp.sum( anp.exp( v - max_v ), axis=axis ) ) + max_v

def logsumexp_vjp(ans, x):
    x_shape = x.shape
    return lambda g: anp.full(x_shape, g) * anp.exp(x - np.full(x_shape, ans))
defvjp(logsumexp, logsumexp_vjp)

In [5]:
def gumbelSample( shape, eps=1e-8 ):
    u = anp.random.random( shape )
    return -anp.log( -anp.log( u + eps ) + eps )
def gumbelSoftmaxSample( logits, g=None, temp=1.0 ):
    if( g is None ):
        g = gumbelSample( logits.shape )
    y = logits + g
    ans = anp.exp( y ) / temp
    return ans / ans.sum()

In [17]:
def alphas_unrolled( theta ):
    T, K = theta.L.shape
    a_0 = theta.pi0 + theta.L[ 0 ]
    a_1 = logsumexp( a_0[ :, None ] + theta.pi, axis=0 ) + theta.L[ 1 ]
    a_2 = logsumexp( a_1[ :, None ] + theta.pi, axis=0 ) + theta.L[ 2 ]
    return anp.array( [ a_0, a_1, a_2 ] )

def betas_unrolled( theta ):
    T, K = theta.L.shape
    b_2 = anp.zeros( K )
    b_1 = logsumexp( b_2 + theta.pi + theta.L[ 2 ], axis=1 )
    b_0 = logsumexp( b_1 + theta.pi + theta.L[ 1 ], axis=1 )
    return anp.array( [ b_0, b_1, b_2 ] )    

def joints_unrolled( alpha, beta ):
    j_1 = alpha[ 0 ][ :, None ] + pi + L[ 1 ] + beta[ 1 ]
    j_2 = alpha[ 1 ][ :, None ] + pi + L[ 2 ] + beta[ 2 ]
    return anp.array( [ j_1, j_2 ] )

def predictive_unrolled( alpha, beta ):
    T, d_latent = alpha.shape
    joint = joints_unrolled( alpha, beta )
    return joint - anp.reshape( ( alpha + beta )[ :-1 ], ( T-1, d_latent, 1 ) )
    
def alphas( theta ):
    T, K = theta.L.shape
    alpha = anp.zeros( ( T, K ) )
    alpha[ 0 ] = theta.pi0 + theta.L[ 0 ]
    for t in range( 1, T ):
        alpha[ t ] = logsumexp( alpha[ t - 1 ][ :, None ] + theta.pi, axis=0 ) + theta.L[ t ]
    return alpha

def betas( theta ):
    T, K = theta.L.shape
    beta = anp.zeros( ( T, K ) )
    for t in reversed( range( 0, T - 1 ) ):
        beta[ t ] = logsumexp( beta[ t + 1 ] + theta.pi + theta.L[ t + 1 ], axis=1 )
    return beta

def joints( alpha, beta ):
    joints = anp.zeros( ( T-1, d_latent, d_latent ) )
    for t in range( T - 1 ):
        joints[ t ] = alpha[ t ][ :, None ] + pi + L[ t + 1 ] + beta[ t + 1 ]
    return joints

def predictive( alpha, beta ):
    T, d_latent = alpha.shape
    joint = joints( alpha, beta )
    return joint - anp.reshape( ( alpha + beta )[ :-1 ], ( T-1, d_latent, 1 ) )

In [69]:
def sampleX( alpha, beta ):
    T, K = alpha.shape
    log_z = logsumexp( alpha[ -1 ] )
    
    preds = predictive( alpha, beta )
    x_samples = anp.zeros( ( T, K ) )
    
    x_samples[ 0 ] = alpha[ 0 ] + beta[ 0 ] - log_z
    
    for t, p in enumerate( preds ):
        logits = logsumexp( p + x_samples[ t ], axis=1 )
        x_samples[ t + 1 ] = anp.log( gumbelSoftmaxSample( logits ) )
        
    return x_samples

def sampleX_unrolled( alpha, beta, gumb ):
    T, K = alpha.shape
    log_z = logsumexp( alpha[ -1 ] )
        
    preds = predictive_unrolled( alpha, beta )
    
    x_0 = alpha[ 0 ] + beta[ 0 ] - log_z
    
    logits = anp.log( anp.exp( preds[ 0 ] ) @ anp.exp( x_0 ) )
    x_1 = anp.log( gumbelSoftmaxSample( logits, g=gumb[ 1 ] ) )
    
    logits = anp.log( anp.exp( preds[ 1 ] ) @ anp.exp( x_1 ) )
    x_2 = anp.log( gumbelSoftmaxSample( logits, g=gumb[ 1 ] ) )
    
    return anp.array( [ x_0, x_1, x_2 ] )

def hmmSamples( theta ):
    alpha, beta = alphas( theta ), betas( theta )
    return sampleX( alpha, beta )

def hmmSamples_unrolled( theta, gumb ):
    alpha, beta = alphas_unrolled( theta ), betas_unrolled( theta )
    return sampleX_unrolled( alpha, beta, gumb )

def neuralNet( x_samples, theta ):
    
    N = theta.d_latent * theta.d_obs
    W = anp.arange( N ).reshape( ( theta.d_latent, theta.d_obs ) )
    
    y_dist = anp.einsum( 'ij,ti->tj', W, x_samples )
    probs = y_dist[ anp.arange( theta.T ), theta.y ]
    return anp.sum( probs )

In [63]:
Theta = namedtuple( 'Theta', [ 'pi0', 'pi', 'L', 'T', 'd_latent', 'd_obs', 'y' ] )
T = 3
d_latent = 3
d_obs = 2
y = np.random.choice( d_obs, size=T )

In [64]:
pi0 = np.random.random( d_latent )
pi = np.random.random( ( d_latent, d_latent ) )
pi0 = np.log( pi0 )
pi = np.log( pi )
L = np.random.random( ( d_latent, d_obs ) )
L = L.T[ y ]

In [65]:
gumb = gumbelSample( ( T, d_latent ) )

In [66]:
def trueAnswer( L ):
    theta = Theta( pi0, pi, L, T, d_latent, d_obs, y )
    x_samples = hmmSamples_unrolled( theta, gumb )
    return neuralNet( x_samples, theta )
jac( trueAnswer )( L )

array([[-1.4662907 , -0.6942795 ,  2.16057019],
       [-0.3437912 ,  0.46290751, -0.1191163 ],
       [ 0.03566366, -0.10849283,  0.07282917]])

In [258]:
def a( L ):
    theta = Theta( pi0, pi, L, T, d_latent, d_obs, y )
    x_samples = hmmSamples_unrolled( theta, gumb )
    return x_samples
def b( x_samples ):
    theta = Theta( pi0, pi, L, T, d_latent, d_obs, y )
    return neuralNet( x_samples, theta )
da = jac( a )( L )
db = jac( b )( a( L ) )

In [259]:
np.einsum( 'ijab,ij->ab', da, db )

array([[-1.13000478, -1.38964251,  2.51964729],
       [-0.83208038,  1.29573408, -0.46365371],
       [-0.08749182,  0.13307068, -0.04557886]])

## Step 1 - Compute x samples
## Step 2 - Compute dlogP( y | x )/dx
## Step 3 - Accumulate dlogP( y | x )/dL by computing dx/dL and summing immediately

In [362]:
def deriv( theta, gumb ):
    temp = 1.0
    
    # Get the needed stats
    alpha, beta = alphas( theta ), betas( theta )
    preds = predictive( alpha, beta )
    log_z = logsumexp( alpha[ -1 ] )
    
    T, K = alpha.shape
    
    # Initialize the variables
    dXtdLs = np.zeros( ( T, T, K, K ) )
    dXtdXt1 = np.zeros( ( T-1, K, K ) )
    x_samples = np.zeros( ( T, K ) )
    
    # Base case derivative
    x_samples[ 0 ] = alpha[ 0 ] + beta[ 0 ] - log_z
    dXtdLs[ 0, 0 ] = np.eye( K ) - np.exp( x_samples[ 0 ] )

    print( '\nt', 0, 's', 0, 'dXtdLs[ t, t ]\n', dXtdLs[ 0, 0 ] )
    
    for i, p in enumerate( preds ):
        t = i + 1
        
        # Compute x_t | x_t-1
        p = theta.pi + theta.L[ t ] + beta[ t ] - beta[ t - 1 ][ :, None ]
        logits = logsumexp( p + x_samples[ t - 1 ], axis=1 )
        unnormx = logits + gumb[ t ] - np.log( temp )
        x_samples[ t ] = unnormx - logsumexp( unnormx )
        
        # Compute dLogit / dL_t
        dXPdLt = np.eye( K ) - np.exp( theta.L[ t ] + theta.pi + beta[ t ] ).T
        
        # Compute dLogit / dX_t-1
        dLogitdXt1 = np.exp( p + x_samples[ t - 1 ] - logits[ :, None ] )
        dLogitdP = dLogitdXt1
        
        # Compute dX_t / dLogit
        dXtdLogit = np.eye( K ) - np.exp( x_samples[ t ] )
        
        # Compute dX_t / dX_t-1
        dXtdXt1[ i ] = dXtdLogit @ dLogitdXt1
    
        # Compute the derivative dX_t / dL_s for s == t
        dXtdLs[ t, t ] = dXtdLogit @ dLogitdP @ dXPdLt
        
        # Update each of the L derivatives
        for s in reversed( range( t ) ):
            # Compute dX_t / dL_s for s < t
            dXtdLs[ t, s ] = dXtdXt1[ i ] @ dXtdLs[ t-1, s ]
        
    return dXtdLs

In [363]:
theta = Theta( pi0, pi, L, T, d_latent, d_obs, y )
deriv( theta, gumb )


t 0 s 0 dXtdLs[ t, t ]
 [[ 0.75561822 -0.44904658 -0.30657163]
 [-0.24438178  0.55095342 -0.30657163]
 [-0.24438178 -0.44904658  0.69342837]]


array([[[[ 0.75561822, -0.44904658, -0.30657163],
         [-0.24438178,  0.55095342, -0.30657163],
         [-0.24438178, -0.44904658,  0.69342837]],

        [[ 0.        ,  0.        ,  0.        ],
         [ 0.        ,  0.        ,  0.        ],
         [ 0.        ,  0.        ,  0.        ]],

        [[ 0.        ,  0.        ,  0.        ],
         [ 0.        ,  0.        ,  0.        ],
         [ 0.        ,  0.        ,  0.        ]]],


       [[[ 0.08986504, -0.10107791,  0.01121287],
         [ 0.12682073, -0.30038841,  0.17356768],
         [-0.02320787,  0.05281617, -0.0296083 ]],

        [[ 0.15352277, -0.2211536 ,  0.40992592],
         [ 0.14916173, -1.0293167 ,  0.92114144],
         [-0.02821793,  0.17856803, -0.16388745]],

        [[ 0.        ,  0.        ,  0.        ],
         [ 0.        ,  0.        ,  0.        ],
         [ 0.        ,  0.        ,  0.        ]]],


       [[[-0.00797907,  0.0194045 , -0.01142542],
         [-0.02253939,  0.05361719

In [364]:
da.transpose( 0, 2, 1, 3 )

array([[[[ 7.55618217e-01, -4.49046583e-01, -3.06571634e-01],
         [-2.44381783e-01,  5.50953417e-01, -3.06571634e-01],
         [-2.44381783e-01, -4.49046583e-01,  6.93428366e-01]],

        [[ 3.26463435e-02,  2.52806087e-02, -5.79269522e-02],
         [ 6.21143449e-02, -1.40845923e-01,  7.87315777e-02],
         [-1.17004974e-01,  1.86149838e-01, -6.91448646e-02]],

        [[ 8.91936531e-03, -1.43948637e-02,  5.47549836e-03],
         [-1.66117260e-02,  4.00088512e-02, -2.33971252e-02],
         [ 1.72217774e-02, -4.71276331e-02,  2.99058558e-02]]],


       [[[ 8.98650435e-02, -1.01077910e-01,  1.12128664e-02],
         [ 1.26820735e-01, -3.00388411e-01,  1.73567676e-01],
         [-2.32078695e-02,  5.28161663e-02, -2.96082968e-02]],

        [[-1.26362414e-01,  1.30269462e-01, -3.90704849e-03],
         [-1.86000309e-01,  3.55624649e-01, -1.69624341e-01],
         [ 3.39327950e-02, -6.27325253e-02,  2.87997303e-02]],

        [[ 2.40068934e-04, -4.40108480e-03,  4.16101587e-0

In [367]:
alpha, beta = alphas( theta ), betas( theta )

In [379]:
def forAG( func ):
    def wrapper( _L ):
        theta = Theta( pi0, pi, _L, T, d_latent, d_obs, y )
        return func( theta )
    return wrapper

In [391]:
def pred_unrolled( theta ):
    beta = betas_unrolled( theta )
    p1 = theta.pi + theta.L[ 1 ] + beta[ 1 ] - beta[ 0 ][ :, None ]
    p2 = theta.pi + theta.L[ 2 ] + beta[ 2 ] - beta[ 1 ][ :, None ]
    return anp.array( [ p1, p2 ] )
@forAG
def pred_unrolled_ag( theta ):
    return pred_unrolled( theta )

In [388]:
def mypred_jac( theta ):
    preds = pred_unrolled( theta )

(2, 3, 3)

# d*beta*<sup>(t)</sup>/d*L*<sup>(s)</sup>

In [416]:
def betas_unrolled( theta ):
    T, K = theta.L.shape
    b_2 = anp.zeros( K )
    b_1 = logsumexp( b_2 + theta.pi + theta.L[ 2 ], axis=1 )
    b_0 = logsumexp( b_1 + theta.pi + theta.L[ 1 ], axis=1 )
    return anp.array( [ b_0, b_1, b_2 ] )    

@forAG
def betas_unrolled_ag( theta ):
    return betas_unrolled( theta )

def dbetadL( theta ):
    betas = betas_unrolled( theta )
    T, K = betas.shape
    betas_jac = np.zeros( ( T, K, T, K ) )
    for t in range( T-2, -1, -1 ):
        val = np.exp( theta.pi + theta.L[ t+1 ] + betas[ t+1 ] - betas[ t ][ :, None ] )
        betas_jac[ t, :, t+1, : ] = val
        for s in range( t+2, T ):
            betas_jac[ t, :, s, : ] = val @ betas_jac[ t+1, :, s, : ]
    return betas_jac

# d*F<sup>(t)</sup>*/d*L<sup>(s)</sup>*

In [947]:
def F_unrolled( theta ):
    beta = betas_unrolled( theta )
    f1 = theta.pi + theta.L[ 1 ] + beta[ 1 ] - beta[ 0 ][ :, None ]
    f2 = theta.pi + theta.L[ 2 ] + beta[ 2 ] - beta[ 1 ][ :, None ]
    return anp.array( [ f1, f2 ] )

@forAG
def F_unrolled_ag( theta ):
    return F_unrolled( theta )

def dFdL( theta ):
    F = F_unrolled( theta )
    T, K = theta.L.shape
    f_jac = np.zeros( F.shape + theta.L.shape )
    b_jac = dbetadL( theta )
    for t in range( 1, T ):
        val = np.eye( K ) - np.exp( F[ t-1 ] )
        for s in range( T ):
            
            f_jac[ t-1, :, :, s, : ] = b_jac[ t, :, s, : ] - b_jac[ t-1, :, s, : ][ :, None, : ]
            if( t == s ):
                f_jac[ t-1, :, :, s, : ] += np.eye( K )
            
    return f_jac

# d*H<sup>(t)</sup>*/d*L<sup>(s)</sup>*

In [1299]:
def x_samples_unrolled( theta, gumb, temp=1.0 ):
    F = F_unrolled( theta )
    T, K = theta.T, theta.d_latent
    x_samples = anp.zeros( ( T, K ) )
    
    alpha, beta = alphas_unrolled( theta ), betas_unrolled( theta )
    log_z = logsumexp( alpha[ -1 ] )
    x_samples_0 = theta.pi0 + theta.L[ 0 ] + beta[ 0 ] - log_z
    
    H0 = logsumexp( F[ 0 ] + x_samples_0, axis=1 )
    G = H0 + gumb[ 1 ] - anp.log( temp )
    x_samples_1 = G - logsumexp( G )

    H1 = logsumexp( F[ 1 ] + x_samples_1, axis=1 )
    G = H1 + gumb[ 2 ] - anp.log( temp )
    x_samples_2 = G - logsumexp( G )

    return anp.array( [ x_samples_0, x_samples_1, x_samples_2 ] )

@forAG
def x_samples_unrolled_ag( theta ):
    return x_samples_unrolled( theta, gumb )

def dHdL( theta, gumb, temp=1.0 ):
    F = F_unrolled( theta )
    dF = dFdL( theta )
    dB = dbetadL( theta )
    
    T, K = theta.T, theta.d_latent
    x_samples = anp.zeros( ( T, K ) )
    dX = np.zeros( ( T, K, T, K ) )
    
    alpha, beta = alphas_unrolled( theta ), betas_unrolled( theta )
    log_z = logsumexp( alpha[ -1 ] )
    x_samples[ 0 ] = theta.pi0 + theta.L[ 0 ] + beta[ 0 ] - log_z
    
    dX[ 0 ] = dB[ 0 ]
    dX[ 0, :, :, : ] -= np.exp( alpha + beta - log_z )
    dX[ 0, :, 0, : ] += np.eye( K )
    
    for t in range( 1, T ):
        H = logsumexp( F[ t-1 ] + x_samples[ t-1 ], axis=1 )
        G = H + gumb[ t ] - anp.log( temp )
        x_samples[ t ] = G - logsumexp( G )
        
        # H are the logits
        # F are the conditioned transition matrix
        # G are the unnormalized x samples

        for s in range( T ):
            deriv = dF[ t-1, :, :, s, : ] + dX[ t-1, :, s, : ][ None, :, : ]
            val = np.exp( F[ t-1 ] + x_samples[ t-1 ] - H[ :, None ] )
            dH = np.einsum( 'ij,ijk->ijk', val, deriv )
            tmp = 1 - np.exp( G - logsumexp( G ) )
            dX[ t, :, s, : ] = np.einsum( 'i,ijk->ik', tmp, dH )
        
    return dX

In [1300]:
dHdL( theta, gumb )

array([[[[ 0.75561822, -0.44904658, -0.30657163],
         [ 0.03264634,  0.02528061, -0.05792695],
         [ 0.00891937, -0.01439486,  0.0054755 ]],

        [[-0.24438178,  0.55095342, -0.30657163],
         [ 0.06211434, -0.14084592,  0.07873158],
         [-0.01661173,  0.04000885, -0.02339713]],

        [[-0.24438178, -0.44904658,  0.69342837],
         [-0.11700497,  0.18614984, -0.06914486],
         [ 0.01722178, -0.04712763,  0.02990586]]],


       [[[-0.02791065,  0.00398982,  0.02392082],
         [-0.03125553,  0.04752293, -0.0162674 ],
         [ 0.00420886, -0.0118033 ,  0.00759444]],

        [[-0.0310497 , -0.18078   ,  0.21182971],
         [-0.09592882,  0.13779689, -0.04186807],
         [ 0.01148529, -0.0333516 ,  0.02186631]],

        [[-0.11751182,  0.10545616,  0.01205566],
         [-0.01968889,  0.05098026, -0.03129136],
         [ 0.00639153, -0.01494169,  0.00855016]]],


       [[[-0.05785748,  0.03102342,  0.02683406],
         [-0.02560338,  0.04650705

In [1296]:
jac( x_samples_unrolled_ag )( L )

array([[[[ 7.55618217e-01, -4.49046583e-01, -3.06571634e-01],
         [ 3.26463435e-02,  2.52806087e-02, -5.79269522e-02],
         [ 8.91936531e-03, -1.43948637e-02,  5.47549836e-03]],

        [[-2.44381783e-01,  5.50953417e-01, -3.06571634e-01],
         [ 6.21143449e-02, -1.40845923e-01,  7.87315777e-02],
         [-1.66117260e-02,  4.00088512e-02, -2.33971252e-02]],

        [[-2.44381783e-01, -4.49046583e-01,  6.93428366e-01],
         [-1.17004974e-01,  1.86149838e-01, -6.91448646e-02],
         [ 1.72217774e-02, -4.71276331e-02,  2.99058558e-02]]],


       [[[ 3.86093376e-02, -4.57825240e-02,  7.17318646e-03],
         [-1.57952071e-02,  1.27514407e-02,  3.04376643e-03],
         [ 1.24872450e-04, -1.94678441e-03,  1.82191196e-03]],

        [[ 7.55650288e-02, -2.45093025e-01,  1.69527996e-01],
         [-3.84774105e-02,  3.87961262e-02, -3.18715660e-04],
         [ 1.67867654e-03, -7.50090644e-03,  5.82222991e-03]],

        [[-7.44635755e-02,  1.08111552e-01, -3.36479767e-0

In [1243]:
np.ones( 4 )[ None, : ]

array([[1., 1., 1., 1.]])

In [1276]:
def gumbelSample( shape, eps=1e-8 ):
    u = anp.random.random( shape )
    return -anp.log( -anp.log( u + eps ) + eps )
gumb = gumbelSample( ( T, d_latent ) )

In [1333]:
def activation( x ):
    return anp.sin( x )

def check( x ):
    d_in = x.shape[ -1 ]
    d_out = 100
    W1 = anp.arange( d_in * d_out ).reshape( ( d_in, d_out ) )
    b1 = anp.arange( d_out )
    
    d_in = 100
    d_out = 1
    W2 = anp.arange( d_in * d_out ).reshape( ( d_in, d_out ) )
    b2 = anp.arange( d_out )
    
    z = activation( anp.einsum( 'ij,ti->tj', W1, x ) + b1 )
    
    return anp.sum( activation( anp.einsum( 'ij,ti->tj', W2, z + b2 ) ) )

In [1334]:
x = np.random.random( ( 10, 3 ) )

In [1335]:
check( x )

-2.2740183982017976

In [1338]:
%%timeit
jac( check )( x )

355 µs ± 1.26 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [1337]:
x

array([[0.60521847, 0.02983659, 0.33941179],
       [0.77536835, 0.88879314, 0.19546331],
       [0.60456678, 0.443581  , 0.8629852 ],
       [0.00963859, 0.16892713, 0.72825752],
       [0.79318225, 0.98739385, 0.26031268],
       [0.22024778, 0.12707345, 0.38055451],
       [0.99660973, 0.55210765, 0.07690179],
       [0.23970282, 0.62977837, 0.90661498],
       [0.87995858, 0.89019672, 0.84628363],
       [0.32152712, 0.35871458, 0.8535235 ]])

In [1344]:
def asdf( x ):
    return anp.sum( x**2 )

In [1345]:
jac( asdf )( x )

array([[1.21043693, 0.05967318, 0.67882357],
       [1.5507367 , 1.77758628, 0.39092662],
       [1.20913355, 0.887162  , 1.72597039],
       [0.01927718, 0.33785425, 1.45651504],
       [1.58636451, 1.9747877 , 0.52062536],
       [0.44049555, 0.2541469 , 0.76110901],
       [1.99321946, 1.10421531, 0.15380357],
       [0.47940564, 1.25955675, 1.81322995],
       [1.75991717, 1.78039344, 1.69256726],
       [0.64305424, 0.71742917, 1.70704699]])

In [1347]:
2*x

array([[1.21043693, 0.05967318, 0.67882357],
       [1.5507367 , 1.77758628, 0.39092662],
       [1.20913355, 0.887162  , 1.72597039],
       [0.01927718, 0.33785425, 1.45651504],
       [1.58636451, 1.9747877 , 0.52062536],
       [0.44049555, 0.2541469 , 0.76110901],
       [1.99321946, 1.10421531, 0.15380357],
       [0.47940564, 1.25955675, 1.81322995],
       [1.75991717, 1.78039344, 1.69256726],
       [0.64305424, 0.71742917, 1.70704699]])

In [1348]:
def fwd( theta ):
    T, K = theta.L.shape
    alpha = []
    alpha.append( theta.pi0 + theta.L[ 0 ] )
    for t in range( 1, T ):
        alpha.append( logsumexp( alpha[ -1 ][ :, None ] + theta.pi, axis=0 ) + theta.L[ t ] )
    return anp.array( alpha )

In [1349]:
@forAG
def fwdAG( theta ):
    return fwd( theta )

In [1350]:
jac( fwdAG )( L )

array([[[[1.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        ]],

        [[0.        , 1.        , 0.        ],
         [0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        ]],

        [[0.        , 0.        , 1.        ],
         [0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        ]]],


       [[[0.28211374, 0.58096008, 0.13692618],
         [1.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        ]],

        [[0.26349003, 0.25343247, 0.4830775 ],
         [0.        , 1.        , 0.        ],
         [0.        , 0.        , 0.        ]],

        [[0.21395349, 0.52503869, 0.26100783],
         [0.        , 0.        , 1.        ],
         [0.        , 0.        , 0.        ]]],


       [[[0.25767243, 0.40356353, 0.33876404],
         [0.27519173, 0.50390694, 0.22090133],
         [1.        , 0.        , 0.        

In [1351]:
@forAG
def blah( theta ):
    return alphas_unrolled( theta )

In [1352]:
jac( blah )( L )

array([[[[1.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        ]],

        [[0.        , 1.        , 0.        ],
         [0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        ]],

        [[0.        , 0.        , 1.        ],
         [0.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        ]]],


       [[[0.28211374, 0.58096008, 0.13692618],
         [1.        , 0.        , 0.        ],
         [0.        , 0.        , 0.        ]],

        [[0.26349003, 0.25343247, 0.4830775 ],
         [0.        , 1.        , 0.        ],
         [0.        , 0.        , 0.        ]],

        [[0.21395349, 0.52503869, 0.26100783],
         [0.        , 0.        , 1.        ],
         [0.        , 0.        , 0.        ]]],


       [[[0.25767243, 0.40356353, 0.33876404],
         [0.27519173, 0.50390694, 0.22090133],
         [1.        , 0.        , 0.        

In [1393]:
def nonFBSMultiplyTerms( terms ):
    # Basically np.einsum but in log space
    terms = list( terms )

    # Separate out where the feedback set axes start and get the largest fbs_axis.
    # Need to handle case where ndim of term > all fbs axes
    # terms, fbs_axes_start = list( zip( *terms ) )
    fbs_axes_start = [ -1, -1 ]

    if( max( fbs_axes_start ) != -1 ):
        max_fbs_axis = max( [ ax if ax != -1 else term.ndim for ax, term in zip( fbs_axes_start, terms ) ] )

        if( max_fbs_axis > 0 ):
            # Pad extra dims at each term so that the fbs axes start the same way for every term
            for i, ax in enumerate( fbs_axes_start ):
                if( ax == -1 ):
                    for _ in range( max_fbs_axis - terms[ i ].ndim + 1 ):
                        terms[ i ] = terms[ i ][ ..., None ]
                else:
                    for _ in range( max_fbs_axis - ax ):
                        terms[ i ] = anp.expand_dims( terms[ i ], axis=ax )
    else:
        max_fbs_axis = -1

    ndim = max( [ len( term.shape ) for term in terms ] )

    axes = [ [ i for i, s in enumerate( t.shape ) if s != 1 ] for t in terms ]

    # Get the shape of the output
    shape = anp.ones( ndim, dtype=int )
    for ax, term in zip( axes, terms ):
        shape[ anp.array( ax ) ] = term.squeeze().shape

    total_elts = shape.prod()
    if( total_elts > 1e8 ):
        assert 0, 'Don\'t do this on a cpu!  Too many terms: %d'%( int( total_elts ) )

    # Basically np.einsum in log space
    ans = anp.zeros( shape )
    for ax, term in zip( axes, terms ):

        for _ in range( ndim - term.ndim ):
            term = term[ ..., None ]

        ans += term
#         ans += np.broadcast_to( term, ans.shape )

    return ans

In [1394]:
def blah( x ):
    y = np.random.random( ( 1, 4, 2, 1, 1, 1, 8 ) )
    return nonFBSMultiplyTerms( ( x, y ) )

In [1395]:
x = np.random.random( ( 3, 1, 2, 1 ) )

In [1400]:
import recordclass

In [1403]:
Something = recordclass.recordclass( 'something', [ 'a', 'b' ] )

In [1418]:
def blah2( s ):
    s.a *= 4
    return ( s.a * s.b )**2

In [1422]:
def blahh( a ):
    s = Something( a, 5 )
    return blah2( s )

In [1431]:
blahh( 3 )

3600

In [1432]:
jac( blahh )( 3.0 )

array(2400.)

In [1425]:
def blah3( a ):
    a *= 4
    return ( a*5 )**2

In [1429]:
jac( blah3 )( 3.0 )

array(2400.)

In [1430]:
blah3( 3.0 )

3600.0

In [1436]:
class tester():
    def __init__( self ):
        self.a = 4
        
    def blah( self, b ):
        self.b = b
        self.stuff()
        return self.c
        
    def stuff( self ):
        self.b *= 4
        self.c = ( self.b * 3 )**2

In [1434]:
t = tester()

In [1435]:
t.blah( 2.4 )

829.4399999999998

In [1437]:
def blah4( b ):
    t = tester()
    return t.blah( b )

In [1438]:
blah4( 2.4 )

829.4399999999998

In [1439]:
jac( blah4 )( 2.4 )

array(691.2)

In [1440]:
from recordclass import recordclass

In [1441]:
Succ = recordclass( 'Succ', [ 'a', 'b' ] )
class blahh( Succ ):
    @property
    def yo( self ):
        return self.a*2

In [1448]:
blah = blahh( 3, 1 )

In [1449]:
blah.yo

6

In [1450]:
a, b = blah

In [1451]:
a

3

In [1474]:
from autograd.misc.optimizers import adam
from autograd import grad
import autograd.numpy as np

In [1475]:
def blah( params, iter ):
    a, b, c = params
    return np.sum( np.sum( a )*np.sum( b ) + c )

In [1476]:
a = np.random.random( ( 4, 3 ) )
b = np.random.random( ( 4, 3 ) )
c = np.random.random( ( 4, 3 ) )
params = ( a, b, c )

In [1477]:
g = grad( blah )

In [1481]:
g( params, 1 )

(array([[85.06240416, 85.06240416, 85.06240416],
        [85.06240416, 85.06240416, 85.06240416],
        [85.06240416, 85.06240416, 85.06240416],
        [85.06240416, 85.06240416, 85.06240416]]),
 array([[91.70995451, 91.70995451, 91.70995451],
        [91.70995451, 91.70995451, 91.70995451],
        [91.70995451, 91.70995451, 91.70995451],
        [91.70995451, 91.70995451, 91.70995451]]),
 array([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]]))

In [1484]:
def blah( a, b ):
    print( a, b )


In [1485]:
k = partial( blah, b=4 )

In [1486]:
k( 1 )

1 4


In [1487]:
k( 5 )

5 4


In [1488]:
import string
letters = string.ascii_lowercase[ :10 ]

In [1490]:
letters + ',' + ','.join( [ i for i in letters[ :-1 ] ] ) + '->' + letters[ -1 ]

'abcdefghij,a,b,c,d,e,f,g,h,i->j'

In [1515]:
from autograd import value_and_grad

In [1528]:
t = None
def blah( a ):
    global t
    b = np.sum( a, axis=0 )
    def blah2( b ):
        c = np.linalg.cholesky( b )
        d = np.sum( c )
        return d
    val, t = value_and_grad( blah2 )( b )
    return val

def blahh( a ):
    b = np.sum( a, axis=0 )
    c = np.linalg.cholesky( b )
    d = np.sum( c )
    return d

In [1529]:
import scipy.stats
a = scipy.stats.invwishart.rvs( scale=np.eye( 4 ), df=5, size=5 )

In [1530]:
value_and_grad( blah )( a )

(60.326246428875415,
 array([[[ 0.02395243, -0.1135541 , -0.05461702, -0.05644718],
         [-0.1135541 ,  0.79227359,  0.40746491,  0.4258389 ],
         [-0.05461702,  0.40746491,  0.21296255,  0.21281433],
         [-0.05644718,  0.4258389 ,  0.21281433,  0.22582264]],
 
        [[ 0.02395243, -0.1135541 , -0.05461702, -0.05644718],
         [-0.1135541 ,  0.79227359,  0.40746491,  0.4258389 ],
         [-0.05461702,  0.40746491,  0.21296255,  0.21281433],
         [-0.05644718,  0.4258389 ,  0.21281433,  0.22582264]],
 
        [[ 0.02395243, -0.1135541 , -0.05461702, -0.05644718],
         [-0.1135541 ,  0.79227359,  0.40746491,  0.4258389 ],
         [-0.05461702,  0.40746491,  0.21296255,  0.21281433],
         [-0.05644718,  0.4258389 ,  0.21281433,  0.22582264]],
 
        [[ 0.02395243, -0.1135541 , -0.05461702, -0.05644718],
         [-0.1135541 ,  0.79227359,  0.40746491,  0.4258389 ],
         [-0.05461702,  0.40746491,  0.21296255,  0.21281433],
         [-0.05644718,  0

In [1531]:
t._value

array([[ 0.02395243, -0.1135541 , -0.05461702, -0.05644718],
       [-0.1135541 ,  0.79227359,  0.40746491,  0.4258389 ],
       [-0.05461702,  0.40746491,  0.21296255,  0.21281433],
       [-0.05644718,  0.4258389 ,  0.21281433,  0.22582264]])

In [1532]:
grad( blahh )( a )

array([[[ 0.02395243, -0.1135541 , -0.05461702, -0.05644718],
        [-0.1135541 ,  0.79227359,  0.40746491,  0.4258389 ],
        [-0.05461702,  0.40746491,  0.21296255,  0.21281433],
        [-0.05644718,  0.4258389 ,  0.21281433,  0.22582264]],

       [[ 0.02395243, -0.1135541 , -0.05461702, -0.05644718],
        [-0.1135541 ,  0.79227359,  0.40746491,  0.4258389 ],
        [-0.05461702,  0.40746491,  0.21296255,  0.21281433],
        [-0.05644718,  0.4258389 ,  0.21281433,  0.22582264]],

       [[ 0.02395243, -0.1135541 , -0.05461702, -0.05644718],
        [-0.1135541 ,  0.79227359,  0.40746491,  0.4258389 ],
        [-0.05461702,  0.40746491,  0.21296255,  0.21281433],
        [-0.05644718,  0.4258389 ,  0.21281433,  0.22582264]],

       [[ 0.02395243, -0.1135541 , -0.05461702, -0.05644718],
        [-0.1135541 ,  0.79227359,  0.40746491,  0.4258389 ],
        [-0.05461702,  0.40746491,  0.21296255,  0.21281433],
        [-0.05644718,  0.4258389 ,  0.21281433,  0.22582264]],

