tt-svd

tt-ортогонализация

tt-дожатие

tt-cross

tt-multifinc

In [2]:
import numpy as np
import scipy
import time
import itertools
import collections

### TT basics

In [3]:
def tensor_from_tt(factors):
    ndim = len(factors)
    res = factors[0].copy()
    for k in range(1, ndim):
        res = np.tensordot(res, factors[k], axes=([-1], [0]))
        # ? res = np.moveaxis(res, -1, k)
    return res.reshape(list(res.shape)[1:-1])

In [82]:
def reverse_tt(factors):
    new_factors = []
    for f in factors:
        new_factors.append(f.transpose((2,1,0)))
    return new_factors[::-1]

In [4]:
def ranks_from_factors(factors):
    ranks = [1]
    for f in factors:
        ranks.append(f.shape[2])
    return ranks

In [5]:
def dims_from_factors(factors):
    dims = []
    for f in factors:
        dims.append(f.shape[1])
    return dims

In [6]:
def t_less_k(factors, k):
    ndim = len(factors)
    ranks = ranks_from_factors(factors)
    dims = dims_from_factors(factors)
    if k <= 1 or k > ndim:
        print(f'wrong {k=}')
        return -1
    res = factors[0].copy()
    for j in range(1, k - 1):
        res = np.tensordot(res, factors[j], axes=([-1], [0]))
    return res.reshape( (np.prod(dims[:k-1]), ranks[k - 1]) )
    
def t_more_k(factors, k):
    ndim = len(factors)
    ranks = ranks_from_factors(factors)
    dims = dims_from_factors(factors)
    res = factors[k].copy()
    for j in range(k + 1, ndim):
        res = np.tensordot(res, factors[j], axes=([-1], [0]))
    return res.reshape( (ranks[k], np.prod(dims[k:])) )

In [7]:
def tt_add(factors_1, factors_2):
    factors = []
    new_cores = []
    ndim = len(factors_1)
    for i in range(ndim):
        f1 = factors_1[i]
        f2 = factors_2[i]
        f = None
        if i == 0:
            # first factor
            f = np.concatenate([f1, f2], axis=2)
        elif i == ndim - 1:
            # last factor
            f = np.concatenate([f1, f2], axis=0)
        else:
            r1, n, r_next1 = f1.shape
            r2, _, r_next2 = f2.shape
            
            f = np.zeros((r1 + r2, n, r_next1 + r_next2))
            f[:r1, :, :r_next1] = f1  
            f[r1:, :, r_next1:] = f2  
        if f is None:
            print('???')
            return -1
        factors.append(f)
    return factors

In [8]:
def tt_sub(factors_1, factors_2):
    factors = []
    new_cores = []
    ndim = len(factors_1)
    for i in range(ndim):
        f1 = factors_1[i]
        f2 = -factors_2[i]
        f = None
        if i == 0:
            # first factor
            f = np.concatenate([f1, f2], axis=2)
        elif i == ndim - 1:
            # last factor
            f = np.concatenate([f1, f2], axis=0)
        else:
            r1, n, r_next1 = f1.shape
            r2, _, r_next2 = f2.shape
            
            f = np.zeros((r1 + r2, n, r_next1 + r_next2))
            f[:r1, :, :r_next1] = f1  
            f[r1:, :, r_next1:] = f2  
        if f is None:
            print('???')
            return -1
        factors.append(f)
    return factors

### TT-SVD

In [9]:
def tt_svd(tensor, err, verbose=False):
    ndim = tensor.ndim
    dims = tensor.shape
    left = dims[0]
    cur_ranks = 1
    factors = []
    ranks = [1]
    
    errors = []
    cur_err = err / np.sqrt(ndim)
    A_k = tensor.copy()
    if verbose:
        print(f'{ndim=}')
        print(f'{dims=}')
        print()
    for k in range(ndim - 1):
        cur_err = err*err
        for e in errors:
            cur_err -= e*e
        cur_err = np.sqrt(cur_err/(ndim - k))
        errors.append(cur_err)
        A_k = A_k.reshape((ranks[-1] * dims[k], -1))
        if verbose:
            print(f'numiter {k + 1}')
            print('A_k shape =', A_k.shape)
        U, s, V = np.linalg.svd(A_k, full_matrices=False)
        sing_sum = 0.0
        r = s.shape[0]
        while r > 0 and sing_sum + s[r-1]*s[r-1] < cur_err:
            r -= 1
            sing_sum += s[r]*s[r]
        if r < 1:
            print('something is very bad')
        U = U[:,:r]
        s = s[:r]
        V = V[:r,:]
        if verbose:
            print('svd shapes')
            print('U', U.shape)
            print('s', s.shape)
            print('V', V.shape)
        factors.append(U.reshape((ranks[-1], dims[k], r)))
        ranks.append(r)
        A_k = np.diag(s) @ V
        if verbose:
            print('new factor shape =', factors[-1].shape)
            print()
    factors.append(A_k.reshape((ranks[-1], dims[-1], 1)))
    if verbose:
        print('end of iterations')
        print('last factor shape =', A_k.shape)
    print('\n\nsummary:')
    print('ranks', ranks + [1])
    total_tt = 0
    total = np.prod(dims)
    for f in factors:
        total_tt += np.prod(f.shape)
    print(f'{total_tt} elements in tt approx')
    print(f'{total} elements in input tensor')
    print(f'{total_tt / total * 100:.5f}% ') 
    return factors

In [11]:
sizes = np.array((20, 30, 40))
tensor = np.zeros(sizes)
for I in itertools.product(*(range(i) for i in sizes)):
    tensor[I] = np.sin(np.sum(I))
st = time.time()
factors = tt_svd(tensor, 1e-10, verbose=1)
print(f'tt_svd time = {time.time() - st:.5f} seconds')
tt_tensor = tensor_from_tt(factors)
err = np.linalg.norm(tensor - tt_tensor)
print(f'absolute error = {err**2}')
print(f'relative error = {err**2 / np.linalg.norm(tensor)**2}')

ndim=3
dims=(20, 30, 40)

numiter 1
A_k shape = (20, 1200)
svd shapes
U (20, 2)
s (2,)
V (2, 1200)
new factor shape = (1, 20, 2)

numiter 2
A_k shape = (60, 40)
svd shapes
U (60, 2)
s (2,)
V (2, 40)
new factor shape = (2, 30, 2)

end of iterations
last factor shape = (2, 40)


summary:
ranks [1, 2, 2, 1]
240 elements in tt approx
24000 elements in input tensor
1.00000% 
tt_svd time = 0.58767 seconds
absolute error = 5.733251645867613e-27
relative error = 4.777581766213576e-31


In [12]:
sizes = np.array((20, 30, 10, 20, 30))
tensor = np.zeros(sizes)
for I in itertools.product(*(range(i) for i in sizes)):
    tensor[I] = np.sin(np.sum(I))

st = time.time()
factors = tt_svd(tensor, 1e-10, verbose=0)
print(f'tt_svd time = {time.time() - st:.5f} seconds')
tt_tensor = tensor_from_tt(factors)
err = np.linalg.norm(tensor - tt_tensor)
print(f'absolute error = {err**2}')
print(f'relative error = {err**2 / np.linalg.norm(tensor)**2}')



summary:
ranks [1, 2, 2, 2, 2, 1]
340 elements in tt approx
3600000 elements in input tensor
0.00944% 
tt_svd time = 0.74657 seconds
absolute error = 2.482658347617053e-19
relative error = 1.3792547344231947e-25


In [13]:
sizes = np.array((20, 30, 10, 20, 30))
tensor = np.zeros(sizes)
for I in itertools.product(*(range(i) for i in sizes)):
    tensor[I] = np.sum(I)

st = time.time()
factors = tt_svd(tensor, 1e-10, verbose=0)
print(f'tt_svd time = {time.time() - st:.5f} seconds')
tt_tensor = tensor_from_tt(factors)
err = np.linalg.norm(tensor - tt_tensor)
print(f'absolute error = {err**2}')
print(f'relative error = {err**2 / np.linalg.norm(tensor)**2}')



