In [1]:
# we have to set env variable to use propack, if we want to use sparse2tuck
import os
os.environ["SCIPY_USE_PROPACK"] = "1"

In [2]:
import scipy
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (8,8)
import torch
import tucker_riemopt
from tucker_riemopt import backend as back
from tucker_riemopt import set_backend
from tucker_riemopt.tucker.tucker import Tucker
import tucker_riemopt.tucker.riemannian as riemann

In [18]:
#init tensor, mask and sparse tessor
size = 32
set_backend("pytorch")
A = back.zeros([size,size,size])
for i in range(size):
    for j in range(size):
        for k in range(size):
            A[i][j][k] = np.sin(i+j+k)
            
Omega = back.zeros([size,size,size])
for i in range(size):
    for j in range(size):
        for k in range(size):
            Omega[i][j][k] = 1 if np.random.uniform()<0.5 else 0
X = back.zeros([size,size,size])
for i in range(size):
    for j in range(size):
        for k in range(size):
            X[i][j][k] = np.random.uniform()
Asp = A*Omega #element wise multiplication

In [19]:
Omega_tucker = Tucker.from_dense(Omega)
T = Tucker.from_dense(A)
X = Tucker.from_dense(X)

In [54]:
def Omega_projection(dense_tensor, Omega = Omega):
    return dense_tensor*Omega
def Tucker_Omega_projection(tucker_tensor, Omega = Omega):
    return Tucker.from_dense(Omega_projection(tucker_tensor.to_dense()))
def Euclidean_grad(X,A=T):
    return Tucker_Omega_projection(A) - Tucker_Omega_projection(X)
def f(X,A=T):
    return 1/2*(Tucker_Omega_projection(A) - Tucker_Omega_projection(X)).norm()**2
def line_search(eta,X):
    proj_eta = Tucker_Omega_projection(eta)
    return proj_eta.flat_inner(Euclidean_grad(X))/(proj_eta.norm()**2)
def retraction(X,xi,r):
    return (X+xi).round([r,r,r])
    
# def Riemannian_grad(Grad, X):
#     tucker_riemopt.tucker.riemannian

In [64]:
xi = riemann.grad(f, X)
eta = -xi[0].construct()
alpha = line_search(eta,X)
X =retraction(X, alpha*eta, 2)

In [103]:
max_iter = 1000
for k in range(max_iter):
    xi = riemann.grad(f, X)
    eta = -xi[0].construct() + riemann.project(X, eta).construct()
    alpha = line_search(eta,X)
    X =retraction(X, alpha*eta, 2)
    eps = back.sqrt(2*xi[1])/T.norm()
    if k%50 ==0:
        print(eps)
    if eps <1e-3:
        break

tensor(0.0577, grad_fn=<DivBackward0>)
tensor(0.0558, grad_fn=<DivBackward0>)
tensor(0.0542, grad_fn=<DivBackward0>)
tensor(0.0527, grad_fn=<DivBackward0>)
tensor(0.0515, grad_fn=<DivBackward0>)
tensor(0.0504, grad_fn=<DivBackward0>)
tensor(0.0493, grad_fn=<DivBackward0>)
tensor(0.0484, grad_fn=<DivBackward0>)
tensor(0.0476, grad_fn=<DivBackward0>)
tensor(0.0468, grad_fn=<DivBackward0>)
tensor(0.0461, grad_fn=<DivBackward0>)
tensor(0.0454, grad_fn=<DivBackward0>)
tensor(0.0448, grad_fn=<DivBackward0>)
tensor(0.0443, grad_fn=<DivBackward0>)
tensor(0.0437, grad_fn=<DivBackward0>)
tensor(0.0432, grad_fn=<DivBackward0>)
tensor(0.0427, grad_fn=<DivBackward0>)
tensor(0.0423, grad_fn=<DivBackward0>)
tensor(0.0418, grad_fn=<DivBackward0>)
tensor(0.0414, grad_fn=<DivBackward0>)
