# Tutorial 03.1 Tensor decomposition

In [1]:
import torch
from scipy.linalg import svd
import numpy as np

# Tensor decomposition and entanglement entropy
# Define the size of the tensor
sz = [2, 3, 2, 3, 4]

# Create the rank-5 tensor and normalize it
T = torch.arange(1, np.prod(sz)+1, dtype=torch.float).reshape(sz)
T = T / torch.norm(T)

# Initialize the list to store Q
Q = [None]*len(sz)
R = T  # temporary tensor to be QR-decomposed

szl = 1  # the bond dimension of the left leg of Q[n] to be obtained after
# the QR decomposition at iteration n; for n = 1, szl = 1 for the dummy leg

In [2]:
for it in range(len(sz)-1):
    R = R.reshape([szl*sz[it], np.prod(sz[it+1:])])
    Q[it], R = torch.linalg.qr(R)
    Q[it] = Q[it].reshape([szl, sz[it], int(Q[it].numel()/szl/sz[it])])
    Q[it] = Q[it].permute([0, 2, 1])  # permute to the left-right-bottom order
    szl = Q[it].shape[1]  # update the bond dimension
    R = R.reshape([szl]+sz[it+1:])
    print([szl]+sz[it+1:])

print(R.shape)
print(R.dim())

if R.dim() > 2:
    Q[-1] = R.permute([0, 2, 1])
else:
    temp = R.shape
    Q[-1] = R.reshape([temp[0]] + [1] + [temp[1]])
    
for i, tensor in enumerate(Q):
    print(f"Shape of Q[{i}]: {tensor.shape}")

[2, 3, 2, 3, 4]
[6, 2, 3, 4]
[12, 3, 4]
[4, 4]
torch.Size([4, 4])
2
Shape of Q[0]: torch.Size([1, 2, 2])
Shape of Q[1]: torch.Size([2, 6, 3])
Shape of Q[2]: torch.Size([6, 12, 2])
Shape of Q[3]: torch.Size([12, 4, 3])
Shape of Q[4]: torch.Size([4, 1, 4])


In [3]:
# Contract the tensors Q[n] to make a rank-5 tensor again
T2 = Q[0].permute([2, 1, 0])
for it in range(1, len(sz)):
    T2 = torch.tensordot(T2, Q[it].permute([0, 2, 1]), dims=[[len(T2.shape)-1], [0]])

print(T.shape)
print(T2.shape)
print(torch.max(torch.abs(T - T2)))

torch.Size([2, 3, 2, 3, 4])
torch.Size([2, 2, 3, 2, 3, 4, 1])


RuntimeError: The size of tensor a (3) must match the size of tensor b (4) at non-singleton dimension 5

### Entanglement entropies for 3 different bipartitions

In [None]:
# A = {1, 2}, B = {3, 4, 5}
svec = torch.linalg.svd(T.reshape([sz[0]*sz[1], np.prod(sz[2:])]))[1]  # singular values
Spart = -(svec**2)*torch.log(svec**2)/torch.log(torch.tensor(2.))  # contributions to entanglement entropy
print(torch.sum(Spart[~torch.isnan(Spart)]))  # entanglement entropy

In [None]:
# A = {1, 3}, B = {2, 4, 5}
svec = torch.linalg.svd(T.permute([0, 2, 1, 3, 4]).reshape([sz[0]*sz[2], sz[1]*sz[3]*sz[4]]))[1]  # singular values
Spart = -(svec**2)*torch.log(svec**2)/torch.log(torch.tensor(2.))  # contributions to entanglement entropy
print(torch.sum(Spart[~torch.isnan(Spart)]))  # entanglement entropy

In [None]:
# A = {1, 5}, B = {2, 3, 4}
svec = torch.linalg.svd(T.permute([0, 4, 1, 2, 3]).reshape([sz[0]*sz[4], sz[1]*sz[2]*sz[3]]))[1]  # singular values
Spart = -(svec**2)*torch.log(svec**2)/torch.log(torch.tensor(2.))  # contributions to entanglement entropy
print(torch.sum(Spart[~torch.isnan(Spart)]))  # entanglement entropy

In [None]:
# Use the SVD for the tensor decomposition and compute the entanglement entropy
M = [None]*len(sz)  # MPS tensors
Sent = torch.zeros(len(sz)-1)  # entanglement entropy
R = T  # temporary tensor to be SVD-ed

szl = 1  # the bond dimension of the left leg of R[n] to be obtained 
# after the SVD at iteration n; trivially 1 for n = 1

In [None]:
for it in range(len(sz)-1):
    R = R.reshape([szl*sz[it], np.prod(sz[it+1:])])
    U, svec, V = torch.linalg.svd(R, full_matrices=False)
    
    # truncate the column vectors of U and V associated with 
    # singular values smaller than eps
    ok = svec < torch.finfo(svec.dtype).eps
    U = U[:, ~ok]
    V = V[~ok, :]
    
    M[it] = U.reshape([szl, sz[it], int(U.numel()/szl/sz[it])])
    M[it] = M[it].permute([0, 2, 1])  # permute to the left-right-bottom order
    szl = M[it].shape[1]  # update the bond dimension
    R = torch.diag(svec[~ok]) @ V
    R = R.reshape([szl]+sz[it+1:])
    
    # compute entanglement entropy
    Spart = -(svec**2)*torch.log(svec**2)/torch.log(torch.tensor(2.))
    Sent[it] = torch.sum(Spart[~torch.isnan(Spart)])

if R.dim() == 3:
    M[-1] = R.permute([0, 2, 1])
elif R.dim() == 2:
    M[-1] = R.permute([1, 0])
    
# Check the tensors M and Q
for i, tensor in enumerate(M):
    print(f"Shape of M[{i}]: {tensor.shape}")
    
for i, tensor in enumerate(Q):
    print(f"Shape of Q[{i}]: {tensor.shape}")

In [None]:
# Print the values of entanglement entropy
print(Sent)

In [None]:
# Check whether the contraction of M's give the original tensor T
T2 = M[0].permute([2, 1, 0])
for it in range(1, len(sz)):
    if M[it].dim() == 3:
        T2 = torch.tensordot(T2, M[it].permute([0, 2, 1]), dims=[[len(T2.shape)-1], [0]])
    elif M[it].dim() == 2:
        T2 = torch.tensordot(T2, M[it].permute([1, 0]), dims=[[len(T2.shape)-1], [0]])

print(T.shape)
print(T2.shape)
print(torch.max(torch.abs(T - T2)))