In [1]:
"""
ADMM solver for tensor completion with n-rank minimization
"""
import numpy as np
from math_utils import *
from scipy import linalg
from sktensor import ktensor

def gen_orth_tensor(input_shape,R):
    # all orthogonal cases
    U = np.random.random((input_shape[0], R))
    U_orth = linalg.orth(U)

    V = np.random.random((input_shape[1], R))
    V_orth = linalg.orth(V)

    W = np.random.random((input_shape[2], R))
    W_orth = linalg.orth(W)

    Lambda = np.random.random((R,))

    X = ktensor([U_orth, V_orth, W_orth], lmbda=Lambda)
    X_ten = np.asarray(X.totensor())

    return X_ten


def exact_update(Omega,X, Ws, Ys, params):
    """Exact update of the primal variable"""
    num_modes = np.ndim(Omega)
    beta_val = params['beta']
    lambda_val  = params['lambda']
    
    W_Y_sum = np.sum(W_m + beta_val * Y_m for W_m in Ws for Y_m in Ys)
    X_out = 1.0/(lambda_val +  num_modes * beta_val) * (W_Y_sum + lambda_val * X)   
    X_out[Omega==0] = 1.0/(num_modes * beta_val)*W_Y_sum[Omega==0]
    
    return X_out


def inexact_update(Omega, X, Ws, Ys):
    pass
    
    
def tc_loss(X_out, Omega, X, Ws,Ys, params):
    """loss function of tensor completion"""
    num_modes = np.ndim(X)
    loss = 0.0;
    loss += params['lambda']*0.5 * np.square(tensor_norm(np.subtract(X_out[Omega==1],X[Omega==1]), 'fro'))
    for mode in range(num_modes):
        loss += np.linalg.norm(unfold(Ys[mode], mode),'nuc')
        loss += np.sum(np.multiply(Ws[mode], Ys[mode] - X_out))
        loss += params['beta'] *0.5* np.square(tensor_norm(np.subtract(Ys[mode],X_out),'fro'))
    return loss
                                             
                                             
def tensor_complete_ADMM(Omega, X, succ_thres,params):
    num_modes = np.ndim(X)
    beta_val =  params['beta']
    lambda_val = params['lambda']
    X_out = np.zeros(X.shape)
    Ws = [np.zeros(X.shape)] * num_modes
    Ys = [np.zeros(X.shape)] * num_modes
    loss = np.zeros((max_iter+1,1))
    loss_val = tc_loss(X_out, Omega, X, Ws,Ys, params)
    loss[0] = loss_val
    if params['verbose']:
        print'start:{}'.format(loss_val)
    for k in range(params['max_iter']):
#         params['lambda'] = params['lambda']*c_lambda
#         params['beta'] = params['beta'] * c_beta
        X_out_new  = exact_update(Omega, X, Ws, Ys, params)
        loss_val_new = tc_loss(X_out_new, Omega, X, Ws,Ys, params)
        loss[k+1] = loss_val_new
        if params['verbose']:
            print'iter {}:{}'.format(k, loss_val_new)
        if  abs((loss_val_new-loss_val)/loss_val) < params['stop_thres']:
            loss[k+1:] = loss_val_new
            break;
        X_out = np.copy(X_out_new)
        loss_val = np.copy(loss_val_new)

        for mode in range(num_modes):
            X_W_mat = unfold(X_out,mode) - 1.0/beta_val * unfold(Ws[mode],mode)
            Y_m = shrink(X_W_mat, 1.0/beta_val)
            Ys[mode] = fold(Y_m, mode, X.shape)
            Ws[mode] = Ws[mode] - beta_val * (X_out[mode]- Ys[mode])
       
    return (X_out,loss)
        

In [None]:
from sktensor import dtensor, cp_als

def hard_thres():
    
def RTPM(X):
    """robust tensor power method"""
    X_out = np.zeros(X.shape)
    pass

def TPM(X):
    T = dtensor(X)
    P, fit, itr, exectimes = cp_als(T, 3, init='random')
    return (P.Lambda, P.U)

    
def tensor_complete_ALS(Omega, X, succ_thres,params ):
    """implementation of [Jain 2014]"""
    [Lambda,U0] = RTPM(X) # tensor power method initialization
    [U] = hard_thres(U0)
    for k in range(max_iter):
        for r in range(R):
            u1= rank_one_ls(Omega, X, U)
            Lambda(r) = norm(u1)
            U[:,r] = u1/ Lambda(r)
            
            

In [None]:
"""
test routine  for TensorComplete
"""
import matplotlib.pyplot as plt
%matplotlib inline
var_shape = (5,6,8)
rank_val = 2
Omega = np.random.rand(*var_shape)
Omega = np.array(Omega < 0.8)
X = gen_orth_tensor(var_shape, rank_val)

X_obv = np.copy(X)
X_obv[Omega ==0] = 0
succ_thres = np.float32(1e-3)
beta_val = np.float32(1)
lambda_val = np.float32(1e2)
c_beta = np.float32(1)
c_lambda = np.float32(1e-3)
VERBOSE = False
max_iter = np.int32(1e4);
stop_thres = np.float32(1e-6);
params = {'beta':beta_val, 'lambda':lambda_val,'verbose':VERBOSE,
          'max_iter':max_iter, 'stop_thres':stop_thres}
X_out, loss= tensor_complete_ADMM(Omega, X_obv, succ_thres, params)
print 'error ratio:', tensor_norm(np.subtract(X_out,X),'fro')/tensor_norm(X, 'fro')

  res = _sum_(a)


In [None]:
plt.plot(loss)

In [None]:
# print X
# print X_obv
# print X_out