In [1]:
import torch
import time

I want to compute the mode $n$ product $$\mathcal{Z} = \mathcal{A}\times_nU_n$$,where $\mathcal{A}$ has size $I_1 \times I_2\times \cdots \times I_d$ and $U_n$ has size $J \times I_n$. The ouput tensor should have size $I_1 \times \cdots\times I_{n-1}\times J\times I_{n+1}\times \cdots \times I_d$, and the calculation follows 


$$\mathcal{Z}(i_1,\cdots,i_{j-1},j,i_{n+1},\cdots,i_d) = \sum_{l=1}^{I_n}A(i_1,\cdots,i_{j-1},l,i_{n+1},\cdots,i_d)U(j,l)$$

The computation can also be done by using unfolding 

$$\mathcal{Z}_{(n)} = U\mathcal{A}_{(n)}$$

, where $\mathcal{Z}_{(n)},\mathcal{A}_{(n)}$ are the mode $n$ unfolding of $\mathcal{Z},\mathcal{A}$

### 1. Toy example 
Let's focus on a small example with size $2\times 2\times 2$ and $3\times 2$

In [2]:
d = 3
n = 3
r = 2

A = torch.rand([r for _ in range(d)])
U = torch.rand([n, r])

print('Tensor A:', A)
print('Matrix U:', U)

Tensor A: tensor([[[0.0977, 0.4686],
         [0.4027, 0.5704]],

        [[0.0820, 0.9125],
         [0.8529, 0.0545]]])
Matrix U: tensor([[0.7287, 0.3158],
        [0.4553, 0.2043],
        [0.9357, 0.4596]])


#### Compute the mode product between $\mathcal{A}$ and $U$ for each mode of $\mathcal{A}$. 

##### For ***mode 1***, 

using basic definition $$\mathcal{Z}(j,i_2,i_3) = \sum_{l=1}^{I_1}A(l,i_2,i_3)U(j,l)$$, we expect a tensor with size $3\times 2\times 2$

In [3]:
B = torch.zeros([n,r,r])
for i in range(n):
    for j in range(r):
        for k in range(r):
            for l in range(r):
                B[i,j,k] = B[i,j,k] + U[i,l]*A[l,j,k]
print(B)

tensor([[[0.0971, 0.6296],
         [0.5628, 0.4329]],

        [[0.0612, 0.3998],
         [0.3576, 0.2709]],

        [[0.1291, 0.8579],
         [0.7688, 0.5588]]])


Using unfolding method, $$\mathcal{Z}_{(1)} = U\mathcal{A}_{(1)}$$
`torch.moveaxis(A,0,0)` does not change anything

`A_1.reshape([n,r,r])` change back from matrix to tensor

In [4]:
A_1 = torch.moveaxis(A, 0, 0).reshape(A.shape[0],-1)
A_1 = U@A_1
B = torch.moveaxis(A_1.reshape([n,r,r]), 0, 0)
print(B)

tensor([[[0.0971, 0.6296],
         [0.5628, 0.4329]],

        [[0.0612, 0.3998],
         [0.3576, 0.2709]],

        [[0.1291, 0.8579],
         [0.7688, 0.5588]]])


Using `tensordot`, we want to change the first mode from size $I_1$ to $J$, 

so we need to match the first dimension of $\mathcal{A}$ and the second dimension of $U$. 

But `tensordot` outputs a tensor with size $I_2\times I_3\times J$, and we need $J\times I_2\times I_3$

so we can 
- use `B.permute(2,0,1)` to move the last dimension to the first: $(0,1,2)\to (2,0,1)$. Have to write down all dim index to keep the order of $I_2$ and $I_3$
- this can also be done by `torch.moveaxis(B,2,0)` or `B.movedim(2,0)` where they insert the last dim/axis to the first position, while keeping the order of $I_2$ and $I_3$
- note that `torch.moveaxis(B,2,0)`:$(0,1,2)\to (2,0,1)$  is not the same as `torch.moveaxis(B,0,2)`: $(0,1,2)\to (1,2,0)$

In [5]:
B = torch.tensordot(A, U, dims=([0], [1]))
B = B.permute(2,0,1)
print(B)

tensor([[[0.0971, 0.6296],
         [0.5628, 0.4329]],

        [[0.0612, 0.3998],
         [0.3576, 0.2709]],

        [[0.1291, 0.8579],
         [0.7688, 0.5588]]])


##### Let's do for the same for ***mode 2***

using basic definition $$\mathcal{Z}(i_1,j,i_3) = \sum_{l=1}^{I_2}A(i_1,l,i_3)U(j,l)$$, we expect a tensor with size $2\times 3\times 2$

In [6]:
B = torch.zeros([r,n,r])
for i in range(r):
    for j in range(n):
        for k in range(r):
            for l in range(r):
                B[i,j,k] = B[i,j,k] + U[j,l]*A[i,l,k]
print(B)