summary:
ranks [1, 2, 2, 2, 2, 1]
340 elements in tt approx
3600000 elements in input tensor
0.00944% 
tt_svd time = 0.22744 seconds
absolute error = 1.1429451480899751e-15
relative error = 1.0650872687447352e-25


In [14]:
sizes = np.array((20, 30, 10, 20, 30))
tensor = np.zeros(sizes)
for I in itertools.product(*(range(i) for i in sizes)):
    tensor[I] = 1 / (1 + np.sum(I))

st = time.time()
factors = tt_svd(tensor, 1e-10, verbose=0)
print(f'tt_svd time = {time.time() - st:.5f} seconds')
tt_tensor = tensor_from_tt(factors)
err = np.linalg.norm(tensor - tt_tensor)
print(f'absolute error = {err**2}')
print(f'relative error = {err**2 / np.linalg.norm(tensor)**2}')



summary:
ranks [1, 8, 9, 9, 9, 1]
5020 elements in tt approx
3600000 elements in input tensor
0.13944% 
tt_svd time = 0.29997 seconds
absolute error = 1.367531411111444e-11
relative error = 7.722935565398141e-15


In [15]:
sizes = [4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
tensor = np.zeros(sizes)
for I in itertools.product(*(range(i) for i in sizes)):
    tensor[I] = np.sin(np.sum(I))

st = time.time()
factors = tt_svd(tensor, 1e-10, verbose=0)
print(f'tt_svd time = {time.time() - st:.5f} seconds')
tt_tensor = tensor_from_tt(factors)
err = np.linalg.norm(tensor - tt_tensor)
print(f'absolute error = {err**2}')
print(f'relative error = {err**2 / np.linalg.norm(tensor)**2}')



summary:
ranks [1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1]
144 elements in tt approx
1048576 elements in input tensor
0.01373% 
tt_svd time = 0.03397 seconds
absolute error = 4.977127823266276e-19
relative error = 9.49311843262455e-25


In [16]:
sizes = [200,200,200]
tensor = np.zeros(sizes)
for I in itertools.product(*(range(i) for i in sizes)):
    tensor[I] = np.sin(np.sum(I))

st = time.time()
factors = tt_svd(tensor, 1e-10, verbose=0)
print(f'tt_svd time = {time.time() - st:.5f} seconds')
tt_tensor = tensor_from_tt(factors)
err = np.linalg.norm(tensor - tt_tensor)
print(f'absolute error = {err**2}')
print(f'relative error = {err**2 / np.linalg.norm(tensor)**2}')



summary:
ranks [1, 2, 2, 1]
1600 elements in tt approx
8000000 elements in input tensor
0.02000% 
tt_svd time = 0.48391 seconds
absolute error = 1.72910324082855e-23
relative error = 4.322757500929852e-30


### TT random tensor

In [10]:
def generate_tt(ranks, dims):
    ranks = [1] + ranks + [1]
    factors = []
    for i in range(len(dims)):
        shape = (ranks[i], dims[i], ranks[i + 1])
        f = np.zeros(shape)
        for I in itertools.product(*(range(i) for i in shape)):
            f[I] = np.random.rand() * 10
        factors.append(f)
    return factors

In [18]:
ranks = [10, 10, 10]
dims = [50, 50, 50, 50]
random_factors = generate_tt(ranks, dims)
random_tensor = tensor_from_tt(random_factors)
st = time.time()
tt_svd_factors = tt_svd(random_tensor, 1e-10, verbose=False)
tt_tensor = tensor_from_tt(tt_svd_factors)
print(f'tt_svd time = {time.time() - st:.5f} seconds')
err = np.linalg.norm(random_tensor - tt_tensor)
print(f'absolute error = {err**2}')
print(f'relative error = {err**2 / np.linalg.norm(random_tensor)**2}')



summary:
ranks [1, 10, 10, 10, 1]
11000 elements in tt approx
6250000 elements in input tensor
0.17600% 
tt_svd time = 0.44000 seconds
absolute error = 8.66952031709355e-12
relative error = 3.2802516240818e-30


### TT-orthogonalization

In [11]:
# left to right orthogonalization
def tt_orthogonalize_lr(factors):
    new_factors = []
    ranks = ranks_from_factors(factors)
    new_ranks = [1]
    dims = dims_from_factors(factors)
    cur = factors[0]
    for k in range(len(factors) - 1):
        prev_r, n, next_r = cur.shape
        cur = cur.reshape((-1, next_r))
        Q, R = np.linalg.qr(cur)
        new_factors.append(Q.reshape((prev_r, n, -1)))
        new_ranks.append(Q.shape[-1])
        cur = np.tensordot(R, factors[k + 1], axes=(-1,0))
    new_factors.append(cur.reshape((new_ranks[-1], dims[-1], 1)))
    return new_factors

# right to left orthogonalization
def tt_orthogonalize_rl(factors):
    new_factors = []
    ranks = ranks_from_factors(factors)
    new_ranks = [1]
    dims = dims_from_factors(factors)
    cur = factors[-1]
    for k in range(len(factors) - 1, 0, -1):
        r_prev, n, r_next = cur.shape
        cur = cur.reshape((r_prev, -1))
        Q, R = np.linalg.qr(cur.T)
        Q = Q.T
        L = R.T
        new_ranks.append(Q.shape[0])
        new_factors.append(Q.reshape((-1, n, r_next)))
        cur = np.tensordot(factors[k - 1], L, axes=(-1,0))
    new_factors.append(cur.reshape((1, dims[0], new_ranks[-1])))
    return new_factors[::-1]

In [20]:
# orthogonalized left to right => (T<d).T @ T<d = I
ranks = [4, 6, 3, 5, 3, 4]
dims = [6, 8, 7, 5, 7, 6, 8]
random_factors = generate_tt(ranks, dims)
random_factors_norm = tt_orthogonalize_lr(random_factors)
t_less_dim = t_less_k(random_factors, len(dims))
t_norm_less_dim = t_less_k(random_factors_norm, len(dims))
print(t_less_dim.shape)
print(t_norm_less_dim.shape)
print('before normalization (T<d).T @ T<d')
print(t_less_dim.T @ t_less_dim)
print()
print('after normalization (T<d).T @ T<d')
print(t_norm_less_dim.T @ t_norm_less_dim)
print()
print(f'must be close to I_{ranks[-1]}')
print(np.isclose(t_norm_less_dim.T @ t_norm_less_dim, np.eye(ranks[-1])))
print()
random_tensor = tensor_from_tt(random_factors)
random_tensor_norm = tensor_from_tt(random_factors_norm)
print('random tensor - random tensor after normalization norm')
print('absolute error =', np.linalg.norm(random_tensor - random_tensor_norm)**2)
print('relative error =', np.linalg.norm(random_tensor - random_tensor_norm)**2 / np.linalg.norm(random_tensor)**2)

(70560, 4)
(70560, 4)
before normalization (T<d).T @ T<d
[[1.51466013e+19 1.88704508e+19 2.04295224e+19 1.75425082e+19]
 [1.88704508e+19 2.73938095e+19 2.77563194e+19 2.18512833e+19]
 [2.04295224e+19 2.77563194e+19 3.26129228e+19 2.43045506e+19]
 [1.75425082e+19 2.18512833e+19 2.43045506e+19 2.08377425e+19]]

after normalization (T<d).T @ T<d
[[ 1.00000000e+00 -4.84138622e-17  4.00595762e-17 -7.64870751e-18]
 [-4.84138622e-17  1.00000000e+00 -1.25276173e-18 -5.56195714e-17]
 [ 4.00595762e-17 -1.25276173e-18  1.00000000e+00 -6.49470983e-17]
 [-7.64870751e-18 -5.56195714e-17 -6.49470983e-17  1.00000000e+00]]

must be close to I_4
[[ True  True  True  True]
 [ True  True  True  True]
 [ True  True  True  True]
 [ True  True  True  True]]

random tensor - random tensor after normalization norm
absolute error = 1.4565857908710457e-08
relative error = 2.118560892988933e-31


In [22]:
# orthogonalized right to left => (T>1).T @ T>1 = I
ranks = [4, 6, 3, 5, 3, 5]
dims = [5, 7, 9, 6, 7, 9, 6]
random_factors = generate_tt(ranks, dims)
random_factors_norm = tt_orthogonalize_rl(random_factors)
t_more_1 = t_more_k(random_factors, 1)
t_norm_more_1 = t_more_k(random_factors_norm, 1)
print(t_more_1.shape)
print(t_norm_more_1.shape)
print('before normalization T>1 @ (T>1).T')
print(t_more_1 @ t_more_1.T)
print()
print('after normalization T>1 @ (T>1).T')
print(t_norm_more_1 @ t_norm_more_1.T)
print()
print(f'must be close to I_{ranks[0]}')
print(np.isclose(t_norm_more_1 @ t_norm_more_1.T, np.eye(ranks[0])))
print()
random_tensor = tensor_from_tt(random_factors)
random_tensor_norm = tensor_from_tt(random_factors_norm)
print('random tensor - random tensor after normalization norm')
print('absolute error =', np.linalg.norm(random_tensor - random_tensor_norm)**2)
print('relative error =', np.linalg.norm(random_tensor - random_tensor_norm)**2 / np.linalg.norm(random_tensor)**2)

(4, 142884)
(4, 142884)
before normalization T>1 @ (T>1).T
[[1.13274544e+20 1.10764463e+20 1.04350709e+20 1.12042046e+20]
 [1.10764463e+20 1.14512158e+20 1.07236950e+20 1.16834402e+20]
 [1.04350709e+20 1.07236950e+20 1.09188313e+20 1.15224196e+20]
 [1.12042046e+20 1.16834402e+20 1.15224196e+20 1.26066373e+20]]

after normalization T>1 @ (T>1).T
[[ 1.00000000e+00 -2.71402909e-16  6.86503263e-17 -1.65340831e-16]
 [-2.71402909e-16  1.00000000e+00 -1.84368579e-16 -1.16700811e-16]
 [ 6.86503263e-17 -1.84368579e-16  1.00000000e+00  7.07950137e-17]
 [-1.65340831e-16 -1.16700811e-16  7.07950137e-17  1.00000000e+00]]

must be close to I_4
[[ True  True  True  True]
 [ True  True  True  True]
 [ True  True  True  True]
 [ True  True  True  True]]

random tensor - random tensor after normalization norm
absolute error = 4.6884066073893393e-07
relative error = 1.8400996838772246e-30


### TT дожатие

In [12]:
def tt_compression(factors, err, verbose=False):
    new_factors = [factors[0]]
    ranks = ranks_from_factors(factors)
    dims = dims_from_factors(factors)
    ndim = len(dims)
    new_ranks = [1]
    errors = []
    cur_err = err / np.sqrt(ndim)
    for k in range(0, len(factors) - 1):
        cur_err = err*err
        for e in errors:
            cur_err -= e*e
        cur_err = np.sqrt(cur_err/(ndim - k))
        errors.append(cur_err)

        cur = new_factors[k]
        if verbose:
            print(f'numiter {k + 1}')
            print('cur shape =', cur.shape)
        U, s, V = np.linalg.svd(cur.reshape( (new_ranks[k] * dims[k], -1) ), full_matrices=False)
        sing_sum = 0.0
        r = s.shape[0]
        while r > 0 and sing_sum + s[r-1]*s[r-1] < cur_err:
            r -= 1
            sing_sum += s[r]*s[r]
        if r < 1:
            print('something is very bad')
        U = U[:,:r]
        s = s[:r]
        V = V[:r,:]
        if verbose:
            print('svd shapes')
            print('U', U.shape)
            print('s', s.shape)
            print('V', V.shape)
            
        new_factors[k] = U.reshape((new_ranks[k], dims[k], r))
        new_factors.append(np.tensordot(np.diag(s) @ V, factors[k + 1], axes=(1,0)))
        new_ranks.append(r)
        if verbose:
            print('new factor shape =', new_factors[-2].shape)
            print()

    if verbose:
        print('end of iterations')
        print('last factor shape =', new_factors[-1].shape)
    print('\n\nsummary:')
    print('old ranks', ranks)
    print('new ranks', new_ranks + [1])
    print()
    return new_factors

чтобы проверить сделаем два тензора ортогонализованных справа налево, суммируем их и дожмем

In [24]:
ranks = [5, 5, 5, 5]
dims = [30, 30, 30, 30, 30]

f1 = tt_orthogonalize_rl(generate_tt(ranks, dims))
f2 = tt_orthogonalize_rl(generate_tt(ranks, dims))

In [25]:
f = tt_add(tt_add(tt_add(f1, f2), tt_add(f1, f2)), f2)
f = tt_orthogonalize_rl(f)
f_compressed = tt_compression(f, 1e-10)
t = tensor_from_tt(f)
t_compressed = tensor_from_tt(f_compressed)
print('absolute error =', np.linalg.norm(t - t_compressed)**2)
print('realtive error =', np.linalg.norm(t - t_compressed)**2 / np.linalg.norm(t)**2)



summary:
old ranks [1, 25, 25, 25, 25, 1]
new ranks [1, 10, 10, 10, 10, 1]

absolute error = 1.5060371596196425e-08
realtive error = 6.544462280735583e-30


In [26]:
f = tt_add(tt_add(tt_add(f1, f2), tt_add(f1, f2)), f2)
#f = tt_orthogonalize_rl(f)
f_compressed = tt_compression(f, 1e-10)
t = tensor_from_tt(f)
t_compressed = tensor_from_tt(f_compressed)
print('absolute error =', np.linalg.norm(t - t_compressed)**2)
print('realtive error =', np.linalg.norm(t - t_compressed)**2 / np.linalg.norm(t)**2)



summary:
old ranks [1, 25, 25, 25, 25, 1]
new ranks [1, 10, 10, 11, 11, 1]

absolute error = 2.700911684789305e-09
realtive error = 1.1736771919469627e-30


In [27]:
ranks = [[2, 3, 4, 5],
         [5, 4, 2, 3]]
dims = [21, 17, 34, 62, 50]

f1 = generate_tt(ranks[0], dims)
f2 = generate_tt(ranks[1], dims)

In [28]:
f = tt_add(f1, f1)
f = tt_orthogonalize_rl(f)
f_compressed = tt_compression(f, 1e-10)
t = tensor_from_tt(f)
t_compressed = tensor_from_tt(f_compressed)
print('absolute error =', np.linalg.norm(t - t_compressed)**2)
print('realtive error =', np.linalg.norm(t - t_compressed)**2 / np.linalg.norm(t)**2)



summary:
old ranks [1, 4, 6, 8, 10, 1]
new ranks [1, 2, 3, 4, 5, 1]

absolute error = 2.6899266768565207e-11
realtive error = 7.700568733285868e-31


In [29]:
f = tt_add(f2, f2)
f = tt_orthogonalize_rl(f)
f_compressed = tt_compression(f, 1e-10)
t = tensor_from_tt(f)
t_compressed = tensor_from_tt(f_compressed)
print('absolute error =', np.linalg.norm(t - t_compressed)**2)
print('realtive error =', np.linalg.norm(t - t_compressed)**2 / np.linalg.norm(t)**2)



summary:
old ranks [1, 10, 8, 4, 6, 1]
new ranks [1, 5, 4, 2, 3, 1]

absolute error = 3.119181386051331e-11
realtive error = 1.175858124800576e-30


In [30]:
ranks = [[1, 2, 3, 4, 5, 1],
         [1, 5, 4, 2, 3, 1]]
dims = [21, 17, 34, 62, 50]
f1 = []
for k in range(len(dims)):
    shape = (ranks[0][k], dims[k], ranks[0][k + 1])
    cur = np.zeros(shape)
    for I in itertools.product(*(range(i) for i in shape)):
        cur[I] = np.sin(np.sum(I))
    f1.append(cur)

f2 = []
for k in range(len(dims)):
    shape = (ranks[1][k], dims[k], ranks[1][k + 1])
    cur = np.zeros(shape)
    for I in itertools.product(*(range(i) for i in shape)):
        cur[I] = np.sin(np.sum(I) + 17)
    f2.append(cur)

f = tt_add(f1, f2)
f = tt_orthogonalize_rl(f)
f_compressed = tt_compression(f, 1e-10)
t = tensor_from_tt(f)
t_compressed = tensor_from_tt(f_compressed)
print('absolute error =', np.linalg.norm(t - t_compressed)**2)
print('realtive error =', np.linalg.norm(t - t_compressed)**2 / np.linalg.norm(t)**2)



summary:
old ranks [1, 7, 7, 6, 8, 1]
new ranks [1, 2, 4, 4, 2, 1]

absolute error = 2.9166658329335427e-20
realtive error = 1.1809628508593548e-29


### TT-CROSS

#### maxvol

In [13]:
'''
####
####
####
####
####
####
'''
def maxvol_cols(A):
    m, r = A.shape
    rows = np.zeros(m, dtype=bool) 
    rows[np.random.choice(m, r, replace=False)] = True
    sub = A[rows, :]
    it = 0
    while np.linalg.cond(sub) >= 1e16 and it < m:
        rows = np.zeros(m, dtype=bool) 
        rows[np.random.choice(m, r, replace=False)] = True
        sub = A[rows, :]
        it += 1
    CA = A @ np.linalg.inv(sub)
    it = 0
    while it < m:
        i, j = np.unravel_index(np.argmax(np.abs(CA), axis=None), CA.shape)
        if np.isclose(np.abs(CA[i, j]), 1.0):
            break
        if rows[i]:
            break
        num = 0
        cnt = 0
        k = -1    
        while cnt != j + 1:
            k += 1
            if rows[k]:
                cnt += 1
        rows[k] = False
        rows[i] = True
        ej = np.zeros(r)
        ej[j] = 1
        CA -= 1/CA[i, j] * np.outer(CA[:, j], ( CA[i, :] - ej))
        it += 1
    return rows

'''
#################
#################
#################
'''
def maxvol_rows(A):
    r, n = A.shape
    cols = np.zeros(n, dtype=bool)
    cols[np.random.choice(n, r, replace=False)] = True
    sub = A[:, cols]
    it = 0
    while np.linalg.cond(sub) >= 1e16 and it < n:
        cols = np.zeros(n, dtype=bool)
        cols[np.random.choice(n, r, replace=False)] = True
        sub = A[:, cols]
        it += 1
    AR = np.linalg.inv(sub) @ A
    it = 0
    while it < n:
        i, j = np.unravel_index(np.argmax(np.abs(AR), axis=None), AR.shape)
        if np.isclose(np.abs(AR[i, j]), 1.0):
            break
        if cols[j]:
            break
        cnt = 0
        k = -1
        while cnt != i + 1:
            k += 1
            if cols[k]:
                cnt += 1
        cols[k] = False
        cols[j] = True
        ei = np.zeros(r)
        ei[i] = 1
        AR -= 1/AR[i, j] * np.outer(AR[:, j] - ei, AR[i, :])
        it += 1
    return cols

'''
def cartesian_prod(arrays, out=None):
    la = len(arrays)
    L = *map(len, arrays), la
    dtype = np.result_type(*arrays)
    arr = np.empty(L, dtype=dtype)
    arrs = *itertools.accumulate(itertools.chain((arr,), itertools.repeat(0, la-1)), np.ndarray.__getitem__),
    idx = slice(None), *itertools.repeat(None, la-1)
    for i in range(la-1, 0, -1):
        arrs[i][..., i] = arrays[i][idx[:la-i]]
        arrs[i-1][1:] = arrs[i]
    arr[..., 0] = arrays[0][idx]
    return arr.reshape(-1, la)

def calculate_tensor_J(fv, row_hat, J):
    r = len(J)
    calc_shape = tuple(dims[:-1]) + (r,)
    calc = np.zeros(calc_shape)
    J = np.array(J).reshape(-1)
    cart = cartesian_prod(arrays=(*row_hat.T, *J.T))
    left_ind = np.ndindex(calc_shape)
    for I in left_ind:
        calc[I] = fv(cart[np.ravel_multi_index(I, calc_shape)])
    return calc

def calculate_tensor_I(fv, col_hat, I):
    r = len(I)
    calc_shape = (r,) + tuple(dims[1:])
    calc = np.zeros(calc_shape)
    I = np.array(I).reshape(-1)
    cart = cartesian_prod(arrays=(I, *[np.arange(d) for d in dims[1:]]))
    left_ind = np.ndindex(calc_shape)
    for I in left_ind:
        calc[I] = fv(cart[np.ravel_multi_index(I, calc_shape)])
    return calc

def fix_cols(cols):
    for i in range(len(cols[-1])):
        cols[-1][i][0] = i
    return cols
'''


'\ndef cartesian_prod(arrays, out=None):\n    la = len(arrays)\n    L = *map(len, arrays), la\n    dtype = np.result_type(*arrays)\n    arr = np.empty(L, dtype=dtype)\n    arrs = *itertools.accumulate(itertools.chain((arr,), itertools.repeat(0, la-1)), np.ndarray.__getitem__),\n    idx = slice(None), *itertools.repeat(None, la-1)\n    for i in range(la-1, 0, -1):\n        arrs[i][..., i] = arrays[i][idx[:la-i]]\n        arrs[i-1][1:] = arrs[i]\n    arr[..., 0] = arrays[0][idx]\n    return arr.reshape(-1, la)\n\ndef calculate_tensor_J(fv, row_hat, J):\n    r = len(J)\n    calc_shape = tuple(dims[:-1]) + (r,)\n    calc = np.zeros(calc_shape)\n    J = np.array(J).reshape(-1)\n    cart = cartesian_prod(arrays=(*row_hat.T, *J.T))\n    left_ind = np.ndindex(calc_shape)\n    for I in left_ind:\n        calc[I] = fv(cart[np.ravel_multi_index(I, calc_shape)])\n    return calc\n\ndef calculate_tensor_I(fv, col_hat, I):\n    r = len(I)\n    calc_shape = (r,) + tuple(dims[1:])\n    calc = np.zeros

#### cross

In [14]:
def tt_cross(f, dims, maxrank, rtol=1e-4, verbose=False):
    '''
    f must take np.array() as an input and output single number
    '''
    fv = np.vectorize(f, signature='(m)->()')
    ndim = len(dims)
    maxrank = min(maxrank, max(*dims))
    maxiter = max(*dims) * 2
    if verbose:
        print(f'{dims=}')
        print(f'{ndim=}')
        print(f'{maxrank=}')
    r = 1
    rerr = 1e16
    prev_rerr = 1e16
    cols = np.array([np.random.choice(dims[-1], r, replace=False)] + [np.random.choice(d, r) for d in dims[-2:0:-1]])
    cols = np.array([cols[:, i] for i in range(r)])
    for i in range(len(cols[0]) - 1, -1, -1):
        cols = cols[cols[:,i].argsort(kind='mergesort')]
    cols = [np.array([list(reversed(j[:k+1])) for j in cols]) for k in reversed(range(ndim - 1))]
    rows = []
    ranks = np.ones(ndim + 1, dtype=np.int32)
    update_ranks = np.ones(ndim + 1, dtype=bool)
    old_factors = []
    for i in range(ndim):
        old_factors.append(np.zeros((1, dims[i], 1)))
    factors = []
    if verbose:
        print('start cols')
        print(cols)
        print(maxrank, rtol, maxiter)
    it = 0
    while r <= maxrank and rerr > rtol and it < maxiter:
        print(f'iteration {it + 1}/{maxiter}')
        # update ranks
        prev_ranks = ranks.copy()
        # r += 1
        for i in range(1, ndim):
            if ranks[i] < maxrank and update_ranks[i] and ranks[i] < dims[i - 1]:
                ranks[i] += 1
        # r_k <= n_k * r_k-1
        for i in range(1, ndim):
            if ranks[i] > dims[i - 1] * ranks[i - 1]: 
                ranks[i] = dims[i - 1] * ranks[i - 1]
        # r_k-1 <= n_k * r_k
        for i in range(ndim, 0, -1):
            if dims[i - 1] * ranks[i] < ranks[i - 1]: 
                ranks[i - 1] = dims[i - 1] * ranks[i]
        r = np.max(ranks)
        if verbose:
            print('new maxrank', r)
            print('new ranks', ranks)
            
        if np.sum(ranks - prev_ranks) <= 0:
            pass    
        print('global rank update =', (ranks - prev_ranks)[1:-1])

        
        # have columns. search rows
        new_cols = cols.copy()
        new_rows = []
        update_ranks = np.ones(ndim + 1, dtype=bool)
        row_hat = np.array([[i] for i in range(dims[0])])
        if verbose:
            print('start row iter')
        for k in range(ndim - 1):
            if verbose:
                print(f'row iter {k + 1}')

            # add new random column
            if len(new_cols[k]) < ranks[k + 1]:
                if k == len(cols) - 1:
                    chosen = set(new_cols[-1].flatten())
                    free = list(set(range(dims[-1])) - chosen)
                    new = np.random.choice(free)
                    new_cols[-1] = np.vstack( (new_cols[-1], [new]) )
                else:
                    prev = new_cols[k + 1]
                    d = dims[k + 1]
                    available = np.array([[i, *j] for i in range(d) for j in prev])
                    new = np.random.choice(len(available))
                    if (available[new] == new_cols[k]).all(axis=1).any():
                        new = np.random.choice(len(available))
                    new_cols[k] = np.vstack( (new_cols[k], available[new]) )
                    
            ind = np.array([[*i, *j] for i in row_hat for j in new_cols[k]]) # rn x r
            calc = fv(ind).reshape((ranks[k] * dims[k], ranks[k + 1]))
            calc_rank = np.linalg.matrix_rank(calc)
            if calc_rank < ranks[k + 1]:
                new_cols[k] = cols[k]
                ranks[k + 1] -= 1
                update_ranks[k + 1] = False
                ind = np.array([[*i, *j] for i in row_hat for j in new_cols[k]]) # rn x r
                calc = fv(ind).reshape((ranks[k] * dims[k], ranks[k + 1]))
            if verbose:
                print(cols[k].shape)
                print(ranks[k])
                print(dims[k])
                print(ind.shape)
                print(calc.shape)
            row = maxvol_cols(calc)
            if verbose:
                print(row.shape)
                print(np.sum(row))
                print(row_hat.shape)
                
            row = row_hat[row]
            if verbose:
                print(row.shape)
            new_rows.append(row)
            row_hat = np.array([[*row[i], j] for i in range(len(row)) for j in range(dims[k + 1])])
        cols = new_cols
        # have rows, search columns
        new_cols = []
        factors = collections.deque()
        col_hat = np.array([[i] for i in range(dims[-1])])
        if verbose:
            print('start col iter')
        for k in range(ndim - 1, 0, -1):
            if verbose:
                print(f'col iter {ndim - k}')
            ind = np.array([[*i, *j] for i in new_rows[k - 1] for j in col_hat]) # r_i x d_i x r_i+1             
            calc = fv(ind).reshape((-1, dims[k] * ranks[k + 1]))
            calc_rank = np.linalg.matrix_rank(calc)
            if calc_rank < ranks[k]:
                print('DAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAANGER')
                new_rows[k - 1] = rows[k - 1]
                ranks[k] -= 1
                update_ranks[k] = False
                ind = np.array([[*i, *j] for i in new_rows[k - 1] for j in col_hat]) # r_i x d_i x r_i+1             
                calc = fv(ind).reshape((-1, dims[k] * ranks[k + 1]))
            if verbose:
                print(new_rows[k - 1].shape)
                print(ranks[k + 1])
                print(dims[k])
                print(ind.shape)
                print(calc.shape)
            orig_col = maxvol_rows(calc)
            if verbose:
                print(orig_col.shape)
                print(np.sum(orig_col))
                print(row_hat.shape)
            col = col_hat[orig_col]
            if verbose:
                print(col.shape)
            new_cols.append(col)
            
            col_hat = np.array([[i, *col[j]] for i in range(dims[k - 1]) for j in range(len(col))])
            factors.appendleft( (np.linalg.inv(calc[:,orig_col]) @ calc).reshape((ranks[k], dims[k], ranks[k + 1])) )
        factors.appendleft( fv(col_hat).reshape((ranks[0], dims[0], ranks[1])) )        
        cols = list(reversed(new_cols))
        rows = new_rows
        
        print('ranks', ranks)
        print('dims and factors shapes')
        print(dims)
        for i in factors:
            print(i.shape)
        # error
        old_factors[0] *= -1
        tt_diff = tt_add(old_factors, factors)
        if verbose:
            print('diff shapes')
            for i in tt_diff:
                print(i.shape)
        '''
        tt_diff_compr = tt_compression(tt_diff, 1e-10)
        print('diff shapes after compression')
        for i in tt_diff_compr:
            print(i.shape)
        tt_diff_ort = tt_orthogonalize_lr(tt_diff_compr)
        '''
        tt_diff_ort = tt_orthogonalize_lr(tt_diff)
        rerr = np.power(np.linalg.norm(tt_diff_ort[-1]), 2)
        print(f'{rerr=}')
        if rerr > prev_rerr:
            print()
            print()
            print('summary:')
            print('tt ranks:', ranks[1:-1])
            return old_factors
        prev_rerr = rerr
        old_factors = factors
        it += 1
        if np.sum(update_ranks) == 2:
            pass
            # all ranks staying same
            # break
        print()
    print()
    print('summary:')
    print('tt ranks:', ranks[1:-1])
    return factors

In [15]:
dims = np.array((10, 10, 10, 10, 10))
def f(a):
    return 1 / (np.sum(a) + 1)
    
tensor = np.zeros(dims)
for I in itertools.product(*(range(i) for i in dims)):
    tensor[I] = f(I)

In [17]:
print('start cross')
st = time.time()
factors = tt_cross(f, dims, max(dims), 1e-2, verbose=False)
print(f'tt_cross time = {time.time() - st:.5f} seconds')
tt_tensor = tensor_from_tt(factors)
err = np.power(np.linalg.norm(tensor - tt_tensor), 2)
print(f'absolute error = {err}')
tensor_norm = np.power(np.linalg.norm(tensor), 2)
print(f'initial tensor norm = {tensor_norm}')
print(f'relative error = {err / tensor_norm}')

start cross
iteration 1/20
global rank update = [1 1 1 1]
ranks [1 2 2 2 2 1]
dims and factors shapes
[10 10 10 10 10]
(1, 10, 2)
(2, 10, 2)
(2, 10, 2)
(2, 10, 2)
(2, 10, 1)
rerr=311.72957724555323

iteration 2/20
global rank update = [1 1 1 1]
ranks [1 3 3 3 3 1]
dims and factors shapes
[10 10 10 10 10]
(1, 10, 3)
(3, 10, 3)
(3, 10, 3)
(3, 10, 3)
(3, 10, 1)
rerr=6.246428502576554

iteration 3/20
global rank update = [1 1 1 1]
ranks [1 4 4 4 4 1]
dims and factors shapes
[10 10 10 10 10]
(1, 10, 4)
(4, 10, 4)
(4, 10, 4)
(4, 10, 4)
(4, 10, 1)
rerr=0.32745397689208156

iteration 4/20
global rank update = [1 1 1 1]
ranks [1 5 5 5 5 1]
dims and factors shapes
[10 10 10 10 10]
(1, 10, 5)
(5, 10, 5)
(5, 10, 5)
(5, 10, 5)
(5, 10, 1)
rerr=0.038144162633155655

iteration 5/20
global rank update = [1 1 1 1]
ranks [1 6 6 6 6 1]
dims and factors shapes
[10 10 10 10 10]
(1, 10, 6)
(6, 10, 6)
(6, 10, 6)
(6, 10, 6)
(6, 10, 1)
rerr=0.0007783428387427349


summary:
tt ranks: [6 6 6 6]
tt_cross time = 0.

In [18]:
dims = np.array((10, 13, 11, 12, 14))
def f(a):
    return np.sin(np.sum(a))
    
tensor = np.zeros(dims)
for I in itertools.product(*(range(i) for i in dims)):
    tensor[I] = f(I)

In [19]:
print('start cross')
st = time.time()
factors = tt_cross(f, dims, max(dims)/2, 1e-5, verbose=False)
print(f'tt_cross time = {time.time() - st:.5f} seconds')
tt_tensor = tensor_from_tt(factors)
err = np.power(np.linalg.norm(tensor - tt_tensor), 2)
print(f'absolute error = {err}')
print(f'relative error = {err / np.power(np.linalg.norm(tensor), 2)}')

start cross
iteration 1/28
global rank update = [1 1 1 1]
ranks [1 2 2 2 2 1]
dims and factors shapes
[10 13 11 12 14]
(1, 10, 2)
(2, 13, 2)
(2, 11, 2)
(2, 12, 2)
(2, 14, 1)
rerr=120120.00318607355

iteration 2/28
global rank update = [1 1 1 1]
ranks [1 2 2 2 2 1]
dims and factors shapes
[10 13 11 12 14]
(1, 10, 2)
(2, 13, 2)
(2, 11, 2)
(2, 12, 2)
(2, 14, 1)
rerr=1.261545854835002e-25


summary:
tt ranks: [2 2 2 2]
tt_cross time = 0.02300 seconds
absolute error = 1.1665792487876463e-25
relative error = 9.711781700342872e-31


In [20]:
dims = np.array((10, 13, 11, 12, 14))
weights = np.array((2.71, 3.14, 1.55, 2.2, 4.1234))
def f(a):
    return np.dot(a, weights)

tensor = np.zeros(dims)
for I in itertools.product(*(range(i) for i in dims)):
    tensor[I] = f(I)

In [21]:
print('start cross')
st = time.time()
factors = tt_cross(f, dims, max(dims)/2, 1e-5, verbose=False)
print(f'tt_cross time = {time.time() - st:.5f} seconds')
tt_tensor = tensor_from_tt(factors)
err = np.power(np.linalg.norm(tensor - tt_tensor), 2)
print(f'absolute error = {err}')
print(f'relative error = {err / np.power(np.linalg.norm(tensor), 2)}')

start cross
iteration 1/28
global rank update = [1 1 1 1]
ranks [1 2 2 2 2 1]
dims and factors shapes
[10 13 11 12 14]
(1, 10, 2)
(2, 13, 2)
(2, 11, 2)
(2, 12, 2)
(2, 14, 1)
rerr=1583638018.9854305

iteration 2/28
global rank update = [1 1 1 1]
ranks [1 2 2 2 2 1]
dims and factors shapes
[10 13 11 12 14]
(1, 10, 2)
(2, 13, 2)
(2, 11, 2)
(2, 12, 2)
(2, 14, 1)
rerr=2.6533220578961204e-19


summary:
tt ranks: [2 2 2 2]
tt_cross time = 0.01925 seconds
absolute error = 7.576326728572415e-23
relative error = 4.7841278358714097e-32


In [22]:
dims = np.array((200, 200, 200))
weights = np.array((2.71, 3.14, 1.55))
def f(a):
    return np.dot(a, weights)

tensor = np.zeros(dims)
for I in itertools.product(*(range(i) for i in dims)):
    tensor[I] = f(I)

In [23]:
print('start cross')
st = time.time()
factors = tt_cross(f, dims, max(dims)/2, 1e-5, verbose=False)
print(f'tt_cross time = {time.time() - st:.5f} seconds')
tt_tensor = tensor_from_tt(factors)
err = np.power(np.linalg.norm(tensor - tt_tensor), 2)
print(f'absolute error = {err}')
print(f'relative error = {err / np.power(np.linalg.norm(tensor), 2)}')

start cross
iteration 1/400
global rank update = [1 1]
ranks [1 2 2 1]
dims and factors shapes
[200 200 200]
(1, 200, 2)
(2, 200, 2)
(2, 200, 1)
rerr=4859920449199.997

iteration 2/400
global rank update = [1 1]
ranks [1 2 2 1]
dims and factors shapes
[200 200 200]
(1, 200, 2)
(2, 200, 2)
(2, 200, 1)
rerr=1.318682611190938e-19


summary:
tt ranks: [2 2]
tt_cross time = 0.16159 seconds
absolute error = 2.340469054166689e-19
relative error = 4.8158587751204884e-32


In [24]:
dims = np.array((4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4)) # 4**12
weights = np.array((2.71, 3.14, 2.55, 4.32, 6.23, 2.65, 3.42, 2.432, 7.23, 3.123, 4.123, 6.321))
def f(a):
    return np.dot(a, weights)

tensor = np.zeros(dims)
for I in itertools.product(*(range(i) for i in dims)):
    tensor[I] = f(I)

In [25]:
print('start cross')
st = time.time()
factors = tt_cross(f, dims, max(dims), 1e-5, verbose=False)
print(f'tt_cross time = {time.time() - st:.5f} seconds')
tt_tensor = tensor_from_tt(factors)
err = np.power(np.linalg.norm(tensor - tt_tensor), 2)
print(f'absolute error = {err}')
print(f'relative error = {err / np.power(np.linalg.norm(tensor), 2)}')

start cross
iteration 1/8
global rank update = [1 1 1 1 1 1 1 1 1 1 1]
ranks [1 2 2 2 2 2 2 2 2 2 2 2 1]
dims and factors shapes
[4 4 4 4 4 4 4 4 4 4 4 4]
(1, 4, 2)
(2, 4, 2)
(2, 4, 2)
(2, 4, 2)
(2, 4, 2)
(2, 4, 2)
(2, 4, 2)
(2, 4, 2)
(2, 4, 2)
(2, 4, 2)
(2, 4, 2)
(2, 4, 1)
rerr=92592071503.7733

iteration 2/8
global rank update = [1 1 1 1 1 1 1 1 1 1 1]
ranks [1 2 2 2 2 2 2 2 2 2 2 2 1]
dims and factors shapes
[4 4 4 4 4 4 4 4 4 4 4 4]
(1, 4, 2)
(2, 4, 2)
(2, 4, 2)
(2, 4, 2)
(2, 4, 2)
(2, 4, 2)
(2, 4, 2)
(2, 4, 2)
(2, 4, 2)
(2, 4, 2)
(2, 4, 2)
(2, 4, 1)
rerr=4.331772230932457e-14


summary:
tt ranks: [2 2 2 2 2 2 2 2 2 2 2]
tt_cross time = 0.03300 seconds
absolute error = 4.3577930718882016e-14
relative error = 4.706443004367371e-25


In [26]:
dims = np.array((4, 4, 4, 4))
weights = np.array((100, 1, 50, 1000))
def f(a):
    return np.dot(a, weights)

tensor = np.zeros(dims)
for I in itertools.product(*(range(i) for i in dims)):
    tensor[I] = f(I)

In [27]:
print('start cross')
st = time.time()
factors = tt_cross(f, dims, max(dims)/2, 1e-5, verbose=False)
print(f'tt_cross time = {time.time() - st:.5f} seconds')
tt_tensor = tensor_from_tt(factors)
err = np.power(np.linalg.norm(tensor - tt_tensor), 2)
print(f'absolute error = {err}')
print(f'relative error = {err / np.power(np.linalg.norm(tensor), 2)}')

start cross
iteration 1/8
global rank update = [1 1 1]
ranks [1 2 2 2 1]
dims and factors shapes
[4 4 4 4]
(1, 4, 2)
(2, 4, 2)
(2, 4, 2)
(2, 4, 1)
rerr=1087085696.000001

iteration 2/8
global rank update = [0 0 0]
ranks [1 2 2 2 1]
dims and factors shapes
[4 4 4 4]
(1, 4, 2)
(2, 4, 2)
(2, 4, 2)
(2, 4, 1)
rerr=8.426135696445293e-16


summary:
tt ranks: [2 2 2]
tt_cross time = 0.01000 seconds
absolute error = 8.512370029478303e-16
relative error = 7.830449853953651e-25


### TT-multifunc

TODO

In [165]:
def get_R(factors, cols):
    ndim = len(factors)
    dims = dims_from_factors(factors)
    R = [None] * (ndim - 1)
    R[-1] = factors[-1].reshape((-1, dims[-1]))[:, [j[0] for j in cols[-1]]]
    for k in reversed(range(ndim - 2)):
        rk = factors[k].shape[-1]
        R[k] = np.zeros((rk, len(cols[k])))
        for i, b in enumerate(cols[k]):
            R[k][:, i] = factors[k + 1][:, b[0], :] @ R[k + 1][:, cols[k + 1].tolist().index(b[1:].tolist())]
    return R

def tt_multifunc(tt_list, func, maxrank=10, rtol=1e-5, verbose=False):
    for i in range(len(tt_list) - 1):
        dims1 = dims_from_factors(tt_list[i])
        ranks1 = ranks_from_factors(tt_list[i])
        dims2 = dims_from_factors(tt_list[i + 1])
        ranks2 = ranks_from_factors(tt_list[i + 1])
        if dims1 != dims2 or ranks1 != ranks2:
            print(f'bad tensor dimensions {i + 1}, {i + 2}')
            print(f'dims {i + 1} -> {dims1}')
            print(f'dims {i + 2} -> {dims2}')
            print(f'ranks {i + 1} -> {ranks1}')
            print(f'ranks {i + 2} -> {ranks2}')

    dims = dims1
    ndim = len(dims)
    maxiter = max(*dims) * 2
    rerr = 1e16
    prev_rerr = 1e16
    p = len(tt_list)
    
    r = 1
    cols = np.array([np.random.choice(dims[-1], r, replace=False)] + [np.random.choice(d, r) for d in dims[-2:0:-1]])
    cols = np.array([cols[:, i] for i in range(r)])
    for i in range(len(cols[0]) - 1, -1, -1):
        cols = cols[cols[:,i].argsort(kind='mergesort')]
    cols = [np.array([list(reversed(j[:k+1])) for j in cols]) for k in reversed(range(ndim - 1))]
    rows = [None] * (ndim - 1)
    ranks = np.ones(ndim + 1, dtype=np.int32)
    update_ranks = np.ones(ndim + 1, dtype=bool)
    old_factors = []
    for i in range(ndim):
        old_factors.append(np.zeros((1, dims[i], 1)))
    factors = [None] * ndim

    it = 0
    while r <= maxrank and rerr > rtol and it < maxiter:
        print(f'iteration {it + 1}/{maxiter}')
        # update ranks
        prev_ranks = ranks.copy()
        # r += 1
        for i in range(1, ndim):
            if ranks[i] < maxrank and ranks[i] < dims[i - 1]:
                ranks[i] += 1
        # r_k <= n_k * r_k-1
        for i in range(1, ndim):
            if ranks[i] > dims[i - 1] * ranks[i - 1]: 
                ranks[i] = dims[i - 1] * ranks[i - 1]
        # r_k-1 <= n_k * r_k
        for i in range(ndim, 0, -1):
            if dims[i - 1] * ranks[i] < ranks[i - 1]: 
                ranks[i - 1] = dims[i - 1] * ranks[i]
        r = np.max(ranks)
        if verbose:
            print(f'new maxrank {r} < {maxrank} = maxrank')
            print('new ranks', ranks)
            
        if np.sum(ranks - prev_ranks) <= 0:
            pass    
        print('global rank update =', (ranks - prev_ranks)[1:-1])

        # update columns with random one
        for k in range(ndim - 1):
            # add new random column
            if len(cols[k]) < ranks[k + 1]:
                if k == len(cols) - 1:
                    chosen = set(cols[-1].flatten())
                    free = list(set(range(dims[-1])) - chosen)
                    new = np.random.choice(free)
                    cols[-1] = np.vstack( (cols[-1], [new]) )
                else:
                    prev = cols[k + 1]
                    d = dims[k + 1]
                    available = np.array([[i, *j] for i in range(d) for j in prev])
                    new = np.random.choice(len(available))
                    if (available[new] == cols[k]).all(axis=1).any():
                        new = np.random.choice(len(available))
                    cols[k] = np.vstack( (cols[k], available[new]) )

        # have columns. search rows
        row_hat = np.array([[i] for i in range(dims[0])])
        L = [[None] * ndim for _ in range(p)]
        R = [None] * p
        S = [None] * p
        for t in range(p):
            L[t][0] = np.eye(1,1)
        for t in range(p):
            R[t] = get_R(tt_list[t], cols)
        if verbose:
            print('start row iter')
        for k in range(ndim - 1):
            if verbose:
                print('!!!', k)
            for t in range(p):
                S[t] = np.einsum('ab,bic,cd->aid', L[t][k], tt_list[t][k], R[t][k]).reshape(-1, len(cols[k]))
            mat = np.apply_along_axis(func, axis=0, arr=np.stack(S, axis=0))
            q, _ = np.linalg.qr(mat)
            row = maxvol_cols(q)
            if verbose:
                print(cols[k].shape)
                print(ranks[k])
                print(dims[k])
            if verbose:
                print(row.shape)
                print(np.sum(row))
                print(row_hat.shape)
            row = row_hat[row]
            if verbose:
                print(row.shape)
            rows[k] = row
            row_hat = np.array([[*row[i], j] for i in range(len(row)) for j in range(dims[k + 1])])

            # update L
            for t in range(p):
                rk = tt_list[t][k].shape[-1]
                if k == 0:
                    L[t][1] = tt_list[t][0].reshape(dims[0], -1)[[i[0] for i in rows[0]], :]
                    continue
                L[t][k + 1] = np.zeros((len(rows[k]), rk))
                for i, b in enumerate(rows[k]):
                    L[t][k + 1][i, :] = L[t][k][rows[k - 1].tolist().index(b[:-1].tolist()), :] @ tt_list[t][k][:, b[-1], :]
        
        # have rows, search columns
        # factors = [None] * ndim
        col_hat = np.array([[i] for i in range(dims[-1])])
        L = [[None] * ndim for _ in range(p)]
        R = [None] * p
        S = [None] * p
        tt_list_r = [reverse_tt(f) for f in tt_list]
        for t in range(p):
            L[t][0] = np.eye(1,1)
        print(rows)
        rows = rows[::-1]
        for i in range(len(rows)):
            rows[i] = np.flip(rows[i], axis=1)
        for t in range(p):
            R[t] = get_R(tt_list_r[t], rows)
        if verbose:
            print('start col iter')
        dims = dims[::-1]
        ranks = ranks[::-1]
        for k in range(ndim - 1):
            if verbose:
                print('!!!', k)
            for t in range(p):
                S[t] = np.einsum('ab,bic,cd->aid', L[t][k], tt_list_r[t][k], R[t][k]).reshape(-1, len(rows[k]))
            mat = np.apply_along_axis(func, axis=0, arr=np.stack(S, axis=0))
            q, _ = np.linalg.qr(mat)
            if verbose:
                print(mat.shape)
                print(q.shape)
                print(rows[k].shape)
                print(ranks[k])
                print(dims[k])
            orig_col = maxvol_cols(q)
            if verbose:
                print(orig_col.shape)
                print(np.sum(orig_col))
                print(col_hat.shape)
            col = col_hat[orig_col]
            if verbose:
                print(col.shape)
            cols[k] = col
            col_hat = np.array([[*col[i], j] for i in range(len(col)) for j in range(dims[k + 1])])

            # update L
            for t in range(p):
                rk = tt_list_r[t][k].shape[-1]
                if k == 0:
                    L[t][1] = tt_list_r[t][0].reshape(dims[0], -1)[[i[0] for i in cols[0]], :]
                    continue
                L[t][k + 1] = np.zeros((len(cols[k]), rk))
                for i, b in enumerate(cols[k]):
                    L[t][k + 1][i, :] = L[t][k][cols[k - 1].tolist().index(b[:-1].tolist()), :] @ tt_list_r[t][k][:, b[-1], :]
            factors[k] = mat @ np.linalg.pinv(mat[orig_col,:])
            factors[k] = factors[k].reshape((-1, dims[k], ranks[k + 1]))
            print(f'FACTOR_{k}', factors[k].shape)
        # update last factor
        for t in range(p):
            S[t] = L[t][-1] @ tt_list_r[t][-1].reshape(-1, dims[-1])
        mat = np.apply_along_axis(func, axis=0, arr=np.stack(S, axis=0))
        factors[-1] = mat.reshape(-1, dims[-1], 1)
        print(f'FACTOR_{-1}', factors[-1].shape)
        factors = reverse_tt(factors)
        rows = rows[::-1]
        dims = dims[::-1]
        ranks = ranks[::-1]
        for i in range(len(rows)):
            rows[i] = np.flip(rows[i], axis=1)
        cols = cols[::-1]
        for i in range(len(cols)):
            cols[i] = np.flip(cols[i], axis=1)
        print('ranks', ranks)
        print('dims and factors shapes')
        print(dims)
        for i in factors:
            print(i.shape)
        print('old_factors shapes')
        for i in old_factors:
            print(i.shape)
        # error
        old_factors[0] *= -1
        tt_diff = tt_add(old_factors, factors)
        if verbose:
            print('diff shapes')
            for i in tt_diff:
                print(i.shape)
        '''
        tt_diff_compr = tt_compression(tt_diff, 1e-10)
        print('diff shapes after compression')
        for i in tt_diff_compr:
            print(i.shape)
        tt_diff_ort = tt_orthogonalize_lr(tt_diff_compr)
        '''
        tt_diff_ort = tt_orthogonalize_lr(tt_diff)
        rerr = np.power(np.linalg.norm(tt_diff_ort[-1]), 2)
        print(f'{rerr=}')
        if rerr > prev_rerr:
            print()
            print()
            print('summary:')
            print('tt ranks:', ranks[1:-1])
            return old_factors
        prev_rerr = rerr
        for i in range(len(factors)):
            old_factors[i] = factors[i]
        it += 1
        if np.sum(update_ranks) == 2:
            pass
            # all ranks staying same
            # break
        print()
    print()
    print('summary:')
    print('tt ranks:', ranks[1:-1])
    return factors

In [176]:
def f1(x):
    return np.sin(sum(x))
def f2(x):
    return np.cos(sum(x))
np.random.seed(42)
dims = [10, 8, 11, 10]
tt_list = []
for k, f in enumerate([f1, f2]):
    tt_list.append(tt_cross(f, dims, maxrank=5, rtol=1e-5, verbose=False))
    T = np.fromfunction(lambda *args: f(np.array(args)), dims)
    print([U.shape for U in tt_list[k]])
    print(f'{k + 1} cross rerr =', np.linalg.norm(T - tensor_from_tt(tt_list[k])))
    print()

print()
print('MULTIFUNC TEST')
print()

def func(x):
    return 2 * x[0] - 3 * x[1]
factors_mf = tt_multifunc(tt_list, func, maxrank=5, rtol=1e-5, verbose=True)
tensor_mf = tensor_from_tt(factors_mf)
print([U.shape for U in factors_mf])
tensors = [tensor_from_tt(fact) for fact in tt_list]
tensor = np.apply_along_axis(func, axis=0, arr=np.stack(tensors, axis=0))
err = np.power(np.linalg.norm(tensor - tensor_mf), 2)
tensor_norm = np.power(np.linalg.norm(tensor), 2)
print('absolute error =', err)
print('relative error =', err / tensor_norm)

iteration 1/22
global rank update = [1 1 1]
ranks [1 2 2 2 1]
dims and factors shapes
[10, 8, 11, 10]
(1, 10, 2)
(2, 8, 2)
(2, 11, 2)
(2, 10, 1)
rerr=4399.736114948231

iteration 2/22
global rank update = [1 1 1]
ranks [1 2 2 2 1]
dims and factors shapes
[10, 8, 11, 10]
(1, 10, 2)
(2, 8, 2)
(2, 11, 2)
(2, 10, 1)
rerr=4.32096508347541e-27


summary:
tt ranks: [2 2 2]
[(1, 10, 2), (2, 8, 2), (2, 11, 2), (2, 10, 1)]
1 cross rerr = 1.3984102892106106e-14

iteration 1/22
global rank update = [1 1 1]
ranks [1 2 2 2 1]
dims and factors shapes
[10, 8, 11, 10]
(1, 10, 2)
(2, 8, 2)
(2, 11, 2)
(2, 10, 1)
rerr=4400.263885051768

iteration 2/22
global rank update = [1 1 1]
ranks [1 2 2 2 1]
dims and factors shapes
[10, 8, 11, 10]
(1, 10, 2)
(2, 8, 2)
(2, 11, 2)
(2, 10, 1)
rerr=1.8006656180894913e-27


summary:
tt ranks: [2 2 2]
[(1, 10, 2), (2, 8, 2), (2, 11, 2), (2, 10, 1)]
2 cross rerr = 1.3163317491848607e-14


MULTIFUNC TEST

iteration 1/22
new maxrank 2 < 5 = maxrank
new ranks [1 2 2 2 1]
glob

In [177]:
def f1(x):
    return np.sin(sum(x))
def f2(x):
    return np.cos(sum(x))
def f3(x):
    return np.sin(sum(x)) + np.cos(sum(x))
np.random.seed(42)
dims = [10, 8, 11, 10, 9]
tt_list = []
for k, f in enumerate([f1, f2, f3]):
    tt_list.append(tt_cross(f, dims, maxrank=5, rtol=1e-5, verbose=False))
    T = np.fromfunction(lambda *args: f(np.array(args)), dims)
    print([U.shape for U in tt_list[k]])
    print(f'{k + 1} cross rerr =', np.linalg.norm(T - tensor_from_tt(tt_list[k])))
    print()

print()
print('MULTIFUNC TEST')
print()

def func(x):
    return 2 * x[0] + 3 * x[1] - 5 * x[2]
factors_mf = tt_multifunc(tt_list, func, maxrank=5, rtol=1e-5, verbose=True)
tensor_mf = tensor_from_tt(factors_mf)
print([U.shape for U in factors_mf])
tensors = [tensor_from_tt(fact) for fact in tt_list]
tensor = np.apply_along_axis(func, axis=0, arr=np.stack(tensors, axis=0))
err = np.power(np.linalg.norm(tensor_mf - tensor), 2)
tensor_norm = np.power(np.linalg.norm(tensor), 2)
print('absolute error =', err)
print('relative error =', err / tensor_norm)

iteration 1/22
global rank update = [1 1 1 1]
ranks [1 2 2 2 2 1]
dims and factors shapes
[10, 8, 11, 10, 9]
(1, 10, 2)
(2, 8, 2)
(2, 11, 2)
(2, 10, 2)
(2, 9, 1)
rerr=39600.07938872676

iteration 2/22
global rank update = [1 1 1 1]
ranks [1 2 2 2 2 1]
dims and factors shapes
[10, 8, 11, 10, 9]
(1, 10, 2)
(2, 8, 2)
(2, 11, 2)
(2, 10, 2)
(2, 9, 1)
rerr=3.5714116406970086e-26


summary:
tt ranks: [2 2 2 2]
[(1, 10, 2), (2, 8, 2), (2, 11, 2), (2, 10, 2), (2, 9, 1)]
1 cross rerr = 1.1499291229922721e-13

iteration 1/22
global rank update = [1 1 1 1]
ranks [1 2 2 2 2 1]
dims and factors shapes
[10, 8, 11, 10, 9]
(1, 10, 2)
(2, 8, 2)
(2, 11, 2)
(2, 10, 2)
(2, 9, 1)
rerr=39599.92061127323

iteration 2/22
global rank update = [1 1 1 1]
ranks [1 2 2 2 2 1]
dims and factors shapes
[10, 8, 11, 10, 9]
(1, 10, 2)
(2, 8, 2)
(2, 11, 2)
(2, 10, 2)
(2, 9, 1)
rerr=1.7831467785369238e-26


summary:
tt ranks: [2 2 2 2]
[(1, 10, 2), (2, 8, 2), (2, 11, 2), (2, 10, 2), (2, 9, 1)]
2 cross rerr = 4.138719552435