In [1]:
import numpy as np
import math
from numba import jit

## Simulated Dataset

A simulated matrix of 100 by 50 is generated.

In [2]:
# generate simulated data set 
np.random.seed(1)
u_hat = np.array([10,9,8,7,6,5,4,3])
u_hat = np.append(u_hat,np.repeat(2,17))
u_hat = np.append(u_hat,np.repeat(0,75))

v_hat = np.array([10,-10,8,-8,5,-5])
v_hat = np.append(v_hat,np.repeat(3,5))
v_hat = np.append(v_hat,np.repeat(-3,5))
v_hat = np.append(v_hat,np.repeat(0,34))

u = u_hat/np.linalg.norm(u_hat)

v = v_hat/np.linalg.norm(v_hat)

X_star = 50*u.reshape(len(u),1) @ v.reshape(1,len(v))
X = X_star + np.random.normal(0,1,100*50).reshape(100,50)

## Pure Python Code

In order to find the bottlenecks, the pure Python Code is provided and profiled. 

In [3]:
def thresh(z,delta):
    return np.sign(z)*(np.abs(z) >= delta)*(np.abs(z)-delta)

def ssvd(X, gamu = 2, gamv = 2, merr = 10**(-4), niter = 100):
    # initial values
    U, s, V = np.linalg.svd(X)
    u0 = U.T[0]
    v0 = V.T[0]
    
    n = X.shape[0]
    d = X.shape[1]
    ud = 1
    vd = 1
    iters = 0
    SST = np.sum(X*X)
    while (ud > merr or vd > merr):
        iters = iters +1
        # Updating v
        z =  X.T @ u0
        winv = np.abs(z)**gamv
        sigsq = np.abs(SST - np.sum(z*z))/(n*d-d)
        cand = z*winv
        delt = np.sort(np.append(np.abs(cand),0))
        delt_uniq = np.unique(delt)
        Bv = np.ones(len(delt_uniq)-1)*float("inf")
        ind = np.where(winv>10^(-8))
        cand1 = cand[ind]
        winv1 = winv[ind]
        for i in range(len(Bv)):
            temp2 = thresh(cand1,delta = delt_uniq[i])
            temp2 = temp2/winv1
            temp3 = np.zeros(d)
            temp3[ind] = temp2
            Bv[i] = np.sum((X - u0[:,None] @ temp3[None,:])**2)/sigsq + np.sum(temp2!=0)*math.log(n*d)
        Iv = min(np.where(Bv== np.min(Bv)))
        th = delt_uniq[Iv]
        temp2 = thresh(cand1,delta = th)
        temp2 = temp2/winv1
        v1 = np.zeros(d)
        v1[ind] = temp2
        v1 = v1/((np.sum(v1*v1))**0.5) #v_new
        
        # Updating u
        z = X @ v1
        winu = np.abs(z)**gamu
        sigsq = np.abs(SST - np.sum(z*z))/(n*d-n)
        cand = z*winu
        delt = np.sort(np.append(np.abs(cand),0))
        delt_uniq = np.unique(delt)
        Bu = np.ones(len(delt_uniq)-1)*float("inf")
        ind = np.where(winu > 10^(-8))
        cand1 = cand[ind]
        winu1 = winu[ind]
        for i in range(len(Bu)):
            temp2 = thresh(cand1,delta = delt_uniq[i])
            temp2 = temp2/winu1
            temp3 = np.zeros(n)
            temp3[ind] = temp2
            Bu[i] = np.sum((X - temp3[:,None] @ v1[None,:])**2)/sigsq + np.sum(temp2!=0)*math.log(n*d)
        Iu = min(np.where(Bu==np.min(Bu)))
        th = delt_uniq[Iu]
        temp2 = thresh(cand1,delta = th)
        temp2 = temp2/winu1
        u1 = np.zeros(n)
        u1[ind] =  temp2
        u1 = u1/((np.sum(u1*u1))**0.5)
        
        
        ud = np.sum((u0-u1)*(u0-u1))**0.5
        vd = np.sum((v0-v1)*(v0-v1))**0.5

        if iters > niter :
            print("Fail to converge! Increase the niter!")
            break
        
        u0 = u1
        v0 = v1

    s = u1[None, :] @ X @ v1[:, None]
    return u1, v1, s, iters

