The notebook contains simple examples to get started with TT decomposition using tensorly (http://tensorly.org/stable/index.html) and ttpy (https://github.com/oseledets/ttpy)

Author: Suhan Shetty (suhan.n.shetty@gmail.com | suhan.shetty@idiap.ch)

In [7]:
import numpy as np

In [8]:
# Define the input tensor

# Discretie the domain of the function and 
K = 50 # each axis ([0,1]) is discretized at K points
d = 3

# Define a d-dimensional function with domain: [0,1]^d
def f_c(x):
    return np.exp(-np.linalg.norm(x))   

def f(I):
    I = I.reshape(-1,d)
    x = I/K
    return f_c(x)
# Note: In general f could be any function that returns the value of d-dimensional array given an index I 

F = np.empty([K]*d)
for i in range(K):
    for j in range(K):
        for k in range(K):
            F[i,j,k] = f(np.array([i,j,k]))



## Using Tensorly

In [9]:
import tensorly
from tensorly.decomposition import tensor_train as TT

In [10]:
# Find the decomposition of tensor F
ttF = TT(F,rank=[1,5,5,1])# increase rank for better approximation rank. Note: rank[0]=1 and rank[-1]=1
factors = ttF.factors # list of tt cores
F_apprx = ttF.to_tensor() #re-contruct the full d-dimenisonal tensor from it tt-decomposition
print("Error: ",np.linalg.norm(F-F_apprx))
print("Number of elements in the original array: ", K**d)
print("Number of elements in tt format: ", np.sum([ttF.rank[i]*ttF.rank[i+1]*K for i in range(d)]))

Error:  0.018017938381372296
Number of elements in the original array:  125000
Number of elements in tt format:  1750


## Using ttpy

In [11]:
# !pip install ttpy
import tt
from tt.cross import rect_cross as tt_cross

In [12]:
# Given a full d-dimensional arry F, compute its tt-decomposition using tt-svd algorithm
ttF = tt.core.vector.tensor(F, eps=1e-3) # Find the decomposition
F_apprx = ttF.full() #re-contruct the full d-dimenisonal tensor from it tt-decomposition
print("Error: ",np.linalg.norm(F-F_apprx))
factors = tt.vector.to_list(ttF) #list of all the tt cores

print("Number of elements in the original array: ", K**d)
print("Number of elements in tt format: ", np.sum([ttF.r[i]*ttF.r[i+1]*K for i in range(d)]))

Error:  0.053601566915513335
Number of elements in the original array:  125000
Number of elements in tt format:  1200
