# Tensor Train-Singular Value Decomposition (TT-SVD)

In this notebook, we will walk through the process of compressing a high-dimensional tensor into a Tensor Train (TT) format using the TT-SVD algorithm.

## Table of Contents:
1. What is a Tensor Train (TT)?
2. The TT-SVD Algorithm
3. Reversing the decomposition: Restoring the original tensor
4. Examples
5. Exercises

## 1. What is a Tensor Train (TT)?
A **Tensor Train** (TT), also known as a Matrix Product State (MPS), is a method to decompose a high-dimensional tensor into a series of smaller tensors. 
This decomposition makes it easier to store, manipulate, and perform computations on high-dimensional tensors.

The goal of TT-SVD is to approximate a high-dimensional tensor as a series of lower-order tensors, reducing the computational and memory requirements.

## 2. The TT-SVD Algorithm
The TT-SVD algorithm works by performing the following steps for each mode (dimension) of the tensor:

1. Reshape the tensor to a 2D matrix
2. Perform Singular Value Decomposition (SVD)
3. Truncate the small singular values (based on a truncation error tolerance, `eps`)
4. Store the left singular vectors as one of the TT-cores (the decomposed tensors)
5. Contract the singular values with the right singular vectors
6. Repeat the process for the next mode.

Let's first import the necessary libraries and define the TT-SVD function. This function will compress a tensor into a TT format and return the truncation error.

In [312]:
import numpy as np
from copy import copy

def tt_svd(tens: np.ndarray, eps: float = 1e-6, max_bond_dimension: int = 10**12) -> list:
    """
    Compress a tensor to a MPS/TT using the TT-SVD algorithm.
    
    Args:
        tens: The input tensor
        eps: Truncation error for each SVD
        max_bond_dimension: Maximum allowable rank for truncation
    
    Returns:
        An MPS/TT as a list of order-3 tensors (dummy bonds are added to boundary tensors)
        The total truncation error
    """
    dims = tens.shape
    N = len(dims)
    tmp = copy(tens)
    A = []
    r_prev = 1
    
    total_trunc_error_squared = 0

    for i in range(N-1):
        # Step 1: Reshape the tensor into a 2D matrix
        tmp = tmp.reshape(r_prev * dims[i], -1)
        
        # Step 2: Perform SVD
        U, s, Vt = np.linalg.svd(tmp, full_matrices=False)
        
        # Step 3: Truncate small singular values (based on eps)
        truncation_error_squared = np.cumsum(s[::-1]**2)
        where_error_is_lower_than_eps = np.where(truncation_error_squared <= eps**2)[0]
        
        num_sv_to_discard = 0 if len(where_error_is_lower_than_eps) == 0 else int(1 + where_error_is_lower_than_eps[-1])
        
        # Compute new rank: should not be lower than 1 or greater than `max_bond_dimension`
        r = min(max_bond_dimension, max(1, len(s) - num_sv_to_discard))

        # Compute truncation error from discarded singular values
        total_trunc_error_squared += sum(s[r:]**2)

        # Step 4: Store the left singular vectors (TT-core)
        A.append(U[:, :r].reshape(r_prev, dims[i], r))
        
        # Step 5: Contract singular values with right singular vectors
        tmp = np.diagflat(s[:r]) @ Vt[:r, :]
        r_prev = r
    
    # Final TT-core
    A.append(tmp.reshape(r_prev, dims[-1], 1))
    
    return (A, np.sqrt(total_trunc_error_squared))

### Explanation of the Code:
- **Input Tensor (`tens`)**: This is the high-dimensional tensor we aim to compress.
- **Truncation Error (`eps`)**: This defines the tolerance for truncating small singular values, `eps` is the maximum allowable truncation error.
- **Max bond dimension (`max_bond_dimension`)**: This sets the maximum allowable rank for truncation.
- **`A` List**: This stores the sequence of TT-cores, which are the smaller tensors resulting from the decomposition.
- **`total_trunc_error_squared`**: This stores the total trunction error squared, in the end we return the square root of this quantity. 


# 3. Reversing the decomposition: Restoring the original tensor

As when we looked at the SVD we checked that multiplying the U, S, and Vt matrices together results in the original tensor, we can do the same here with the TT decomposition to restore the original tensor. Below is a code to do that:

In [313]:
def restore_full(mps: list) -> np.ndarray:
    """
    Restore full tensor from an MPS/TT

    Args:
        mps: List of order-3 tensors representing an MPS/TT

    Return:
        The full tensor
    """
    tmp = copy(mps[0])
    dims = [site.shape[1] for site in mps]
    for site in mps[1:]:
        tmp = np.einsum('iuj,jvk->iuvk', tmp, site)
        tmp = tmp.reshape(tmp.shape[0], tmp.shape[1] * tmp.shape[2], tmp.shape[3])
    return tmp.reshape(dims)

## 4. Examples

Now, let's run through some example to see the algorithm in action.

### 1. Random tensor
Let’s create a random 4-dimensional tensor for this example.

In [314]:
# Create a random 4D tensor with shape (2, 2, 2, 2, 2)
tensor = np.random.rand(3, 3, 3, 3, 3, 2, 2)
tensor.shape  # Confirm the shape

(3, 3, 3, 3, 3, 2, 2)

Next, we will apply the `tt_svd` function to compress this tensor.

In [315]:
# Apply the TT-SVD algorithm
tt_cores, err = tt_svd(tensor, eps=1e-6)

