In [55]:
import ctf
import numpy as np
import numpy.linalg as la
from ctf import random
glob_comm = ctf.comm()

In [56]:
def normalize(Z):
    norms = ctf.tensor(r)
    norms.i("u") << Z.i("pu")*Z.i("pu")
    norms = 1./norms**.5
    X = ctf.tensor(copy=Z)
    Z.set_zero()
    Z.i("pu") << X.i("pu")*norms.i("u")
    return 1./norms

In [57]:
def updateOmega(I,J,K,sparsity):
    '''
    Gets a random subset of rows for each U,V,W iteration
    '''
    Actf = ctf.tensor((I,J,K),sp=True)
    Actf.fill_sp_random(0,1,sparsity)
    omegactf = ((Actf > 0)*ctf.astensor(1.))
    return omegactf


def getIndexFromOmega(omega,I,J,K):
    '''
    Gets a random i,j,k contained in Ω[0...i-1,0...j-1,0...k-1]
    '''
    #TODO: Better sampling instead of MC.
    i = random.randint(0,I-1)
    j = random.randint(0,J-1)
    k = random.randint(0,K-1)
    while omega[i][j][k] != 1:
        i = random.randint(0,I-1)
        j = random.randint(0,J-1)
        k = random.randint(0,K-1)
        
    return (i,j,k)

In [58]:
def updateU(T,V,W,regParam,omega,I,J,K,r):
    '''Update U matrix by using the formula'''
    
    M1 = ctf.tensor((J,K,r))
    M1.i("jku") << V.i("ju")*W.i("ku")
    [U_,S_,V_] = ctf.svd(M1.reshape((J*K,r)))
    S_ = 1./S_
    U.set_zero()
    U.i("iu") << V_.i("vu")*S_.i("v")*U_.reshape((J,K,r)).i("jkv")*T.i("ijk")
    normalize(U)
    
    return U   
    
    
def updateV(T,U,W,regParam,omega,I,J,K,r):
    '''Update V matrix by using the formula'''
    
    M2 = ctf.tensor((I,K,r))
    M2.i("iku") << U.i("iu")*W.i("ku")
    [U_,S_,V_] = ctf.svd(M2.reshape((I*K,r)))
    S_ = 1./S_
    V.set_zero()
    V.i("ju") << V_.i("vu")*S_.i("v")*U_.reshape((I,K,r)).i("ikv")*T.i("ijk")
    normalize(V)
    
    return V   

def updateW(T,U,V,regParam,omega,I,J,K,r):
    '''Update V matrix by using the formula'''
    
    M3 = ctf.tensor((I,J,r))
    M3.i("iju") << U.i("iu")*V.i("ju")
    [U_,S_,V_] = ctf.svd(M3.reshape((I*J,r)))
    S_ = 1./S_
    W.set_zero()
    W.i("ku") << V_.i("vu")*S_.i("v")*U_.reshape((I,J,r)).i("ijv")*T.i("ijk")
    normalize(W)
    
    return W

In [63]:
def getALSCtf(T,U,V,W,regParam,omega,I,J,K,r):
    '''
    Same thing as above, but CTF
    '''
    it = 0
    E = ctf.tensor((I,J,K))
    E.i("ijk") << T.i("ijk") - omega.i("ijk")*U.i("iu")*V.i("ju")*W.i("ku")
    curr_err_norm = ctf.vecnorm(E) + (ctf.vecnorm(U) + ctf.vecnorm(V) + ctf.vecnorm(W))*regParam
    
    while True:
        U = updateU(T,V,W,regParam,omega,I,J,K,r) 
        V = updateV(T,U,W,regParam,omega,I,J,K,r) 
        W = updateW(T,U,V,regParam,omega,I,J,K,r) 
        E.i("ijk") << T.i("ijk") - omega.i("ijk")*U.i("iu")*V.i("ju")*W.i("ku")
        next_err_norm = ctf.vecnorm(E) + (ctf.vecnorm(U) + ctf.vecnorm(V) + ctf.vecnorm(W))*regParam
        
        if curr_err_norm - next_err_norm > .001 or it > 10:
            break
            
        print(curr_err_norm, next_err_norm)
        curr_err_norm = next_err_norm
        it += 1
        
    return U,V,W

In [64]:
I = 3
J = 4
K = 5
r = 3 # TODO: determine rank?
sparsity = .2
#learningRate = .1
regParam = 2
#regParamSGD = .1
#width = 1  # how many random i,j we're going to get
#convNo = .00002  # When we'll stop converging
  
# 3rd-order tensor
T = ctf.tensor((I,J,K),sp=True)
T.fill_sp_random(0,1,1)

ctf.random.seed(42)
U = ctf.random.random((I,r))
V= ctf.random.random((J,r))
W= ctf.random.random((K,r))

normalize(U)
normalize(V)
normalize(W)

omega = updateOmega(I,J,K,sparsity)

In [65]:
getALSCtf(T,U,V,W,regParam,omega,I,J,K,r)

14.019753130499076 17.437660889861057
17.437660889861057 20.833161839357622
20.833161839357622 24.232768942197666
24.232768942197666 27.63977141098906
27.63977141098906 31.05058946569811
31.05058946569811 34.463935933262036
34.463935933262036 37.87963545021427
37.87963545021427 41.29787466835978
41.29787466835978 44.71893183113951
44.71893183113951 48.14309546583924
48.14309546583924 51.570633605345


(array([[ 0.07512799,  0.69449551,  0.40799817],
        [-0.78159667,  0.53520393,  0.43958365],
        [-0.61924343,  0.48086667,  0.8001898 ]]),
 array([[ 0.51389248,  0.38233145,  0.49891826],
        [-0.32236203,  0.85430083,  0.4165905 ],
        [ 0.40625483,  0.28368424,  0.3612204 ],
        [ 0.68334051,  0.20860491,  0.66862003]]),
 array([[-0.55985128,  0.69938648, -0.37708965],
        [-0.6396375 ,  0.6548355 , -0.25587897],
        [ 0.11859833,  0.15214084,  0.31106806],
        [-0.00801446,  0.24244355,  0.15276784],
        [ 0.51312828,  0.01110442,  0.81989512]]))