In [4]:
%prun ssvd(X)

 

## Just-In-Time Compilation

The `thresh` function is frequently called, so a JIT decorator is added to optimize the original code. 

In [5]:
@jit
def thresh_jit(z,delta):
    return np.sign(z)*(np.abs(z) >= delta)*(np.abs(z)-delta)

In [6]:
def ssvd_jit(X, gamu = 2, gamv = 2, merr = 10**(-4), niter = 100):
    # initial values
    U, s, V = np.linalg.svd(X)
    u0 = U.T[0]
    v0 = V.T[0]
    
    n = X.shape[0]
    d = X.shape[1]
    ud = 1
    vd = 1
    iters = 0
    SST = np.sum(X*X)
    while (ud > merr or vd > merr):
        iters = iters +1
        # Updating v
        z =  X.T @ u0
        winv = np.abs(z)**gamv
        sigsq = np.abs(SST - np.sum(z*z))/(n*d-d)
        cand = z*winv
        delt = np.sort(np.append(np.abs(cand),0))
        delt_uniq = np.unique(delt)
        Bv = np.ones(len(delt_uniq)-1)*float("inf")
        ind = np.where(winv>10^(-8))
        cand1 = cand[ind]
        winv1 = winv[ind]
        for i in range(len(Bv)):
            temp2 = thresh_jit(cand1,delta = delt_uniq[i])
            temp2 = temp2/winv1
            temp3 = np.zeros(d)
            temp3[ind] = temp2
            Bv[i] = np.sum((X - u0[:,None] @ temp3[None,:])**2)/sigsq + np.sum(temp2!=0)*math.log(n*d)
        Iv = min(np.where(Bv== np.min(Bv)))
        th = delt_uniq[Iv]
        temp2 = thresh_jit(cand1,delta = th)
        temp2 = temp2/winv1
        v1 = np.zeros(d)
        v1[ind] = temp2
        v1 = v1/((np.sum(v1*v1))**0.5) #v_new
        
        # Updating u
        z = X @ v1
        winu = np.abs(z)**gamu
        sigsq = np.abs(SST - np.sum(z*z))/(n*d-n)
        cand = z*winu
        delt = np.sort(np.append(np.abs(cand),0))
        delt_uniq = np.unique(delt)
        Bu = np.ones(len(delt_uniq)-1)*float("inf")
        ind = np.where(winu > 10^(-8))
        cand1 = cand[ind]
        winu1 = winu[ind]
        for i in range(len(Bu)):
            temp2 = thresh_jit(cand1,delta = delt_uniq[i])
            temp2 = temp2/winu1
            temp3 = np.zeros(n)
            temp3[ind] = temp2
            Bu[i] = np.sum((X - temp3[:,None] @ v1[None,:])**2)/sigsq + np.sum(temp2!=0)*math.log(n*d)
        Iu = min(np.where(Bu==np.min(Bu)))
        th = delt_uniq[Iu]
        temp2 = thresh_jit(cand1,delta = th)
        temp2 = temp2/winu1
        u1 = np.zeros(n)
        u1[ind] =  temp2
        u1 = u1/((np.sum(u1*u1))**0.5)
        
        
        ud = np.sum((u0-u1)*(u0-u1))**0.5
        vd = np.sum((v0-v1)*(v0-v1))**0.5

        if iters > niter :
            print("Fail to converge! Increase the niter!")
            break
        
        u0 = u1
        v0 = v1

    s = u1[None, :] @ X @ v1[:, None]
    return  u1, v1, s, iters

## Cythonized Code
We rewrited the Python code so that it could compile directly to machine code. We tried our best to remove as much as possible the yellow highlighted part which indicated Python interaction. 

In [7]:
%load_ext cython

In [8]:
%%cython -a
import numpy as np
cimport numpy as np
from libc.math cimport log
import cython