tensor([[[0.1983, 0.5216],
         [0.1268, 0.3300],
         [0.2765, 0.7007]],

        [[0.3290, 0.6822],
         [0.2116, 0.4266],
         [0.4686, 0.8789]]])


Using unfolding method, $$\mathcal{Z}_{(2)} = U\mathcal{A}_{(2)}$$
`torch.moveaxis(A,1,0)` insert the second dimension to the first position and shifts the positions of the other dimensions accordingly.

We need this step since only in this way will give us mode 2 fiber as columns of $\mathcal{A}_{(2)}$

`A_2.reshape([n,r,r])` change back from matrix to tensor

we need to reshape tensor as size $3\times 2\times 2$ since we've changed dimension earlier

`moveaxis` allows us to change the dimension back


In [7]:
A_2 = torch.moveaxis(A, 1, 0).reshape(A.shape[1],-1)
A_2 = U@A_2
B = torch.moveaxis(A_2.reshape([n,r,r]), 0, 1)
print(B)

tensor([[[0.1983, 0.5216],
         [0.1268, 0.3300],
         [0.2765, 0.7007]],

        [[0.3290, 0.6822],
         [0.2116, 0.4266],
         [0.4686, 0.8789]]])


Using `tensordot`, we want to change the second mode from size $I_2$ to $J$, 

so we need to match the second dimension of $\mathcal{A}$ and the second dimension of $U$. 

But `tensordot` outputs a tensor with size $I_1\times I_3\times J$, and we need $I_1\times J\times I_3$

so `permute` is required to move the last dimension to the second 

In [8]:
B = torch.tensordot(A, U, dims=([1], [1]))
B = B.permute(0,2,1)
print(B)

tensor([[[0.1983, 0.5216],
         [0.1268, 0.3300],
         [0.2765, 0.7007]],

        [[0.3290, 0.6822],
         [0.2116, 0.4266],
         [0.4686, 0.8789]]])


##### ***Last Mode***

using basic definition $$\mathcal{Z}(i_1,i_2,j) = \sum_{l=1}^{I_3}A(i_1,i_2,l)U(j,l)$$, we expect a tensor with size $2\times 2\times 3$

In [9]:
B = torch.zeros([r,r,n])
for i in range(r):
    for j in range(r):
        for k in range(n):
            for l in range(r):
                B[i,j,k] = B[i,j,k] + U[k,l]*A[i,j,l]
print(B)

tensor([[[0.2192, 0.1402, 0.3068],
         [0.4736, 0.2999, 0.6390]],

        [[0.3479, 0.2238, 0.4960],
         [0.6387, 0.3995, 0.8231]]])


Using unfolding method, $$\mathcal{Z}_{(3)} = U\mathcal{A}_{(3)}$$
`torch.moveaxis(A,2,0)` insert the third dimension to the first position and shifts the positions of the other dimensions accordingly.

We need this step since only in this way will give us mode 3 fiber as columns of $\mathcal{A}_{(2)}$

`A_3.reshape([n,r,r])` change back from matrix to tensor

we need to reshape tensor as size $3\times 2\times 2$ since we've changed dimension earlier

`moveaxis` allows us to change the dimension back



In [10]:
A_3 = torch.moveaxis(A, 2, 0).reshape(A.shape[2],-1)
A_3 = U@A_3
B = torch.moveaxis(A_3.reshape([n,r,r]), 0, 2)
print(B)

tensor([[[0.2192, 0.1402, 0.3068],
         [0.4736, 0.2999, 0.6390]],

        [[0.3479, 0.2238, 0.4960],
         [0.6387, 0.3995, 0.8231]]])


Using `tensordot`, we want to change the third mode from size $I_3$ to $J$, 

so we need to match the third dimension of $\mathcal{A}$ and the second dimension of $U$. 

`tensordot` outputs a tensor with size $I_1\times I_2\times J$, which is we need

so `permute` is not required in this case. 

Note: permute(0,1,2) did not change anything

In [11]:
B = torch.tensordot(A, U, dims=([2], [1]))
B = B.permute(0,1,2) # didn;t change anything
print(B)

tensor([[[0.2192, 0.1402, 0.3068],
         [0.4736, 0.2999, 0.6390]],

        [[0.3479, 0.2238, 0.4960],
         [0.6387, 0.3995, 0.8231]]])


### 2. More factors

We want to compute $$\mathcal{Z} = \mathcal{A}\times_1U_1\times_2 U_2\times_3U_3$$,where $\mathcal{A}$ has size $I_1 \times I_2\times \cdots \times I_d$ and $U_n$ has size $J_n \times I_n$. The ouput tensor should have size $J_1 \times J_2\times  \cdots \times J_d$, and the calculation follows 


$$\mathcal{Z}(j_1,j_2,\cdots,j_d) = \sum_{i_1=1}^{I_1}\cdots\sum_{i_d=1}^{I_d}A(i_1,i_2,\cdots,i_d)U_1(j_1,i_1)\cdots U_d(j_d,i_d)$$


Consider an easy case with size $2\times 2\times 2$ for tensor and $3\times 2$ for all matrix factors

