**Demo for `teneva_jax.transformation`**

---

This module contains the functions for orthogonalization and truncation of the TT-tensors.

## Loading and importing modules

In [1]:
from jax.config import config
config.update('jax_enable_x64', True)

import os
os.environ['JAX_PLATFORM_NAME'] = 'cpu'

In [2]:
import jax
import jax.numpy as jnp
import teneva as teneva_base
import teneva_jax as teneva
from time import perf_counter as tpc
rng = jax.random.PRNGKey(42)

## Function `full`

For a given TT-tensor (list of TT-cores), calculates the tensor in full format (this function can only be used for relatively small tensors).

In [3]:
d = 5     # Dimension of the tensor
n = 6     # Mode size of the tensor
r = 4     # Rank of the tensor

rng, key = jax.random.split(rng)
Y = teneva.rand(d, n, r, key)
teneva.show(Y)

Z = teneva.full(Y)

# Compare one value of original tensor and reconstructed tensor:
k = jnp.array([0, 1, 2, 3, 4])
y = teneva.get(Y, k)
z = Z[tuple(k)]
e = jnp.abs(z-y)
print(f'Error : {e:7.1e}')

TT-tensor-jax | d =     5 | n =     6 | r =     4 |
Error : 5.6e-17


## Function `orthogonalize_rtl`

Orthogonalization for TT-tensor from right to left.

In [4]:
rng, key = jax.random.split(rng)
Y = teneva.rand_norm(d=7, n=4, r=3, key=key)
Z = teneva.orthogonalize_rtl(Y)
teneva.show(Z)

TT-tensor-jax | d =     7 | n =     4 | r =     3 |


We can verify that the values of the orthogonalized tensor have not changed:

In [5]:
Y_full = teneva.full(Y)
Z_full = teneva.full(Z)
e = jnp.max(jnp.abs(Y_full - Z_full))
print(f'Error     : {e:-8.2e}')

Error     : 5.68e-13


And we can make sure that all TT-cores, except the first one, have become orthogonalized (in terms of the TT-format):

In [6]:
Zl, Zm, Zr = Z

v = [Zl[:, j, :] @ Zl[:, j, :].T for j in range(Zl.shape[1])]
print(jnp.sum(jnp.array(v), axis=0))

for G in Zm:
    v = [G[:, j, :] @ G[:, j, :].T for j in range(G.shape[1])]
    print(jnp.sum(jnp.array(v), axis=0))
    
v = [Zr[:, j, :] @ Zr[:, j, :].T for j in range(Zr.shape[1])]
print(jnp.sum(jnp.array(v), axis=0))

[[34549434.73187065]]
[[ 1.00000000e+00 -2.08166817e-17  2.77555756e-17]
 [-2.08166817e-17  1.00000000e+00  1.38777878e-17]
 [ 2.77555756e-17  1.38777878e-17  1.00000000e+00]]
[[ 1.00000000e+00 -2.77555756e-17 -2.77555756e-17]
 [-2.77555756e-17  1.00000000e+00 -1.11022302e-16]
 [-2.77555756e-17 -1.11022302e-16  1.00000000e+00]]
[[ 1.00000000e+00  2.77555756e-17  4.16333634e-17]
 [ 2.77555756e-17  1.00000000e+00 -2.77555756e-17]
 [ 4.16333634e-17 -2.77555756e-17  1.00000000e+00]]
[[ 1.00000000e+00 -1.66533454e-16 -2.77555756e-17]
 [-1.66533454e-16  1.00000000e+00 -2.77555756e-17]
 [-2.77555756e-17 -2.77555756e-17  1.00000000e+00]]
[[ 1.00000000e+00 -1.80411242e-16  1.11022302e-16]
 [-1.80411242e-16  1.00000000e+00 -5.55111512e-17]
 [ 1.11022302e-16 -5.55111512e-17  1.00000000e+00]]
[[1.00000000e+00 3.12250226e-17 8.32667268e-17]
 [3.12250226e-17 1.00000000e+00 2.77555756e-16]
 [8.32667268e-17 2.77555756e-16 1.00000000e+00]]


## Function `orthogonalize_rtl_stab`

Orthogonalization for TT-tensor from right to left with stabilization factor.

In [7]:
rng, key = jax.random.split(rng)
Y = teneva.rand_norm(d=7, n=4, r=3, key=key)
Z_stab, p_stab = teneva.orthogonalize_rtl_stab(Y)
teneva.show(Z)

TT-tensor-jax | d =     7 | n =     4 | r =     3 |


We can verify that the values of the orthogonalized tensor have not changed:

In [8]:
Z = teneva.copy(Z_stab)
Z[0] *= 2**jnp.sum(p_stab)

Y_full = teneva.full(Y)
Z_full = teneva.full(Z)
e = jnp.max(jnp.abs(Y_full - Z_full))
print(f'Error     : {e:-8.2e}')

Error     : 2.56e-13


In [9]:
Zl, Zm, Zr = Z_stab

v = [Zl[:, j, :] @ Zl[:, j, :].T for j in range(Zl.shape[1])]
print(jnp.sum(jnp.array(v), axis=0))

for G in Zm:
    v = [G[:, j, :] @ G[:, j, :].T for j in range(G.shape[1])]
    print(jnp.sum(jnp.array(v), axis=0))
    
v = [Zr[:, j, :] @ Zr[:, j, :].T for j in range(Zr.shape[1])]
print(jnp.sum(jnp.array(v), axis=0))

[[7.15816805]]
[[ 1.00000000e+00  1.52655666e-16  0.00000000e+00]
 [ 1.52655666e-16  1.00000000e+00 -1.38777878e-17]
 [ 0.00000000e+00 -1.38777878e-17  1.00000000e+00]]
[[ 1.00000000e+00  5.55111512e-17 -2.77555756e-17]
 [ 5.55111512e-17  1.00000000e+00 -2.77555756e-17]
 [-2.77555756e-17 -2.77555756e-17  1.00000000e+00]]
[[ 1.00000000e+00 -6.24500451e-17 -2.77555756e-17]
 [-6.24500451e-17  1.00000000e+00  1.38777878e-17]
 [-2.77555756e-17  1.38777878e-17  1.00000000e+00]]
[[ 1.00000000e+00 -4.16333634e-17  0.00000000e+00]
 [-4.16333634e-17  1.00000000e+00 -9.71445147e-17]
 [ 0.00000000e+00 -9.71445147e-17  1.00000000e+00]]
[[ 1.00000000e+00 -2.77555756e-17 -1.24900090e-16]
 [-2.77555756e-17  1.00000000e+00  0.00000000e+00]
 [-1.24900090e-16  0.00000000e+00  1.00000000e+00]]
[[1.00000000e+00 1.94289029e-16 5.55111512e-17]
 [1.94289029e-16 1.00000000e+00 1.38777878e-17]
 [5.55111512e-17 1.38777878e-17 1.00000000e+00]]


---