**Demo for `teneva.core_jax.svd`**

---

Module contains the basic implementation of the TT-SVD algorithm as well as functions for constructing the skeleton decomposition for the matrices.

## Loading and importing modules

In [1]:
import jax
import jax.numpy as np
import teneva as teneva_base
import teneva.core_jax as teneva
from time import perf_counter as tpc
rng = jax.random.PRNGKey(42)

## Function `matrix_skeleton`

Construct truncated skeleton decomposition A = U V for the given matrix.

In [2]:
# Shape of the matrix:
m, n = 100, 30

# Build random matrix, which has rank 3 as a sum of rank-1 matrices:
rng, key = jax.random.split(rng)
keys = jax.random.split(key, 6)
u = [jax.random.normal(keys[i], (m, )) for i in range(3)]
v = [jax.random.normal(keys[i], (m, )) for i in range(3, 6)]
A = np.outer(u[0], v[0]) + np.outer(u[1], v[1]) + np.outer(u[2], v[2])

In [3]:
# Compute skeleton decomp.:
U, V = teneva.matrix_skeleton(A, r=3)

# Approximation error
e = np.linalg.norm(A - U @ V) / np.linalg.norm(A)

print(f'Shape of U :', U.shape)
print(f'Shape of V :', V.shape)
print(f'Error      : {e:-8.2e}')

Shape of U : (100, 3)
Shape of V : (3, 100)
Error      : 3.31e-07


In [4]:
# Compute skeleton decomp with small rank:
U, V = teneva.matrix_skeleton(A, r=2)

# Approximation error:
e = np.linalg.norm(A - U @ V) / np.linalg.norm(A)
print(f'Shape of U :', U.shape)
print(f'Shape of V :', V.shape)
print(f'Error      : {e:-8.2e}')

Shape of U : (100, 2)
Shape of V : (2, 100)
Error      : 4.62e-01


## Function `svd`

Construct TT-tensor from the given full tensor using TT-SVD algorithm. Note that this function does not take advantage of jax's ability to speed up the code and can be slow, but it should only be meaningfully used for tensors of small dimensions.

In [5]:
d = 5               # Dimension number
t = np.arange(2**d) # Tensor will be 2^d

# Construct d-dim full array:
Z_full = np.cos(t).reshape([2] * d, order='F')

In [6]:
# Construct TT-tensor by TT-SVD:
Y = teneva.svd(Z_full, r=2)

# Convert it back to numpy to check result:
Y_full = teneva.full(Y)

# Compute error for TT-tensor vs full tensor:
e = np.linalg.norm(Y_full - Z_full)
e /= np.linalg.norm(Z_full)

In [7]:
# Size of the original tensor:
print(f'Size (np) : {Z_full.size:-8d}')

# Size of the TT-tensor:
print(f'Size (tt) : {Y[0].size + Y[1].size + Y[2].size:-8d}') # TODO  

# Rel. error for the TT-tensor vs full tensor:
print(f'Error     : {e:-8.2e}')               

Size (np) :       32
Size (tt) :       32
Error     : 6.60e-07


We can also try a lower rank (it will lead to huge error in this case):

In [8]:
# Construct TT-tensor by TT-SVD:
Y = teneva.svd(Z_full, r=1)

# Convert it back to numpy to check result:
Y_full = teneva.full(Y)

# Compute error for TT-tensor vs full tensor:
e = np.linalg.norm(Y_full - Z_full)
e /= np.linalg.norm(Z_full)

print(f'Size (np) : {Z_full.size:-8d}')
print(f'Size (tt) : {Y[0].size + Y[1].size + Y[2].size:-8d}') # TODO   
print(f'Error     : {e:-8.2e}')  

Size (np) :       32
Size (tt) :       10
Error     : 7.13e-01


Note that in jax version rank can not be greater than mode size:

In [9]:
try:
    Y = teneva.svd(Z_full, r=3)
except ValueError as e:
    print('Error :', e)

Error : Rank can not be greater than mode size


---