In [13]:
import numpy as np
import tensorly as tl
from tensorly.tenalg import mode_dot
from numpy.random import randn, rand, randint
import time

## Defining a random tensor A

In [55]:
np.random.seed(0)
shape = (2, 3, 4)

A = randn(shape[0], shape[1], shape[2])

np.random.seed(1)
A = A + randn(shape[0], shape[1], shape[2])*1j
A

array([[[ 1.76405235+1.62434536j,  0.40015721-0.61175641j,
          0.97873798-0.52817175j,  2.2408932 -1.07296862j],
        [ 1.86755799+0.86540763j, -0.97727788-2.3015387j ,
          0.95008842+1.74481176j, -0.15135721-0.7612069j ],
        [-0.10321885+0.3190391j ,  0.4105985 -0.24937038j,
          0.14404357+1.46210794j,  1.45427351-2.06014071j]],

       [[ 0.76103773-0.3224172j ,  0.12167502-0.38405435j,
          0.44386323+1.13376944j,  0.33367433-1.09989127j],
        [ 1.49407907-0.17242821j, -0.20515826-0.87785842j,
          0.3130677 +0.04221375j, -0.85409574+0.58281521j],
        [-2.55298982-1.10061918j,  0.6536186 +1.14472371j,
          0.8644362 +0.90159072j, -0.74216502+0.50249434j]]])

In [80]:
# for i in A[0]:
#     print("{} & {} & {} & {} \\\\".format(
#         str(round(i[0], 1)).strip('(').strip(')'), 
#         str(round(i[1], 1)).strip('(').strip(')'),
#         str(round(i[2], 1)).strip('(').strip(')'),
#         str(round(i[3], 1)).strip('(').strip(')')))

# Unfolding

In [27]:
def unfold(A, n):
    '''Outputs the n-mode unfolding of A
    Inputs:
    -------
        A: Tensor (nd.array)
        n: the n-mode of the unfold (int)
        
    Outputs:
    --------
        X: n-mode unfolding matrix of A (nd.array)
    '''
    
    C = np.concatenate([np.arange(0, n), np.arange(n+1, A.ndim)])

    M = A.shape[n]
    N = np.prod([A.shape[i] for i in C])

    X = np.reshape(np.transpose(A, np.concatenate([[n], C])), (M, N))
    
    return X

In [86]:
for n in range(A.ndim):
    assert tl.unfold(tl.tensor(A), n).all() == unfold(A, n).all()

In [79]:
# for i in unfold(A, 2):
#     print("{} & {} & {} & {} & {} & {} \\\\".format(
#         str(round(i[0], 1)).strip('(').strip(')'), 
#         str(round(i[1], 1)).strip('(').strip(')'),
#         str(round(i[2], 1)).strip('(').strip(')'),
#         str(round(i[3], 1)).strip('(').strip(')'),
#         str(round(i[4], 1)).strip('(').strip(')'),
#         str(round(i[5], 1)).strip('(').strip(')')))

# Folding

In [17]:
def fold(unfolded_A, n, shape):
    '''Folds a unfolded tensor back to tensor form
    
    Inputs:
    -------
        unfolded_A: n-mode unfolded tensor matrix (nd.array)
        n: n-mode of the unfolding (int)
        shape: original shape of the tensor (tuple)
        
    Outputs:
    --------
        X = Folded matrix A in tensor format (nd.array)
    '''
    C = np.concatenate([np.arange(0, n), np.arange(n+1, len(shape))])

    new_shape = []
    new_shape.append(shape[n])
    for item in [shape[i] for i in C]:
        new_shape.append(item)
        

    A = unfolded_A.reshape(tuple(new_shape))

    X = np.moveaxis(A, 0, n)
    
    return X

In [89]:
for n in range(A.ndim):
    unfolded_A = unfold(A, n)
    assert tl.fold(unfolded_A, n, shape).all() == fold(unfolded_A, n, shape).all()

# n-mode Product

In [19]:
def n_mode_product(A, U, n):
    y_shape = list(A.shape)
    y_shape[n] = U.shape[0]
    
    unfolded_Y = np.matmul(U, unfold(A, n))
    
    Y = fold(unfolded_Y, n, tuple(y_shape))
    
    return Y

In [107]:
P, Q, R = 2, 3, 4
I, J, K = 5, 6, 7

np.random.seed(0)

A = randn(P, I)
B = randn(Q, J)
C = randn(R, K)
X = randn(I, J, K)

np.random.seed(1)

A = A + randn(P, I)*1j
B = B + randn(Q, J)*1j
C = C + randn(R, K)*1j
X = X + randn(I, J, K)*1j

In [121]:
Y = n_mode_product(n_mode_product(n_mode_product(X, A, 0), B, 1), C, 2)
Y_tl = mode_dot(mode_dot(mode_dot(X, A, 0), B, 1), C, 2)

assert Y.all() == Y_tl.all()

In [119]:
# for i in mode_dot(mode_dot(mode_dot(X, A, 0), B, 1), C, 2)[0]:
#     print("{} & {} & {} & {} \\\\".format(
#         str(round(i[0], 1)).strip('(').strip(')'), 
#         str(round(i[1], 1)).strip('(').strip(')'),
#         str(round(i[2], 1)).strip('(').strip(')'),
#         str(round(i[3], 1)).strip('(').strip(')')))

In [120]:
# for i in C:
#     print("{} & {} & {} & {} & {} & {} & {} \\\\".format(
#         str(round(i[0], 1)).strip('(').strip(')'), 
#         str(round(i[1], 1)).strip('(').strip(')'),
#         str(round(i[2], 1)).strip('(').strip(')'),
#         str(round(i[3], 1)).strip('(').strip(')'), 
#         str(round(i[4], 1)).strip('(').strip(')'),
#         str(round(i[5], 1)).strip('(').strip(')'),
#         str(round(i[6], 1)).strip('(').strip(')')))