**Demo for `teneva.core_jax.als`**

---

This module contains the function "als" which computes the TT-approximation for the tensor by TT-ALS algorithm, using given random samples.

## Loading and importing modules

In [1]:
import jax
import jax.numpy as np
import teneva as teneva_base
import teneva.core_jax as teneva
from time import perf_counter as tpc
rng = jax.random.PRNGKey(42)
from jax.config import config
config.update('jax_enable_x64', True)

## Function `als`

Build TT-tensor by TT-ALS method using given random tensor samples.

In [2]:
d = 50             # Dimension of the function
n = 20             # Shape of the tensor
r = 3              # TT-rank of the initial random tensor
nswp = 50          # Sweep number for ALS iterations
m = int(1.E+4)     # Number of calls to target function
m_tst = int(1.E+4) # Number of test points

We set the target function (the function takes as input a multi-index i of the shape [dimension], which is transformed into point x of a uniform spatial grid):

In [3]:
a = -2.048 # Lower bound for the spatial grid
b = +2.048 # Upper bound for the spatial grid

def func_base(i):
    """Michalewicz function."""
    x = i / n * (b - a) + a
    y1 = 100. * (x[1:] - x[:-1]**2)**2
    y2 = (x[:-1] - 1.)**2
    return np.sum(y1 + y2)

    y1 = np.sin(((np.arange(d) + 1) * x**2 / np.pi))
    return -np.sum(np.sin(x) * y1**(2 * 10))

func = jax.vmap(func_base)

We prepare train data from the LHS random distribution:

In [4]:
rng, key = jax.random.split(rng)
I_trn = teneva.sample_lhs(d, n, m, key)
y_trn = func(I_trn)

We prepare test data from a random tensor multi-indices:

In [5]:
rng, key = jax.random.split(rng)
I_tst = teneva.sample_rand(d, n, m_tst, key)
y_tst = func(I_tst)

We build the initial approximation by the TT-ANOVA method:

In [6]:
# TODO: replace with jax-version!
Y_anova_base = teneva_base.anova(I_trn, y_trn, r)
Y_anova = teneva.convert(Y_anova_base)

And now we will build the TT-tensor, which approximates the target function by the TT-ALS method:

In [7]:
t = tpc()
Y = teneva.als(I_trn, y_trn, Y_anova, nswp)
t = tpc() - t

print(f'Build time     : {t:-10.2f}')

Build time     :      11.86


We can check the accuracy of the result:

In [8]:
# Compute approximation in train points:
y_our = teneva.get_many(Y, I_trn)

# Accuracy of the result for train points:
e_trn = np.linalg.norm(y_our - y_trn)          
e_trn /= np.linalg.norm(y_trn)

# Compute approximation in test points:
y_our = teneva.get_many(Y, I_tst)

# Accuracy of the result for test points:
e_tst = np.linalg.norm(y_our - y_tst)          
e_tst /= np.linalg.norm(y_tst)

print(f'Error on train : {e_trn:-10.2e}')
print(f'Error on test  : {e_tst:-10.2e}')

Error on train :   4.24e-02
Error on test  :   4.24e+05


We can compare the result with the base (numpy) ALS method (we run it on the same train data with the same initial approximation and parameters):

In [9]:
t = tpc()
Y = teneva_base.als(I_trn, y_trn, Y_anova_base, nswp, e=-1.)
t = tpc() - t

print(f'Build time     : {t:-10.2f}')

# Compute approximation in train points:
y_our = teneva_base.get_many(Y, I_trn)

# Accuracy of the result for train points:
e_trn = np.linalg.norm(y_our - y_trn)          
e_trn /= np.linalg.norm(y_trn)

# Compute approximation in test points:
y_our = teneva_base.get_many(Y, I_tst)

# Accuracy of the result for test points:
e_tst = np.linalg.norm(y_our - y_tst)          
e_tst /= np.linalg.norm(y_tst)

print(f'Error on train : {e_trn:-10.2e}')
print(f'Error on test  : {e_tst:-10.2e}')

Build time     :      19.64
Error on train :   2.05e-02
Error on test  :   3.26e-01


---