In [12]:
# Tensorly library
import tensorly as tl
import numpy as np
from functools import reduce

In [44]:
"""
TENSOR TO TENSOR REGRESSION v1    
  Inputs:
    X - Input Tensor, shape (N, I_1, ..., I_{M1})
    Y - Output Tensor, shape (N, J_1, ..., J_{M2})
    rank - The rank of the solution weights matrix
    lambda_reg - how much l2 regularization to use
    return_factors - whether to include the weight matrix factors along with the weight matrix itself
    eps - the precision of our estimate
    
  Outputs:
    W - weight matrix tensor, shape (I_1, ..., I_{M1}, J_1, ..., J_{M2})
    factors - List of CP factors which forms the CP decomposition of W. Will only be returned if return_factors==True
"""
def tensor_to_tensor_regression_1(X, Y, rank, lambda_reg=1e-1, return_factors=False, eps=1e-6):
    # Check that the N is consistent between X and Y tensors
    if X.shape[0] != Y.shape[0]:
        print("Wrong leading dimensions for tensors X and Y")
    
    # Number of examples (leading dimension in the tensor)
    N = X.shape[0]
    
    # Setup the sizes of the X and Y matrices
    I = reduce(lambda x, y: x * y, X.shape[1:])
    J = reduce(lambda x, y: x * y, Y.shape[1:])
    
    # Initialize the u_r and v_r vectors which we will directly optimize
    u_rs = [np.random.random(I)-0.5 for r in range(rank)]
    v_rs = [np.random.random(J)-0.5 for r in range(rank)]
    
    # Matricize X and Y
    X_mat = tl.unfold(X, mode=0)
    Y_mat = tl.unfold(Y, mode=0)
    
    # Precompute some repeatedly used matrices
    XtXpL = (tl.transpose(X_mat) @ X_mat) + lambda_reg * tl.eye(I)
    XtXpL_inv = np.linalg.inv(XtXpL)
    XtY = tl.transpose(X_mat) @ Y_mat
    YtX = tl.transpose(XtY)
    
    steps = 0
    prev_error = -1
    #errors = [] # TODO: Remove for production (only used for testing)
    # Keep optimizing until the error has converged
    while True:
        # For every rank of the weights matrix
        for r in range(rank):
            
            # Update all the u_rs first
            u_comp = tl.zeros(I)
            # Do some ugly linear algebra
            for r1 in range(rank):
                # Only want cases where r1 != r
                if r1 == r:
                    continue
                # Computing the summation in equations (13)
                u_comp += (v_rs[r1] @ v_rs[r]) * (XtXpL @ u_rs[r1])
            # More ugly linear algebra, finishing off the calculation from (13) 
            u_rs[r] = (XtXpL_inv @ ((XtY @ v_rs[r])-(u_comp / 2))) / (v_rs[r] @ v_rs[r])
            
            # Now update all the v_rs
            v_comp = np.zeros(J)
            # Do some ugly linear algebra
            for r1 in range(rank):
                # Only want cases where r1 != r
                if r1 == r:
                    continue
                v_comp += (u_rs[r] @ XtXpL @ u_rs[r1]) * v_rs[r1]
            v_rs[r] = ((YtX @ u_rs[r]) - (v_comp / 2)) / (u_rs[r] @ XtXpL @ u_rs[r])
        
        # Here comes the normalization step (ensures numerical stability)  
        for r in range(rank):
            u_scale = (tl.norm(v_rs[r], 2) / tl.norm(u_rs[r], 2)) ** 0.5
            v_scale = 1 / u_scale
            u_rs[r] *= u_scale
            v_rs[r] *= v_scale
        
        # Compute the new error, this time ignoring regularization
        W_mat = np.zeros((I, J))
        for r in range(rank):
            # Add each W_r
            W_mat += tl.kron(u_rs[r], v_rs[r]).reshape(I, J)
        # Use MSE to normalize the by the size of the tensor
        error = np.square(Y_mat - X_mat @ W_mat).mean() + lambda_reg * np.square(W_mat).mean()
        #errors.append(error)
        # Determine if we have converged
        if prev_error >= 0 and abs(prev_error - error) < eps:
            print("Converged after", steps, "steps. Final Error:", error)
            break
        else:
            print("Step:", steps, "Error:", error)
        # Reset the previous error
        prev_error = error
        # Next step
        steps += 1
        
    
    # TODO: Add functionality to make the factors
    
    return tl.reshape(W_mat, (X.shape[1:]) + (Y.shape[1:])) 
    