In [12]:
d = 3
n = 1000
r = 10
U = []
A = torch.rand([r for _ in range(d)])
U.append(torch.rand([n, r]))
U.append(torch.rand([n, r]))
U.append(torch.rand([n, r]))

#### Using Definition

$$\mathcal{Z}(j_1,j_2,j_3) = \sum_{i_1=1}^{I_1}\sum_{i_2=1}^{I_2}\sum_{i_3=1}^{I_3}A(i_1,i_2,i_3)U_1(j_1,i_1)U_2(j_2,i_2)U_3(j_3,i_3)$$

In [13]:
#start_time= time.time()
#B_def = torch.zeros([n,n,n])
#for j1 in range(n):
#    for j2 in range(n):
#        for j3 in range(n):
#            for i1 in range(r):
#                for i2 in range(r):
#                    for i3 in range(r):
#                        B_def[j1,j2,j3] = B_def[j1,j2,j3] + U[0][j1,i1]*U[1][j2,i2]*U[2][j3,i3]*A[i1,i2,i3]
                
#print('Time taken:', time.time()-start_time)

#### Using unfolding strategy

iteratively unfolding, folding for each mode $$\mathcal{Z}_{(n)} = U_n\mathcal{A}_{(n)}$$

In [14]:
start_time= time.time()
B_unf = A.clone()
shape = list(B_unf.shape)
for i in range(d):
    B_i = torch.moveaxis(B_unf, i, 0).reshape(B_unf.shape[i],-1)
    B_i = U[i]@B_i
    shape[i] = n
    B_unf = torch.moveaxis(B_i.reshape(shape), 0, i)
print('Time taken:', time.time()-start_time)

Time taken: 0.3908851146697998


Consider the other way of unfolding approach
$$\mathcal{Z}_{(1)} = U_1\mathcal{A}_{(1)}(U_3\otimes U_2)^T$$
$$\mathcal{Z}_{(2)} = U_2\mathcal{A}_{(2)}(U_3\otimes U_1)^T$$
$$\mathcal{Z}_{(3)} = U_3\mathcal{A}_{(3)}(U_2\otimes U_1)^T$$


***Note***: Reshape function in python is using lexicographical order, which means that the third index iterates the fastest, so actually we need

$$\mathcal{Z}_{(1)} = U_1\mathcal{A}_{(1)}(U_2\otimes U_3)^T$$
$$\mathcal{Z}_{(2)} = U_2\mathcal{A}_{(2)}(U_1\otimes U_3)^T$$
$$\mathcal{Z}_{(3)} = U_3\mathcal{A}_{(3)}(U_1\otimes U_2)^T$$


In [15]:
start_time= time.time()
U_kron = torch.ones(1)
B_kron = A.clone()
for i in range(1,d):
    U_kron = torch.kron(U_kron,U[i])
        
B_1 = B_kron.reshape(B_kron.shape[0],-1)
B_kron = U[0]@B_1@U_kron.T
B_kron = B_kron.reshape([n for _ in range(d)])

print('Time taken:', time.time()-start_time)


Time taken: 0.6207261085510254


In [16]:
start_time= time.time()
U_kron_2 = torch.ones(1)
B_kron_2 = A.clone()
for i in range(d):
    if i!=1:
        U_kron_2 = torch.kron(U_kron_2,U[i])
        
B_2 = B_kron_2.movedim(1,0).reshape(B_kron_2.shape[1],-1)
B_kron_2 = U[1]@B_2@U_kron_2.T
B_kron_2 = B_kron_2.reshape([n for _ in range(d)]).movedim(0,1)

print('Time taken:', time.time()-start_time)


Time taken: 2.8304221630096436


In [17]:
start_time= time.time()
U_kron_3 = torch.ones(1)
B_kron_3 = A.clone()
for i in range(d):
    if i!=2:
        U_kron_3 = torch.kron(U_kron_3,U[i])
        
B_3 = B_kron_3.movedim(2,0).reshape(B_kron_3.shape[2],-1)
B_kron_3 = U[2]@B_3@U_kron_3.T
B_kron_3 = B_kron_3.reshape([n for _ in range(d)]).movedim(0,2)

print('Time taken:', time.time()-start_time)

Time taken: 2.070183277130127


In [18]:
print(torch.allclose(B_kron, B_kron_2))
print(torch.allclose(B_kron, B_kron_3))

True
True


#### Using `tensordot`

In [19]:
start_time= time.time()
B_dot = A.clone()
for i, Ui in enumerate(U):
    B_dot = torch.tensordot(B_dot, Ui, dims=([i], [1]))
    B_dot = B_dot.movedim(-1, i)
print('Time taken:', time.time()-start_time)

Time taken: 0.4466838836669922


#### Compare the results

In [20]:
print(torch.allclose(B_unf, B_kron))
print(torch.allclose(B_unf, B_dot))


True
True


It seems like `tensordot` and the first unfolding approach works the best

Probably choose `tensordot` for less code