In [1]:
%matplotlib inline
import numpy as np
import time
from tqdm import tqdm,trange
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs
from numba import jit, vectorize, float64, int64, njit, prange
import math
import random
import sys
import sympy 

In [2]:
def lp_original(trX, X, Z, sigma_X = None, sigma_A = None):
    N, D = X.shape
    K = Z.shape[1]
    invMat = np.linalg.inv(Z.T @ Z + sigma_X**2 / sigma_A**2 * np.eye(K))
    res = - N * D / 2 * np.log(2 * np.pi)
    res -= (N - K) * D * np.log(sigma_X)
    res -= K * D * np.log(sigma_A)
    res += D / 2 * np.log(np.linalg.det(invMat))
    res -= 1 / (2 * sigma_X**2) * np.trace(X.T @ (np.eye(N) - Z @ invMat @ Z.T) @ X)
    return res   

def lp(trX, X, Z = None, sigma_X = None, sigma_A = None):
    N, D = X.shape
    K = Z.shape[1]
    u, s, _ = np.linalg.svd(Z,full_matrices=False)
    det = np.sum(np.log(s**2 + sigma_X**2 / sigma_A**2))
    l = s**2 / (s**2 + sigma_X**2 / sigma_A**2)
    uTX = u.T @ X
    uX = np.sum(uTX ** 2, axis = 1)

    res = - N * D / 2 * np.log(2 * np.pi)
    res -= (N - K) * D * np.log(sigma_X)
    res -= K * D * np.log(sigma_A)
    res -= D / 2 * det
    res -= 1 / (2 * sigma_X**2) * (trX - sum(l * uX))
    return res

@jit
def matrix_multiply_numba(A, B):
    m, n = A.shape
    n, p = B.shape
    C = np.zeros((m, p))
    for i in range(m):
        for j in range(p):
            for k in range(n):
                C[i,j] += A[i,k] * B[k, j]
    return C

@jit(float64(float64, float64[:,:], float64[:,:], float64, float64))
def lp_numba(trX, X, Z = None, sigma_X = None, sigma_A = None):
    N, D = X.shape
    K = Z.shape[1]
    u, s, _ = np.linalg.svd(Z,full_matrices=False)
    det = np.sum(np.log(s**2 + sigma_X**2 / sigma_A**2))
    l = s**2 / (s**2 + sigma_X**2 / sigma_A**2)
    uTX = matrix_multiply_numba(u.T, X)
    uX = np.sum(uTX ** 2, axis = 1)

    res = - N * D / 2 * np.log(2 * np.pi)
    res -= (N - K) * D * np.log(sigma_X)
    res -= K * D * np.log(sigma_A)
    res -= D / 2 * det
    res -= 1 / (2 * sigma_X**2) * (trX - np.sum(l * uX))
    return res

In [3]:
def test_lp(func1, func2, func3, N, D, K, times, sigma_X = None, sigma_A = None):
    if sigma_X is None:
        sigma_X = 1
    if sigma_A is None:
        sigma_A = 1
    diff = np.zeros(times)
    diff[:] = np.nan
    time_arr = np.zeros((3, times))
    with tqdm(total=times) as pbar:
        for i in range(times):
            X = np.random.normal(0, 1, (N, D))
            Z = np.random.randint(0, 1, (N, K)) 
            Z = Z.astype('float')
            trX = np.trace(X.T @ X)
            time_start = time.time()
            ans1 = func1(trX = trX, X = X, Z = Z, sigma_X = 1, sigma_A = 1)
            time_arr[0, i] = time.time() - time_start
            time_start = time.time()
            ans2 = func2(trX = trX, X = X, Z = Z, sigma_X = 1, sigma_A = 1)
            time_arr[1, i] = time.time() - time_start
            time_start = time.time()
            ans3 = func3(trX = trX, X = X, Z = Z, sigma_X = 1, sigma_A = 1)
            time_arr[2, i] = time.time() - time_start
            diff[i] = max(np.abs(ans2 - ans1), np.abs(ans3 - ans1), np.abs(ans3 - ans2))
            pbar.update(1)
    
    return np.max(diff), np.mean(time_arr, axis = 1), np.max(time_arr, axis = 1)

In [None]:
diffs = test_lp(5000, 1000, 500, 100, 1, 1)
print(diffs)

In [None]:
diffs = test_lp(5000, 2000, 500, 100, 1, 1)
print(diffs)

In [None]:
diffs = test_lp(5000, 2000, 1000, 100, 1, 1)
print(diffs)

In [None]:
diffs = test_lp(10000, 2000, 500, 100, 1, 1)
print(diffs)

In [None]:
diffs = test_lp(10000, 2000, 1000, 100, 1, 1)
print(diffs)

In [None]:
diffs = test_lp(10000, 4000, 1000, 100, 1, 1)
print(diffs)

In [None]:
arr = [0] * 100
X = np.random.normal(0, 1, (10000, 4000))
Z = np.random.randint(0, 1, (10000, 2000))  
trX = np.trace(X.T @ X)
for i in range(100):
    time_start = time.time()
    ans = lp(trX = trX, X = X, Z = Z, sigma_X = 1, sigma_A = 1)
    arr[i] = time.time() - time_start
    print(i, arr[i])
print(np.mean(arr[60:]), np.max(arr[60:]))

In [None]:
print(np.mean(arr[60:]), np.max(arr[60:]))

In [None]:
%prun -q -D test.prof diffs = test_lp(N = 1000, D = 20, K = 10, times = 10000, sigma_X = 1, sigma_A = 1)

In [None]:
import pstats
p = pstats.Stats('test.prof')
p.print_stats()
pass
(0.0, array([1.4094832 , 0.05192785]), array([1.515692  , 0.07405591]))

In [None]:
diffs = test_lp(lp_original, lp, lp_numba, 1000, 400, 200, 1000, 1, 1)
print("difference:", diffs[0], "\n mean time for each lp function:", diffs[1], "\n max time for each lp function:", diffs[2])

 83%|████████▎ | 828/1000 [04:09<00:49,  3.50it/s]

In [None]:
%load_ext cython

## unfinish: need to write SVD

In [None]:
%%cython -a

import cython
from libc.math cimport log, M_PI

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def lp_cython(double trX, double[:,:] X, double[:,:] Z, double sigma_X, double sigma_A):
    cdef int N,D,K
    cdef double[:,:] u
    cdef double[:] s
    cdef double[:,:] _
    cdef double[:] l
    cdef double det = 0
    cdef double[:,:] uTX
    cdef double[:] uX
    cdef double res

    N = X.shape[0]
    D = X.shape[1]
    K = Z.shape[1]
    u = np.linalg.svd(Z,full_matrices=False)[0]
    s = np.linalg.svd(Z,full_matrices=False)[1]
    _ = np.linalg.svd(Z,full_matrices=False)[2]
    
    for i in range(sizeof(s)/sizeof(*s)):
        det += log(pow(s[i],2) + pow(sigma_X,2) / pow(sigma_A,2))
        l[i] = pow(s[i],2) / (pow(s[i],2) + pow(sigma_X,2) / pow(sigma_A,2))
        for j in range(D):
            for k in range(N):
                uTX[i,j] += u[k,i] * X[k, j]
            uX[i] += uTX[i,j]^2
        prod += l[i] * uX[i]
    
    res = - N * D / 2 * log(2 * M_PI)
    res -= (N - K) * D * log(sigma_X)
    res -= K * D * log(sigma_A)
    res -= D / 2 * det
    res -= 1 / (2 * sigma_X**2) * (trX - prod)
    return res