In [1]:
import numpy as np
import time
from tqdm import trange

In [4]:
def lp_original(X, Z, sigma_X = None, sigma_A = None):
    N, D = X.shape
    K = Z.shape[1]
    invMat = np.linalg.inv(Z.T @ Z + sigma_X**2 / sigma_A**2 * np.eye(K))
    res = - N * D / 2 * np.log(2 * np.pi)
    res -= (N - K) * D * np.log(sigma_X)
    res -= K * D * np.log(sigma_A)
    res += D / 2 * np.log(np.linalg.det(invMat))
    res -= 1 / (2 * sigma_X**2) * np.trace(X.T @ (np.eye(N) - Z @ invMat @ Z.T) @ X)
    return res

def lp(trX, X, Z, sigma_X = None, sigma_A = None):
    N, D = X.shape
    
    K = Z.shape[1]
    
    u, s, v = np.linalg.svd(Z,full_matrices=False)
    det = np.prod(s**2 + sigma_X**2 / sigma_A**2)
    l = s**2 / (s**2 + sigma_X**2 / sigma_A**2)
    uTX = u.T @ X
    uX = np.sum(uTX ** 2, axis = 1)

    res = - N * D / 2 * np.log(2 * np.pi)
    res -= (N - K) * D * np.log(sigma_X)
    res -= K * D * np.log(sigma_A)
    res -= D / 2 * np.log(det)
    res -= 1 / (2 * sigma_X**2) * (trX - sum(l * uX))
    return res

In [5]:
def test_lp(N, D, K, times, sigma_X = None, sigma_A = None):
    if sigma_X is None:
        sigma_X = 1
    if sigma_A is None:
        sigma_A = 1
    diff = np.zeros(times)
    diff[:] = np.nan
    time_arr = np.zeros((2, times))
    for i in trange(times):
        X = np.random.normal(0, 1, (N, D))
        Z = np.random.randint(0, 1, (N, K))  
        trX = np.trace(X.T @ X)
        time_start = time.time()
        ans1 = lp_original(X = X, Z = Z, sigma_X = 1, sigma_A = 1)
        time_arr[0, i] = time.time() - time_start
        time_start = time.time()
        ans2 = lp(trX = trX, X = X, Z = Z, sigma_X = 1, sigma_A = 1)
        time_arr[1, i] = time.time() - time_start
        diff[i] = ans2 - ans1
    
    return np.max(abs(diff)), np.mean(time_arr, axis = 1), np.max(time_arr, axis = 1)

In [6]:
diffs = test_lp(5000, 1000, 500, 100, 1, 1)
print(diffs)

  1%|          | 1/100 [00:06<10:11,  6.17s/it]

KeyboardInterrupt: 

In [None]:
diffs = test_lp(5000, 2000, 500, 100, 1, 1)
print(diffs)

In [30]:
diffs = test_lp(5000, 2000, 1000, 100, 1, 1)
print(diffs)

(0.0, array([29.50684445,  0.47264665]), array([29.71937585,  0.63346362]))


In [31]:
diffs = test_lp(10000, 2000, 500, 100, 1, 1)
print(diffs)

(0.0, array([13.14423772,  0.3135772 ]), array([13.33555889,  0.37853098]))


In [32]:
diffs = test_lp(10000, 2000, 1000, 100, 1, 1)
print(diffs)

(0.0, array([61.38463615,  0.90489028]), array([61.64979887,  1.05102801]))


In [33]:
diffs = test_lp(10000, 4000, 1000, 100, 1, 1)
print(diffs)

(0.0, array([64.49265615,  1.3256369 ]), array([64.73263001,  1.54279494]))


In [13]:
arr = [0] * 100
X = np.random.normal(0, 1, (10000, 4000))
Z = np.random.randint(0, 1, (10000, 2000))  
trX = np.trace(X.T @ X)
for i in range(100):
    time_start = time.time()
    ans = lp(trX = trX, X = X, Z = Z, sigma_X = 1, sigma_A = 1)
    arr[i] = time.time() - time_start
    print(i, arr[i])
print(np.mean(arr[60:]), np.max(arr[60:]))

0 4.445885896682739
1 4.640552997589111
2 4.650594234466553
3 4.533618927001953
4 4.723412990570068
5 4.5623509883880615
6 4.442514896392822
7 4.512377023696899
8 4.676807880401611
9 4.614368677139282
10 4.49978494644165
11 4.779513120651245
12 4.480829954147339
13 4.751262187957764
14 5.031541109085083
15 6.109114170074463
16 5.403279066085815
17 5.141122102737427
18 5.763437032699585
19 5.781904935836792
20 6.897506237030029
21 6.039398193359375
22 5.9285900592803955
23 5.260403871536255
24 5.254140138626099
25 5.673856973648071
26 4.635652303695679
27 4.96704888343811
28 5.409122705459595
29 5.164546728134155
30 5.149062156677246
31 5.552958011627197
32 5.35636305809021
33 5.880333185195923
34 6.4204747676849365
35 4.544293165206909
36 4.747936010360718
37 5.369999885559082
38 6.62015700340271
39 5.083015203475952
40 4.859413146972656
41 6.485748052597046
42 6.705762147903442
43 5.250417947769165
44 4.858508110046387
45 5.299223899841309
46 5.483964920043945
47 5.624140977859497
48 

NameError: name 'mean' is not defined

In [14]:
print(np.mean(arr[60:]), np.max(arr[60:]))

4.822995638847351 5.329227924346924


In [15]:
%prun -q -D test.prof diffs = test_lp(N = 1000, D = 20, K = 10, times = 10000, sigma_X = 1, sigma_A = 1)


*** Profile stats marshalled to file 'test.prof'. 


In [13]:
import pstats
p = pstats.Stats('test.prof')
p.print_stats()
pass
(0.0, array([1.4094832 , 0.05192785]), array([1.515692  , 0.07405591]))

Fri Apr 24 19:25:57 2020    test.prof

         840005 function calls in 42.547 seconds

   Random listing order was used

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    40000    0.012    0.000    0.012    0.000 {method 'get' of 'dict' objects}
    20000    0.005    0.000    0.005    0.000 {method 'items' of 'dict' objects}
        1    0.000    0.000   42.547   42.547 {built-in method builtins.exec}
    20000    0.009    0.000    0.009    0.000 {built-in method builtins.getattr}
    30000    0.023    0.000    0.023    0.000 {built-in method builtins.isinstance}
    90000    0.014    0.000    0.014    0.000 {built-in method builtins.issubclass}
    10000    0.067    0.000    0.067    0.000 {built-in method builtins.sum}
        1    0.000    0.000    0.000    0.000 {method 'disable' of '_lsprof.Profiler' objects}
    10000   23.538    0.002   30.674    0.003 <ipython-input-1-647e06f663af>:1(lp_original)
    10000    0.589    0.000    2.671    0.000 <ipython