# Display the shapes of the TT-cores
for i, core in enumerate(tt_cores):
    print(f"Core {i+1}: {core.shape}")

print("Total trunction error: ", err)

Core 1: (1, 3, 3)
Core 2: (3, 3, 9)
Core 3: (9, 3, 27)
Core 4: (27, 3, 12)
Core 5: (12, 3, 4)
Core 6: (4, 2, 2)
Core 7: (2, 2, 1)
Total trunction error:  0.0


As you can see, the TT-SVD algorithm has decomposed our 4-dimensional tensor into several order 3 tensors (TT-cores).

The truncation error, however is zero because no truncation has taken place.

### Exercise: 
1. What do noticed about the bond dimensions when there is no truncation? 
2. How can we calculate the bond dimensions when there is no truncation?

Let's try restoring the original tensor to make sure we get back what we started with:

In [316]:
tensor_restored = restore_full(tt_cores)
print("Restored tensor matches original tensor: ", bool(np.all(np.isclose(tensor_restored, tensor))))

Restored tensor matches original tensor:  True




We can try experimenting with different tensors and values of `eps` and `rank` to see how they affect the compression and the sizes of the TT-cores.


In [317]:
tt_cores, err = tt_svd(tensor, eps=1e-1)

# Display the shapes of the TT-cores
for i, core in enumerate(tt_cores):
    print(f"Core {i+1}: {core.shape}")

print("Total trunction error: ", err)

Core 1: (1, 3, 3)
Core 2: (3, 3, 9)
Core 3: (9, 3, 27)
Core 4: (27, 3, 12)
Core 5: (12, 3, 4)
Core 6: (4, 2, 2)
Core 7: (2, 2, 1)
Total trunction error:  0.0


Increasing the `eps` to 10^-1 doesn't change the result - there is still no truncation. We can force a truncation by setting `max_bond_dimension`:

In [318]:
tt_cores, err = tt_svd(tensor, eps=1e-1, max_bond_dimension=4)

# Display the shapes of the TT-cores
for i, core in enumerate(tt_cores):
    print(f"Core {i+1}: {core.shape}")

print("Total trunction error: ", err)

Core 1: (1, 3, 3)
Core 2: (3, 3, 4)
Core 3: (4, 3, 4)
Core 4: (4, 3, 4)
Core 5: (4, 3, 4)
Core 6: (4, 2, 2)
Core 7: (2, 2, 1)
Total trunction error:  7.906917862700159


Now we can see the maximum bond dimension is capped by the value we set and there is a large truncation error. 

As explained in the course this truncation error is an upper bound on the actual error of the TT approximation. To compute the actual error we need to recontruct the rank-4 tensor and check how it compares to the original tensor by computing the Frobenius norm:

In [319]:
tensor_restored = restore_full(tt_cores)
actual_error = np.linalg.norm(tensor - tensor_restored)
print("Error computed with Frobenius norm: ", actual_error)
print("Upper bound on error: ", err)
print("Actual error is less than or equal to upper bound:", actual_error <= err + 10**-10)

Error computed with Frobenius norm:  7.9069178627001575
Upper bound on error:  7.906917862700159
Actual error is less than or equal to upper bound: True


So we confirm that the upper bound is correct for this case.

### 2. Sinusoidal function

The above random tensor doesn't have a low TT-rank and is not approximated well by a MPS/TT with low rank. Below is an example of a tensor which does have a low TT-rank (rank-2): a sinusoidal signal reshaped into a tensor. 

In [330]:
# Create a tensor by reshaping a vector of a cosine function:

shape = (2, 2, 2, 2, 2, 2, 2, 2)

w=1
phi=0.5
def x(t):
    return np.cos(w*t + phi)

tens = np.fromiter((x(t) for t in range(np.prod(shape))), dtype=np.float64).reshape(shape)
mps, err = tt_svd(tens, eps=10**-10)

Let's write a function to help us quickly visualize the bond dimensions:

In [336]:
# Get bond dimensions/tt-ranks
def bdims(mps: list):
    return [site.shape[0] for site in mps] + [mps[-1].shape[-1]]

print("MPS bond dimensions: ", bdims(mps))
print("Total trunction error: ", err)

MPS bond dimensions:  [1, 1, 1, 1, 1, 1, 1, 1, 1]
Total trunction error:  6.062073869922691e-16


We can see that the largest bond dimension of the MPS for the cosine function is 2. The truncation error is close to zero.

### 3. Exponential function

An even lower TT-rank tensor (rank-1) is that given by reshaping the exponential function.

In [335]:
# Create a tensor by reshaping a vector of a exponential function:

shape = (2, 2, 2, 2, 2, 2, 2, 2)

w=1
g=0.1
def x(t):
    return np.exp(-g*t)

tens = np.fromiter((x(t) for t in range(np.prod(shape))), dtype=np.float64).reshape(shape)
mps, err = tt_svd(tens, eps=10**-10)

In [337]:
print("MPS bond dimensions: ", bdims(mps))
print("Total trunction error: ", err)

MPS bond dimensions:  [1, 1, 1, 1, 1, 1, 1, 1, 1]
Total trunction error:  6.062073869922691e-16


## 5. Exercises
1. Compute the number of parameters in the MPS representation of the three example tensors: randon, cosine and exponential.
2. Compute the compression ratios, i.e. the by what factor have the original tensors been compressed?
3. For the random tensor and cosine tensor, plot the truncation error as a function of the MPS bond dimension.