In [1]:
from numba import njit,prange

In [2]:
@njit(parallel=True)
def mttkrp(X, factors, n, rank, dims):    
    output = np.zeros((dims[n],rank))
    indices = X.indices.numpy()
    values = X.values.numpy()
    
    for l in prange(len(values)):
        cur_index = indices[l]
        prod = [values[l]]*rank #makes the value into a row

        for mode,cv in enumerate(cur_index): #does elementwise row multiplications
            if(mode != n):
                prod *= factors[mode][cv]
                
        output[cur_index[n]] += prod
    
    return output

def cp_als(X, rank, n_iter_max = 50):
    
    dims = X.shape.as_list()
    nd = len(dims)
    factors = [np.random.random((d,rank)) for d in dims]
    weights = np.ones((1,rank))
    
    for iteration in range(n_iter_max): 
        print(iteration , end="\r")
        for n in range(nd):
            
            #the following block calculates inverse of the hadamard product
            h = mul(weights.T,weights)
            for i,f in enumerate(factors):
                if i != n:
                    h *= mul(f.T,f)
            vinv = np.linalg.pinv(h)
            
            #the following block calculates An by doing MTTKRP and multiplying it by the inverse of the hadamard
            mk = mttkrp(X, factors, n, rank, dims)
            wmk = np.multiply(mk, weights[0]) #handling the weights
            An = mul(wmk,vinv)
            
            #the following block normalizes the columns and stored
            weight = norm(An,axis=0)
            b = np.where(weight<1e-12, 1, weight)
            weights[0] *= b
            An /= b
            
            factors[n] = An
            
    return weights, factors