In [1]:
import numpy as np

In [2]:
def lp_original(X, Z, sigma_X = None, sigma_A = None):
    N, D = X.shape
    K = Z.shape[1]
    invMat = np.linalg.inv(Z.T @ Z + sigma_X**2 / sigma_A**2 * np.eye(K))
    res = - N * D / 2 * np.log(2 * np.pi)
    res -= (N - K) * D * np.log(sigma_X)
    res -= K * D * np.log(sigma_A)
    res += D / 2 * np.log(np.linalg.det(invMat))
    res -= 1 / (2 * sigma_X**2) * np.trace(X.T @ (np.eye(N) - Z @ invMat @ Z.T) @ X)
    return res

def lp(X, Z, sigma_X = None, sigma_A = None):
    N, D = X.shape
    trX = np.trace(X.T @ X)
    K = Z.shape[1]
    
    u, s, v = np.linalg.svd(Z,full_matrices=False)
    det = np.prod(s**2 + sigma_X**2 / sigma_A**2 * np.ones(len(s)))
    l = s**2 / (s**2 + sigma_X**2 / sigma_A**2 * np.ones(len(s)))
    uTX = u.T @ X
    uX = [np.linalg.norm(uTX[x,:])**2 for x in range(uTX.shape[0])]

    res = - N * D / 2 * np.log(2 * np.pi)
    res -= (N - K) * D * np.log(sigma_X)
    res -= K * D * np.log(sigma_A)
    res -= D / 2 * np.log(det)
    res -= 1 / (2 * sigma_X**2) * (trX - sum(l * uX))
    return res

In [3]:
def test_lp(N, D, K, times, sigma_X, sigma_A):
    if sigma_X is None:
        sigma_X = 1
    if sigma_A is None:
        sigma_A = 1
    diff = np.zeros(times)
    diff[:] = np.nan
    for i in range(times):
        X = np.random.normal(0, 1, (N, D))
        Z = np.random.normal(0, 1, (N, K))  
        diff[i] = lp_original(X = X, Z = Z, sigma_X = 1, sigma_A = 1)-lp(X = X, Z = Z, sigma_X = 1, sigma_A = 1)
    return diff

In [4]:
diffs = test_lp(200, 10, 20, 10000, 1, 1)
max(abs(diffs))

1.3642420526593924e-12

In [5]:
%prun -q -D test.prof test_lp(200, 10, 20, 10000, 1, 1)

 
*** Profile stats marshalled to file 'test.prof'. 


In [6]:
import pstats
p = pstats.Stats('test.prof')
p.print_stats()
pass

Thu Apr 23 22:05:31 2020    test.prof

         3690005 function calls (3490005 primitive calls) in 16.249 seconds

   Random listing order was used

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    40000    0.016    0.000    0.016    0.000 {method 'get' of 'dict' objects}
    10000    0.004    0.000    0.004    0.000 {method 'items' of 'dict' objects}
        1    0.000    0.000   16.249   16.249 {built-in method builtins.exec}
    20000    0.011    0.000    0.011    0.000 {built-in method builtins.getattr}
    20000    0.015    0.000    0.015    0.000 {built-in method builtins.isinstance}
   490000    0.099    0.000    0.099    0.000 {built-in method builtins.issubclass}
    20000    0.007    0.000    0.007    0.000 {built-in method builtins.len}
    10000    0.128    0.000    0.128    0.000 {built-in method builtins.sum}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
    10000    2.797    0.000    4.600    0.