In [1]:
import numpy as np
import os
import sys
import scipy

In [2]:
top_level_dir = '/'.join( os.getcwd().split( '/' )[ :-2 ] )
if top_level_dir not in sys.path:
    sys.path.append( top_level_dir )

In [144]:
from GenModels.GM.Distributions import TensorNormal, TensorRegression
from GenModels.GM.Distributions import InverseWishart, Normal

In [4]:
np.array( TensorNormal.generate( Ds=[ 2, 3, 4 ], size=1 ).shape )

array([1, 2, 3, 4])

In [5]:
M = np.random.random( ( 4, 2, 3 ) )
covs = [ InverseWishart.generate( D=4 ), InverseWishart.generate( D=2 ), InverseWishart.generate( D=3 ) ]

In [6]:
X = np.random.random( ( 4, 2, 3 ) )

In [7]:
TensorNormal.log_partition( x=X, params=( M, covs ) )

1340.982196878431

In [8]:
import autograd.numpy as anp
import autograd
import string
import itertools
from functools import reduce

In [85]:
n1, n2 = TensorNormal.standardToNat( M, covs )

In [86]:
cov_invs = [ np.linalg.inv( cov ) for cov in covs ]

In [110]:
n1 = reduce( lambda x, y: np.kron( x, y ), cov_invs )

In [126]:
def invs( x, k=None ):
    ans = [ np.linalg.inv( _x ) for _x in x ]
    if( k is not None ):
        ans[ 0 ] *= -0.5
    return ans

def realStandardToNat( M, covs ):
    # Avoid this because it is unnecessarily expensive
    cov_invs = [ np.linalg.inv( cov ) for cov in covs ]
    n1 = reduce( lambda x, y: np.kron( x, y ), cov_invs ).reshape( M.shape + M.shape )
    N = len( M.shape ) 
    ind1 = string.ascii_letters[ : N ]
    ind2 = string.ascii_letters[ N : N * 2 ]
    contract = ind2 + ',' + ind1 + ind2 + '->' + ind1    
    n2 = np.einsum( contract, M, n1 )
    return -0.5 * n1, n2

In [127]:
m1, m2 = realStandardToNat( M, covs )

[4, 4, 2, 2, 3, 3]
contract def,abcdef->abc
M.shape (4, 2, 3)
n1.shape (4, 2, 3, 4, 2, 3)


In [128]:
n1.dot( M.ravel() )

array([ 168.83914916,   24.98008754,  123.75923897,  126.18610477,
         25.85595776,   89.49576316,  517.10980888,   53.72652612,
        387.52820584,  496.76506971,   54.56003215,  371.1125931 ,
       -202.58342627,  -20.53750595, -151.71883261, -197.80463401,
        -21.12368699, -147.82725783,  601.79081378,   63.42556405,
        451.16096331,  569.88554907,   63.67172041,  425.67262986])

In [131]:
m2.ravel()

array([ 168.83914916,   24.98008754,  123.75923897,  126.18610477,
         25.85595776,   89.49576316,  517.10980888,   53.72652612,
        387.52820584,  496.76506971,   54.56003215,  371.1125931 ,
       -202.58342627,  -20.53750595, -151.71883261, -197.80463401,
        -21.12368699, -147.82725783,  601.79081378,   63.42556405,
        451.16096331,  569.88554907,   63.67172041,  425.67262986])

In [108]:
m1.shape

(4, 4, 2, 2, 3, 3)

In [109]:
n1.shape

(4, 4, 3, 3, 2, 2)

In [141]:
Ds = [ 4, 2, 3 ]
xs = [ np.random.random( D ) for D in Ds[ 1: ] ]
y = np.random.random( Ds[ 0 ] )[ None ]
M = np.random.random( Ds )
sigma = InverseWishart.generate( D=Ds[ 0 ] )

In [145]:
Normal.log_likelihood( y, params=( np.einsum( 'ijk,j,k', M, *xs ), sigma ) )

-2.211015165457313

In [146]:
Xs = [ np.outer( x, x ) for x in xs ]

In [158]:
x_kron = reduce( lambda x, y: np.kron( x, y ), Xs ).reshape( ( 2, 3, 2, 3 ) )
M_kron = np.kron( M, M ).reshape( ( 4, 4, 2, 2, 3, 3 ) )
print( M_kron.shape )

(4, 4, 2, 2, 3, 3)


In [151]:
np.einsum( 'abc,ijk,b,c,j,k', M, M, *xs, *xs )

array([[0.93272824, 0.85234579, 0.76609866, 0.51937742],
       [0.85234579, 0.7788907 , 0.70007633, 0.47461751],
       [0.76609866, 0.70007633, 0.62923703, 0.42659193],
       [0.51937742, 0.47461751, 0.42659193, 0.28920846]])

In [155]:
np.einsum( 'abc,ijk,bcjk', M, M, x_kron )

array([[0.93272824, 0.85234579, 0.76609866, 0.51937742],
       [0.85234579, 0.7788907 , 0.70007633, 0.47461751],
       [0.76609866, 0.70007633, 0.62923703, 0.42659193],
       [0.51937742, 0.47461751, 0.42659193, 0.28920846]])

In [162]:
np.einsum( 'aibjck,bcjk', M_kron, x_kron )

array([[0.93272824, 0.85234579, 0.76609866, 0.51937742],
       [0.85234579, 0.7788907 , 0.70007633, 0.47461751],
       [0.76609866, 0.70007633, 0.62923703, 0.42659193],
       [0.51937742, 0.47461751, 0.42659193, 0.28920846]])