In [4]:
from functools import lru_cache
from copy import copy, deepcopy
import itertools
from opt_einsum import contract

In [5]:
level = -1
cur_rank = -1
def my_cach(func):
    func = lru_cache(maxsize=int(1e6))(func)
    def f(I):
        I = np.asarray(I, dtype=int)
        global level, cur_rank
        if I.ndim == 1:
            y = func(tuple(I))
            if y > f.max:
                f.max = y
                print(f">>> max: {f.max} ({func.cache_info().misses} evals) (rank={cur_rank}) (I = {I})")
            return y
        elif I.ndim == 2:
            y = [func(tuple(i)) for i in I]
            max_y = max(y)
            if max_y > f.max:
                f.max = max_y
                print(f">>> max: {f.max}")
            return y
        else:
            raise TypeError('Bad argument')
            
    f.max = -np.inf
    f.func = func
        
    return f
                
        

In [17]:
def my_max(l, m=0):
    if len(l) == 0:
        return m
    else:
        return max(l)
    
    
    
        
def next_approx(A, Xj, idx_row, idx_col, idx, vals_row, vals_col, vals):
    # perform iteration as in formula (5)--(6) from 
    ## ``Black box approximation of tensors in hierarchical Tucker format''
    d = max(max(idx_row), my_max(idx_col), my_max(idx)) + 1
    #print(f"{d=}")
    I = np.zeros(d, dtype=int)
    I[idx_row] = vals_row
    I[idx_col] = vals_col
    #print(idx, vals)
    if idx is not None and vals is not None:
        I[idx] = vals
    
    #lamb = 1./(A(I) - Xj(I))
    lamb = (A(I) - Xj(I))
    
    if abs(lamb) < 1e-8: # otherwise "Singular matrix" error would
        return Xj, True

    def f(I):
        res = Xj(I)
        
        I = np.asarray(I, dtype=int)
        if idx is not None and vals is not None:
            I[idx] = vals
            
        I1 = np.copy(I)
        I2 = np.copy(I)
        
        I1[idx_col] = vals_col
        I2[idx_row] = vals_row
        
        #return res +  (A(I2) - Xj(I2))*lamb*(A(I1) - Xj(I1))
        return res +  (A(I2) - Xj(I2))*(A(I1) - Xj(I1))/lamb
        

    return f, False
    


def find_pivots(A, Xj, idx_init, t, tp, f, shape, l_max=3):
    # corresponds to Alg. 1
    idx_init = np.array(idx_init)
    d = len(shape) # dimenstion
    t_and_tp = list(t) + list(tp)
    
    y_all = []
    # modify childrens
    for _ in range(l_max):
        for idx_m in t_and_tp:
            y = []
            idx = list(idx_init)
            #print("1", shape, idx_m)
            for im in range(shape[idx_m]):
                idx[idx_m] = im
                y.append(abs(A(idx) - Xj(idx)))
            idx_init[idx_m] = np.argmax(y)
            
    
    y_all.extend(y)
            
    # modify fathers
    if True:
        fp_idx = np.setdiff1d(np.arange(d), t_and_tp)
        #print("fathers", f, fp_idx, d)
        y = []
        for im in f:
            idx = np.array(idx_init)
            #print(fp_idx, im)
            idx[fp_idx] = im
            y.append(abs(A(idx) - Xj(idx)))

        if y:
            #print("Fathers", fp_idx, idx_init[fp_idx], f[np.argmax(y)])
            idx_init[fp_idx] = f[np.argmax(y)]
            y_all.extend(y)
    
    return idx_init, max(y_all)
        
                
def X0(*args):
    return 0
    
    
def rank_r_approx(A, t, tp, f, f_vals, shape, r=3, l_max=3, debug=False, tryes=1):
    Xj = X0
    pivots = []
    global cur_rank
    for cur_r in range(r):
        cur_rank = cur_r
        p_cur, y = [], []
        for _ in range(tryes):
            idx_init = [np.random.randint(0, m) for m in shape]
            p, y_max = find_pivots(A, Xj, idx_init, t, tp, f_vals, shape=shape, l_max=l_max)
            p_cur.append(p)
            y.append(y_max)
        
        idx_max = np.argmax(y)
        #if abs(y[idx_max]) < 1e-9:
        #    break
            
        p = p_cur[idx_max]
        pivots.append(p)
        if debug:
            print(p)
        #Xj = next_approx(A, Xj, t, tp, f, p[t], p[tp], f_vals)
        ptf = list(tp) + list(f)
        Xj, stop = next_approx(A, Xj, t, ptf, [], p[t], p[ptf], [])
        if stop:
            pivots.pop()
            print("Emerge stop")
            break
        
    return Xj, pivots


