# Matrix Product States

In [None]:
import numpy as np

## TT-SVD

In [398]:
N=10 # number of sites/spins
d=2 # physical dimension

eps=0 # SVD truncation error

tens = np.random.rand(*[2] * N) # high-dimensional tensor
A=[] # list for storing MPS tensors
tens.shape

(2, 2, 2, 2, 2, 2, 2, 2, 2, 2)

In [399]:
# Reshape (step 1)
tmp = tens.reshape(tens.shape[0], -1) # (temporary tensor)
tmp.shape

(2, 512)

In [None]:
# SVD + Truncate (step 2)
U, s, Vt = np.linalg.svd(tmp)

# Truncate singular values such that truncation error is less than or equal to eps
where_truncation_error_is_lower_than_eps = np.where(np.cumsum(s[::-1]**2) <= eps**2)[0]
num_sv_to_discard = 0 if len(where_truncation_error_is_lower_than_eps) == 0 else int(1 + where_truncation_error_is_lower_than_eps[-1])
r = max(1, len(s) - num_sv_to_discard) # new rank

In [402]:
# Reshape and truncate U matrix, store as first MPS site
A.append(U[:,:r].reshape(1, d, r))
A[0].shape

(1, 2, 2)

In [403]:
# Contract s and Vt (step 3)
tmp = np.diagflat(s[:r]) @ Vt[:r,:]
tmp.shape

(2, 512)

In [404]:
# Reshape (step 4)
tmp = tmp.reshape(r * tens.shape[1], -1)
tmp.shape
# Repeat steps 2-4, N-1 times

(4, 256)

In [434]:
from copy import copy
def tt_svd(tens: np.ndarray, eps: float = 10**-6) -> list:
    """
    Compress a tensor to a MPS/TT using the TT-SVD algorithm.

    Args:
        tens: The input tensor
        eps: Truncation error for each SVD
    Return:
        An MPS/TT as a list of order-3 tensors (dummy bonds are added to boundary tensors)
    """
    dims = tens.shape
    N = len(dims)
    tmp = copy(tens)
    A = []
    r_prev = 1
    for i in range(N-1):
        # Reshape (step 4)
        tmp = tmp.reshape(r_prev * dims[i], -1)
        
        # SVD + Truncate (step 2)
        U, s, Vt = np.linalg.svd(tmp)
        # Truncate singular values such that truncation error is less than or equal to eps
        where_truncation_error_is_lower_than_eps = np.where(np.cumsum(s[::-1]**2) <= eps**2)[0]
        num_sv_to_discard = 0 if len(where_truncation_error_is_lower_than_eps) == 0 else int(1 + where_truncation_error_is_lower_than_eps[-1])
        r = max(1, len(s) - num_sv_to_discard) # new rank
        
        # Reshape and truncate U matrix, store in return list
        A.append(U[:,:r].reshape(r_prev, dims[i], r))
        
        # Contract s and Vt (step 3)
        tmp = np.diagflat(s[:r]) @ Vt[:r,:]
        r_prev = r
    A.append(tmp.reshape(r_prev, dims[-1], 1))
    return A

In [442]:
eps=10**-1
mps = tt_svd(tens, eps=eps)

In [443]:
# Show bond dimensions
[site.shape[0] for site in mps] + [mps[-1].shape[-1]]

[1, 2, 4, 8, 16, 30, 16, 8, 4, 2, 1]

In [444]:
def restore_full(mps: list) -> np.ndarray:
    """
    Restore full tensor from an MPS/TT

    Args:
        mps: List of order-3 tensors representing an MPS/TT

    Return:
        The full tensor
    """
    tmp = mps[0]
    dims = [site.shape[1] for site in mps]
    for site in mps[1:]:
        tmp = np.einsum('iuj,jvk->iuvk', tmp, site)
        tmp = tmp.reshape(tmp.shape[0], tmp.shape[1] * tmp.shape[2], tmp.shape[3])
    return tmp.reshape(dims)

In [440]:
np.linalg.norm(tens - restore_full(mps)) # The TT/MPS approximation error

0.08972321921920606

In [446]:
eps * np.sqrt(N) # Theoreterical upper bound for the TT/MPS approximation error

0.316227766016838