**Demo for `teneva.core_jax.data`**

---

This module contains functions for working with datasets, including "accuracy_on_data" function.

## Loading and importing modules

In [1]:
import jax
import jax.numpy as np
import teneva as teneva_base
import teneva.core_jax as teneva
from time import perf_counter as tpc
rng = jax.random.PRNGKey(42)

## Function `accuracy_on_data`

Compute the relative error of TT-tensor on the dataset.

Let generate a random TT-tensor:

In [2]:
d = 20  # Dimension of the tensor
n = 10  # Mode size of the tensor
r = 2   # TT-rank of the tensor

In [3]:
rng, key = jax.random.split(rng)
Y = teneva.rand(d, n, r, key)

Then we generate some random multi-indices, compute related
tensor values and add some noise:

In [4]:
m = 100 # Size of the dataset
I_data = teneva_base.sample_lhs([n]*d, m)
y_data = teneva.get_many(Y, I_data)

rng, key = jax.random.split(rng)
y_data = y_data + 1.E-5*jax.random.normal(key, (m, ))

And then let compute the accuracy:

In [5]:
eps = teneva.accuracy_on_data(Y, I_data, y_data)

print(f'Accuracy     : {eps:-8.2e}')

Accuracy     : 4.11e-04


---