In [45]:
N = 11
X = np.random.random((N, 5, 7))
Y = np.zeros((N, 2, 3))
# Setup Y tensor with some dummy data
for n in range(N):
    x = X[n]
    Y[n] = np.array([[x[0, 0] + x[1, 1] , 2 * x[1, 0] - x[3, 2], (x[4, 5] + 1) ** 2],
                    [-x[1, 6] + 3 * x[1, 5] ,  - x[0, 6] - x[3, 3], (x[2, 5] + 2) ** 2]])
# Now fit it
W = tensor_to_tensor_regression_1(X, Y, 5, lambda_reg=1)
print(W.shape)
# We can verify our answer is correct with the following
Y_pred = tl.tenalg.contract(X, range(1, tl.ndim(X)), W, range(tl.ndim(X) - 1))
print(np.square(Y_pred - Y).mean()) # Should print the same as the final error

Step: 0 Error: 4.187710139472902
Step: 1 Error: 5.427425206371436
Step: 2 Error: 4.9988849742282735
Step: 3 Error: 4.673435500330027
Step: 4 Error: 4.428251184822083
Step: 5 Error: 4.2798215775395425
Step: 6 Error: 4.214043536002181
Step: 7 Error: 4.199231549187298
Step: 8 Error: 4.206932411397745
Step: 9 Error: 4.219520426704756
Step: 10 Error: 4.229418737321797
Step: 11 Error: 4.235171413308442
Step: 12 Error: 4.237823432126397
Step: 13 Error: 4.2388057323003
Step: 14 Error: 4.239129972410975
Step: 15 Error: 4.23930403519958
Step: 16 Error: 4.239499393582396
Step: 17 Error: 4.239729324162558
Step: 18 Error: 4.239961519485483
Step: 19 Error: 4.240167783780741
Step: 20 Error: 4.240336024857642
Step: 21 Error: 4.24046665336701
Step: 22 Error: 4.240565972787705
Step: 23 Error: 4.240641414328005
Step: 24 Error: 4.24069929347633
Step: 25 Error: 4.240744273553565
Step: 26 Error: 4.2407796075239945
Step: 27 Error: 4.240807565056719
Step: 28 Error: 4.240829787084753
Step: 29 Error: 4.24084751

In [46]:
# For timing
import time

start_time = time.time()
W = tensor_to_tensor_regression_1(X, Y, 5, lambda_reg=10)
print(time.time() - start_time)

Step: 0 Error: 4.0214449899317515
Step: 1 Error: 4.615278855711329
Step: 2 Error: 4.330122752733615
Step: 3 Error: 3.9053734534798177
Step: 4 Error: 3.6026137719382403
Step: 5 Error: 3.4503076963740353
Step: 6 Error: 3.399219947618061
Step: 7 Error: 3.397285823498813
Step: 8 Error: 3.410641701689073
Step: 9 Error: 3.423343543054102
Step: 10 Error: 3.4308872772046177
Step: 11 Error: 3.4339501762629157
Step: 12 Error: 3.434492264382796
Step: 13 Error: 3.434083972665766
Step: 14 Error: 3.433562029351891
Step: 15 Error: 3.433216856867045
Step: 16 Error: 3.4330632900403915
Step: 17 Error: 3.4330291269262503
Step: 18 Error: 3.4330451177581387
Step: 19 Error: 3.433070386632795
Step: 20 Error: 3.4330887479150083
Step: 21 Error: 3.433097858556881
Step: 22 Error: 3.433100518950238
Converged after 23 steps. Final Error: 3.4331000359015005
0.09400153160095215