def split_idx(idx, random=False):
    n = len(idx)
    if n < 2:
        return []
    
    if random:
        idx = np.random.permutation(idx)

    return idx[:n//2], idx[n//2:]


class Node():
    
    def __init__(self, idx=None, level=0):
        if idx is None:
            idx = []

        self.idx = idx
        self.L = None
        self.R = None
        self.level = level
        
    
    def split_children(self, random=False):
        n = len(self.idx) 
        if n > 1:
            idx = self.idx
            if random:
                idx = np.random.permutation(self.idx)
            self.L = Node(idx[:n//2], self.level + 1)
            self.R = Node(idx[n//2:], self.level + 1)
            
            self.L._tp = self.R.idx
            self.R._tp = self.L.idx
            
            self.L.split_children(random=random)
            self.R.split_children(random=random)
            
    def print(self, sp_len=0):
        idx_str = str(self.idx)
        sp_len_ch = sp_len + len(idx_str) + 1
        if self.L is not None:
            self.L.print(sp_len_ch)
        #print(f"{' ' * sp_len}-> {idx_str} ({self.tp})")
        
        add_str = ""
        try:
            add_str = f"({self.build_idx})"
        except:
            pass
        print(f"{' ' * sp_len}-> {idx_str} ({add_str})")
        if self.R is not None:
            self.R.print(sp_len_ch)
            
            
    @property
    def tp(self):
        try:
            return self._tp
        except:
            return []
        
            
    def calc_pivots(self, A, shape, f, f_vals, rank=3, l_max=3, check=False, max_level=10000000):
        if self.level > max_level:
            return

        global level
        level = self.level
         

        self.d = d = len(shape)
        self.shape = shape
        self.A = A
        
        self.idx_p = np.setdiff1d(np.arange(d), self.idx) # all other tahn our indices, 
                                                          # note in paper as t'
        t = self.idx # our indices
        tp = self.tp # siblis indices
        

        if len(t) > 0 and len(tp) > 0:
            self.Xj, p = rank_r_approx(A, t, tp, f, f_vals, shape, r=rank, l_max=l_max, 
                                       debug=False)
            p = np.array(p)

            self.build_idx = p[:, self.idx], p[:, self.idx_p] # defined as P_t in the paper
            
            #print(self.idx, self.build_idx)
            if check:
                check_approx(A, shape, self.Xj)
            
            
        else:
            self.build_idx = None
            p = None
            self.Xj = None
        
        
        for nn, nn_sibl in [(self.L, self.R), (self.R, self.L)]:
            if nn is not None and len(nn.idx) > 0:
                siblis_idx = []
                if nn_sibl is not None and len(nn_sibl.idx) > 0:
                    siblis_idx = list(nn_sibl.idx)
                nn_father_idx = np.setdiff1d(np.arange(d), list(nn.idx) + siblis_idx)
                if p is not None:
                    f_vals = p[:, nn_father_idx]
                else:
                    #print("f_idx", nn_father_idx)
                    f_vals = []
                    nn_father_idx = []
                    
                nn.calc_pivots(A, shape, nn_father_idx, f_vals, 
                               rank=rank, l_max=l_max, check=check, max_level=max_level)
                
                
    def make_mat_S(self):
        t = self.idx
        tp = self.idx_p
        idx = np.empty(self.d, dtype=int)
        
        
        if self.build_idx is not None:
            bt, btp = self.build_idx
            res = self.S = np.empty([len(bt), len(btp)])
            for jt, ti in enumerate(bt):
                for jtp, tpi in enumerate(btp):
                    idx[t] = ti
                    idx[tp] = tpi
                    #assert len(np.setdiff1d(np.setdiff1d(np.arange(self.d), t), tp)) == 0
                    res[jt, jtp] = self.Xj(idx)
            #print(f"S inv = {np.linalg.inv(res)} (idx = {self.idx})")
                    
        for nn in [self.L, self.R]:
            if nn is None:
                continue
            nn.make_mat_S()
            
            
    def make_cores(self):
        self.make_mat_S()
        if self.build_idx is not None: # we are not root
            if self.L is None and self.R is None:
                # we are leaf
                bt, btp = self.build_idx
                B =  np.empty([self.shape[self.idx[0]], len(btp)])
                I = np.empty(self.d, dtype=int)
                for i in range(B.shape[0]):
                    for j, val in enumerate(btp):
                        I[self.idx[0]] = i
                        I[self.idx_p] = val
                        B[i, j] = self.Xj(I)

                self.B = B.T
                    
            else:
                # not root and not leaf
                S1, S2 = self.L.S, self.R.S
                bt, btp = self.build_idx
                #print(self.idx, self.idx_p, self.build_idx)
                
                Pt1 = self.L.idx
                Pt2 = self.R.idx
                Ptp = self.idx_p
                
                A_cur = np.empty([len(self.L.build_idx[0]), len(self.R.build_idx[0]), len(btp)])
                I = np.empty(self.d, dtype=int)

                for ind1, iPt1_val in enumerate(self.L.build_idx[0]):
                    for ind2, iPt2_val in enumerate(self.R.build_idx[0]):
                        for ind3, iPtp_val in enumerate(btp):
                            I[Pt1] = iPt1_val
                            I[Pt2] = iPt2_val
                            #print(iPtp, iPtp_val)
                            I[Ptp] = iPtp_val
                            #assert len(np.setdiff1d(np.arange(self.d), list(Pt1) + list(Pt2) + list(Ptp))) == 0

                            A_cur[ind1, ind2, ind3] = self.Xj(I)
                            
                make_stable = True
                if make_stable:
                    # Let's make it more stable
                    self.B = stable_double_inv(A_cur, S1, S2)
                else:
                    self.B = contract("rlk,ir,jl->ikj", A_cur, np.linalg.inv(S1), np.linalg.inv(S2))
                

        else:
            # root node
            S1, S2 = self.L.S, self.R.S
            
            
            #_="""
            Pt1 = self.L.idx
            Pt2 = self.R.idx
            ii = 0
            A_cur = np.empty([len(self.L.build_idx[ii]), len(self.R.build_idx[ii])])
            I = np.empty(self.d, dtype=int)
            for ind1, iPt1_val in enumerate(self.L.build_idx[ii]):
                for ind2,  iPt2_val in enumerate(self.R.build_idx[ii]):
                        #print(iPt1, iPt1_val)
                        I[Pt1] = iPt1_val
                        I[Pt2] = iPt2_val

                        A_cur[ind1, ind2] = self.A(I)

            make_stable = True
            if make_stable:
                self.B = stable_double_inv(A_cur[..., None], S1, S2)
            else:
                self.B = contract("rl,ir,jl->ij", A_cur, np.linalg.inv(S1), np.linalg.inv(S2))[:, None, :]
                        
        
        
        for nn in [self.L, self.R]:
            if nn is None:
                continue
            nn.make_cores()
        
        
        
    def transfer_cores(self):

        me = HNode(self.B)


        ch = []
        for nn in [self.L, self.R]:
            if nn is None:
                continue

            ch.append(nn.transfer_cores())
        if ch:
            me.set_children(*ch)

        return me


    def  __getitem__(self, I):
        # 
        
        if self.L is None and self.R is None:
            # leaf!
            res = self.B[:, I.reshape(-1)]
        else:
            n = I.shape[1]
            res = contract("in,ikj,jn->kn", self.L[I[:, :n//2]], self.B, self.R[I[:, n//2:]])
            if self.level == 0: # root node
                res = res[0]

        #print(f"level={self.level}, B.shape={self.B.shape}, res.shape={res.shape}")
        return res

    def full(self, res=None, num_up=0):
        need_eins = False
        if res is None:
            res = [np.array([1]), [0]]
            need_eins = True

        if self.level == 0:
            res.extend([self.B, [1, 0, 2]])
            self.L.full(res, 1)
            self.R.full(res, 2)
        elif self.L is None and self.R is None:
            next_num = find_next_free_num(res)
            res.extend([self.B, [num_up, next_num]])
        else:
            next_num = find_next_free_num(res)
            res.extend([self.B, [next_num, num_up, next_num + 1]])
            self.L.full(res, next_num)
            self.R.full(res, next_num + 1)

        if need_eins:
            #self._contract_rules = res
            return contract(*res)


def find_next_free_num(arr):
    res = []
    for l in arr[1::2]:
        res.extend(l)
    return max(res) + 1
    
def stable_double_inv(A_cur, S1, S2):
                    shapes_A = A_cur.shape
                    T1 = np.linalg.solve(S1, A_cur.reshape(A_cur.shape[0], -1))
                    T1 = T1.reshape(*shapes_A)
                    T1 = np.transpose(T1, [1, 0, 2])
                    T1_shapes = T1.shape
                    T1 = T1.reshape(T1.shape[0], -1)
                    T2 = np.linalg.solve(S2, T1)
                    T2 = T2.reshape(*T1_shapes)
                    T2 = np.transpose(T2, [1, 2, 0])
                    return np.copy(T2)
    
def check_approx(A, shape, Xj):
    errors = []
    tru = []
    for I in itertools.product(*[range(i) for i in shape]):
        AI = A(I)
        errors.append(AI - Xj(I))
        tru.append(AI)
        
        
    print(np.linalg.norm(errors)/np.linalg.norm(tru))

    

# tensor we will approx

def test_T_simple(I):
    return sum(I)


def test_T(I):
    return sum(I) + np.prod(I) + sum(I)*np.prod(I)

def approx_by_HT(f, shape, rank, l_max=3):
    d = len(shape)
    tr = Node(np.arange(d))
    tr.split_children()
    tr.calc_pivots(f, shape, [], [], rank=rank, l_max=l_max)
    tr.make_cores()

    return tr

    

# Experiments

## Simple: function

In [14]:
@my_cach
def f_Simple(I):
    #print(I)
    #return 100*np.exp(-np.sum(I)/10) + np.sum(I)**2
    return np.sum(I)**4

d = 8
shape = [10]*d

test_I = np.vstack([
    np.ones(d)[None, :],
    2*np.ones(d)[None, :],
    5*np.ones(d)[None, :]
]).astype(int)


for rank in [3, 4, 5]:
    tree = approx_by_HT(f_Simple, shape, rank=rank, l_max=3)
    print(f"rank = {rank}")
    tree_res = tree[test_I]
    tru_res = [f_Simple(i) for i in test_I]
    
    print(f"tree: {tree_res}")
    print(f"tru: {tru_res}")

>>> max: 1336336 (1 evals) (rank=0) (I = [0 5 5 0 9 9 2 4])
>>> max: 1500625 (2 evals) (rank=0) (I = [1 5 5 0 9 9 2 4])
>>> max: 1679616 (3 evals) (rank=0) (I = [2 5 5 0 9 9 2 4])
>>> max: 1874161 (4 evals) (rank=0) (I = [3 5 5 0 9 9 2 4])
>>> max: 2085136 (5 evals) (rank=0) (I = [4 5 5 0 9 9 2 4])
>>> max: 2313441 (6 evals) (rank=0) (I = [5 5 5 0 9 9 2 4])
>>> max: 2560000 (7 evals) (rank=0) (I = [6 5 5 0 9 9 2 4])
>>> max: 2825761 (8 evals) (rank=0) (I = [7 5 5 0 9 9 2 4])
>>> max: 3111696 (9 evals) (rank=0) (I = [8 5 5 0 9 9 2 4])
>>> max: 3418801 (10 evals) (rank=0) (I = [9 5 5 0 9 9 2 4])
>>> max: 3748096 (16 evals) (rank=0) (I = [9 6 5 0 9 9 2 4])
>>> max: 4100625 (17 evals) (rank=0) (I = [9 7 5 0 9 9 2 4])
>>> max: 4477456 (18 evals) (rank=0) (I = [9 8 5 0 9 9 2 4])
>>> max: 4879681 (19 evals) (rank=0) (I = [9 9 5 0 9 9 2 4])
>>> max: 5308416 (25 evals) (rank=0) (I = [9 9 6 0 9 9 2 4])
>>> max: 5764801 (26 evals) (rank=0) (I = [9 9 7 0 9 9 2 4])
>>> max: 6250000 (27 evals) (rank

## Simple: small random array

In [49]:
d = 4
shape  = [2]*d
T = np.random.rand(*shape)
print(T.shape)

@my_cach
def f_Tens(I):
    return T[tuple(I)]

for rank in [3, 4, 5, 10]:
    tree = approx_by_HT(f_Tens, shape, rank=rank, l_max=3)
    print(f"rank = {rank}")
    print("error: ", np.linalg.norm(tree.full() - T))


(2, 2, 2, 2)
>>> max: 0.5683041340861806 (1 evals) (rank=0) (I = [0 0 1 1])
>>> max: 0.717866208338905 (3 evals) (rank=0) (I = [0 1 1 1])
>>> max: 0.9001589374764813 (7 evals) (rank=1) (I = [0 1 0 0])
>>> max: 0.907977117602931 (11 evals) (rank=1) (I = [1 0 1 0])
Emerge stop
Emerge stop
Emerge stop
Emerge stop
rank = 3
error:  0.5860256098867653
Emerge stop
Emerge stop
Emerge stop
Emerge stop
Emerge stop
rank = 4
error:  0.20588218907120726
Emerge stop
Emerge stop
Emerge stop
Emerge stop
Emerge stop
Emerge stop
rank = 5
error:  1.9755164663124212
Emerge stop
Emerge stop
Emerge stop
Emerge stop
Emerge stop
Emerge stop
rank = 10
error:  0.20588218907120726


In [51]:
print(T)
print(tree.full())

[[[[0.4715815303 0.8582040562]
   [0.708102174  0.5683041341]]

  [[0.9001589375 0.0210129598]
   [0.445707823  0.7178662083]]]


 [[[0.510444374  0.8634722902]
   [0.9079771176 0.5414409143]]

  [[0.8956558041 0.7083274594]
   [0.7378443958 0.3426696452]]]]
[[[[0.4715815303 0.8582040562]
   [0.9139843631 0.5683041341]]

  [[0.9001589375 0.0210129598]
   [0.445707823  0.7178662083]]]


 [[[0.510444374  0.8634722902]
   [0.9079771176 0.5414409143]]

  [[0.8956558041 0.7083274594]
   [0.7378443958 0.3426696452]]]]
