### Tensor-train decomposition

In [6]:
import numpy as np

In [7]:
def random_tensor(ndim=None, shape=None, seed=None, ndim_bounds=(2, 5), shape_bounds=(2, 10), values_bounds=(0, 10)):
    if seed:
        np.random.seed(seed)
    if not ndim:
        ndim = np.random.randint(ndim_bounds[0], ndim_bounds[1], 1)
    if not shape:
        shape = np.random.randint(shape_bounds[0], shape_bounds[1], ndim)
    return np.random.uniform(values_bounds[0], values_bounds[1], shape)

In [8]:
ndim = 4
N = 32
directions = np.meshgrid(*([np.arange(N)]*ndim), indexing='ij')
A = np.array(directions).sum(axis=0).astype(float)

In [9]:
A_cur = A.reshape(A.shape[0], np.prod(A.shape) // A.shape[0])
U_list = []
r = [1]
shape = A.shape

for i in range(ndim-1):
    U, s, V = np.linalg.svd(A_cur, full_matrices=False)
    s = s[s > 1e-6]
    rank = len(s)
    r.append(len(s))
    U = U[:, :rank]; V = V[:rank, :]
    U = U.reshape(r[i], shape[i], r[i+1]) @ np.diag(s)
    U_list.append(U)

    if i < ndim - 2:
        A_cur = V.reshape(r[i+1] * shape[i+1], np.prod(shape[i+2:]))
    else:
        r.append(1)
        U_list.append(V.reshape(r[ndim-1], shape[ndim-1], r[ndim]))

In [10]:
A_approx = U_list[0].reshape(shape[0], r[1])
for i in range(1, ndim):
    A_approx = (A_approx @ U_list[i].reshape(r[i], shape[i]*r[i+1])).reshape(A_approx.shape[0]*shape[i], r[i+1])
np.linalg.norm(A_approx.reshape(shape) - A)

1.9528505702174035e-09