@cython.boundscheck(False)
@cython.wraparound(False)
cdef double[:] thresh(double[:] z, double delta):
    cdef int i
    cdef int n = z.shape[0]
    cdef double[:] res = np.zeros(n)
    for i in range(n):
        if abs(z[i]) >= delta:
            if z[i] > 0:
                res[i] = abs(z[i])-delta
            else:
                res[i] = - (abs(z[i])-delta)
    return res

@cython.boundscheck(False)
@cython.wraparound(False)
cdef int[:] get_index(double[:] v, double thresh):
    cdef int i, j
    cdef int count = 0
    cdef int n = v.shape[0]
    for i in range(n):
        if v[i] > thresh:
            count += 1
    cdef int[:] res = np.empty(count, dtype="int32")
    for i in range(n):
        if v[i] > thresh:
            res[j] = i
            j += 1
    return res
    
@cython.boundscheck(False)
@cython.wraparound(False)
cdef int sparsity(double[:] v):
    cdef int count = 0
    cdef int i
    for i in range(v.shape[0]):
        if v[i] != 0:
            count += 1
    return count

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def ssvd_cython(double[:,:] X, int gamu = 2, int gamv = 2, double merr =  0.0001, int niter = 100):
    # initial values
    cdef double[:,:] U, V
    cdef double[:] S
    U, S, V = np.linalg.svd(X)
    cdef double[:] u0 = U.T[0]
    cdef double[:] v0 = V.T[0]
    
    cdef int n = X.shape[0]
    cdef int d = X.shape[1]
    cdef double ud = 1.0
    cdef double vd = 1.0
    cdef int iters = 0
    cdef double SST = 0.0
    cdef int i, j, k
    for i in range(n):
        for j in range(d):
            SST += X[i, j]**2
    cdef double[:,:] Xbar
    cdef double[:] winv = np.zeros(d)
    cdef double[:] winu = np.zeros(n)
    cdef double[:] candu = np.zeros(n)
    cdef double[:] candv = np.zeros(d)
    cdef double[:] z, delt, delt_uniq, Bv, cand1, winv1  
    cdef double[:] temp2, temp3, v1, Bu, u1
    cdef double sigsq, th, s, minimal, v_norm, u_norm
    cdef int[:] ind
    cdef int Iv, Iu
    while (ud > merr or vd > merr):
        iters += 1
        ## Updating v
        z = np.dot(X.T, u0)
        # weights 
        for i in range(d):
            winv[i] = abs(z[i])**gamv 
        # sigma squared
        sigsq = SST
        for i in range(d):
            sigsq -= z[i]**2
        sigsq = abs(sigsq) / (n*d - d)
        for i in range(d):
            candv[i] = z[i]*winv[i]
        delt = np.zeros(d+1)
        for i in range(d):
            delt[i] = abs(candv[i])
        delt = np.sort(delt)
        delt_uniq = np.unique(delt)
        Bv = np.ones(len(delt_uniq)-1)*float("inf")
        ind = get_index(winv, 10^(-8))
        cand1 = np.zeros(ind.shape[0])
        for i in range(cand1.shape[0]):
            cand1[i] = candv[ind[i]]
        winv1 = np.zeros(ind.shape[0])
        for i in range(winv1.shape[0]):
            winv1[i] = winv[ind[i]]
        # compute BIC
        for i in range(Bv.shape[0]):
            temp2 = thresh(cand1,delta = delt_uniq[i])
            for j in range(temp2.shape[0]):
                temp2[j] = temp2[j] / winv1[j]
            temp3 = np.zeros(d)
            for j in range(temp2.shape[0]):
                temp3[ind[j]] = temp2[j]
            Bv[i] = 0
            Xbar =  np.dot(u0[:,None], temp3[None,:])
            for j in range(n):
                for k in range(d):
                    Bv[i] += (X[j, k] - Xbar[j, k])**2/sigsq
            Bv[i] = Bv[i] + sparsity(temp2)*log(n*d)
        # find the minimal BIC value 
        minimal = Bv[0]
        Iv = 0
        for i in range(Bv.shape[0]):
            if (Bv[i] < minimal):
                minimal = Bv[i]
                Iv = i
        th = delt_uniq[Iv]
        temp2 = thresh(cand1,delta = th)
        for i in range(temp2.shape[0]):
            temp2[i] = temp2[i] / winv1[i]
        v1 = np.zeros(d)
        for i in range(temp2.shape[0]):
            v1[ind[i]] = temp2[i]
        #v_new
        v_norm = 0
        for i in range(d):
            v_norm += v1[i]**2
        v_norm = v_norm**0.5
        for i in range(d):
            v1[i] = v1[i]/v_norm
        
        ## Updating u
        z = np.dot(X, v1)
        # weights
        for i in range(n):
            winu[i] = abs(z[i])**gamu 
        # sigma squared
        sigsq = SST
        for i in range(n):
            sigsq -= z[i]**2
        sigsq = abs(sigsq) / (n*d - n)
        for i in range(n):
            candu[i] = z[i]*winu[i]
        delt = np.zeros(n+1)
        for i in range(n):
            delt[i] = abs(candu[i])
        delt = np.sort(delt)
        delt_uniq = np.unique(delt)
        Bu = np.ones(len(delt_uniq)-1)*float("inf")
        ind = get_index(winu, 10^(-8))
        cand1 = np.zeros(ind.shape[0])
        for i in range(cand1.shape[0]):
            cand1[i] = candu[ind[i]]
        winu1 = np.zeros(ind.shape[0])
        for i in range(winu1.shape[0]):
            winu1[i] = winu[ind[i]]
        # compute BIC
        for i in range(Bu.shape[0]):
            temp2 = thresh(cand1,delta = delt_uniq[i])
            for j in range(temp2.shape[0]):
                temp2[j] = temp2[j] / winu1[j]
            temp3 = np.zeros(n)
            for j in range(temp2.shape[0]):
                temp3[ind[j]] = temp2[j]
            Bu[i] = 0
            Xbar =  np.dot(temp3[:,None], v1[None,:])
            for j in range(n):
                for k in range(d):
                    Bu[i] += (X[j, k] - Xbar[j, k])**2/sigsq
            Bu[i] = Bu[i] + sparsity(temp2)*log(n*d)
        # find minimal BIC
        minimal = Bu[0]
        Iu = 0
        for i in range(Bu.shape[0]):
            if Bu[i] < minimal:
                minimal = Bu[i]
                Iu = i 
        th = delt_uniq[Iu] 
        temp2 = thresh(cand1,delta = th)
        for i in range(temp2.shape[0]):
            temp2[i] = temp2[i] / winu1[i]
        u1 = np.zeros(n)
        for i in range(temp2.shape[0]):
            u1[ind[i]] = temp2[i]
        # u_new
        u_norm = 0
        n=u1.shape[0]
        for i in range(n):
            u_norm += u1[i]**2
        u_norm = u_norm**0.5
        for i in range(n):
            u1[i] = u1[i]/u_norm
        
        ## update difference
        ud = 0.0
        for i in range(n):
            ud += (u0[i]-u1[i])**2
        ud = ud**0.5
        
        vd = 0.0
        for i in range(d):
            vd += (v0[i]-v1[i])**2
        vd = vd**0.5

        if iters > niter :
            print("Fail to converge! Increase the niter!")
            break
        
        u0 = u1
        v0 = v1
        
    s = np.dot(u1[None, :], np.dot(X, v1[:, None]))
    return u1, v1, s, iters

## Test on Simulated Data
We tested on JIT and Cython version to make sure results are the same as the pure Python code. 

In [9]:
np.testing.assert_array_almost_equal((ssvd_cython(X))[0], (ssvd(X))[0])

In [10]:
np.testing.assert_array_almost_equal((ssvd_jit(X))[0], (ssvd(X))[0])

## Runtime Comparison

In [14]:
%timeit ssvd(X)

10 loops, best of 3: 35.3 ms per loop


In [15]:
%timeit ssvd_jit(X)

10 loops, best of 3: 31.5 ms per loop


In [16]:
%timeit ssvd_cython(X)

10 loops, best of 3: 33.9 ms per loop
