# Higher Order SVD (HOSVD) and TRG (HOTRG)

## References
* PRB **86**, 045139 (2012).

## Higer Order SVD (HOSVD)

Given a rank-4 tensor $M_{abcd}$, we want to find unitary matrices $U^L, U^R, U^U, U^D$ such that
$$
  \large
  M_{abcd} = \sum_{ijkl} S_{ijkl} U^L_{ai} U^R_{bj} U^U_{ck} U^{D}_{dl},
$$
where $S$ is the core tensor.

$$
  \large
  X_{aa^\prime} \equiv
  \sum_{bcd} M_{abcd} M_{a^\prime bcd} 
  = \sum_{bcd} 
    \sum_{ijkl} S_{ijkl} U^L_{ai} U^R_{bj} U^U_{ck} U^{D}_{dl}
    \sum_{i^\prime j^\prime k^\prime l^\prime } S_{i^\prime j^\prime k^\prime l^\prime } U^L_{a^\prime i^\prime } U^R_{bj^\prime } U^U_{ck^\prime } U^{D}_{dl^\prime } 
  = \sum_{ijkl} \sum_{i^\prime j^\prime k^\prime l^\prime } S_{ijkl} U^L_{ai} S_{i^\prime j^\prime k^\prime l^\prime } U^L_{a^\prime i^\prime } 
  \delta_{jj^\prime} \delta_{kk^\prime} \delta_{ll^\prime}
  = \sum_{ii^\prime} \sum_{jkl} S_{ijkl} U^L_{ai} S_{i^\prime jkl} U^L_{a^\prime i^\prime jkl} 
$$
so that
$$
  \large
  \sum_{aa^\prime} U^{L}_{ai} U^{L}_{a^\prime i^\prime} X_{aa^\prime}
  = \sum_{jkl} S_{ijkl} S_{i^\prime jkl}.
$$

In [1]:
import Tor10
import copy
import numpy as np

def Tprint(T):
    print(T.Print_diagram())
    print(T)
    
def Tprint_diag(T):
    diag = np.array([T[i,i].item() for i in range(T.bonds[0].dim)])
    print(diag)

In [2]:
def check_unitary(U0):
    U = copy.deepcopy(U0)
    Ut = U.Whole_transpose()
    U.SetLabels([0, -1])
    Ut.SetLabels([-1, 1])
    X = Tor10.Contract(Ut, U)
    print(Tor10.Contract(Ut, U))
    print(Tor10.Contract(U, Ut))    

## Use labels to find the corresponding U

In [3]:
def find_U(M, labels):
    MP = copy.deepcopy(M)
    MP.SetName('MP')    
    MP = MP.Permute(labels, rowrank=1)
    
#     MPt = copy.deepcopy(MP)
    MPt = MP.Whole_transpose()
    MPt.SetName('MPt')    
    MPt.SetLabel(1111, 3)
    
#     MP.Print_diagram()
#     MPt.Print_diagram()
    
    U, _, _ = Tor10.Svd(Tor10.Contract(MP, MPt))
#     U.Print_diagram()
    return U

In [4]:
bd = Tor10.Bond(3)
M = Tor10.UniTensor([bd, bd, bd, bd], rowrank=4, name='M')
M.Rand()
M.Print_diagram()


U0 = find_U(M, [0,1,2,3])
U1 = find_U(M, [1,0,2,3])
U2 = find_U(M, [2,0,1,3])
U3 = find_U(M, [3,0,1,2])
check_unitary(U0)
check_unitary(U1)
check_unitary(U2)
check_unitary(U3)

U0.SetLabel(int(U0.labels[0])+10, 1)
U1.SetLabel(int(U1.labels[0])+10, 1)
U2.SetLabel(int(U2.labels[0])+10, 1)
U3.SetLabel(int(U3.labels[0])+10, 1)

# U0.Print_diagram()
# U1.Print_diagram()
# U2.Print_diagram()
# U3.Print_diagram()

U0t = U0.Whole_transpose()
U1t = U1.Whole_transpose()
U2t = U2.Whole_transpose()
U3t = U3.Whole_transpose()

# U0t.Print_diagram()
# U1t.Print_diagram()
# U2t.Print_diagram()
# U3t.Print_diagram()

U01 = Tor10.Contract(U0, U1)
# U01.Print_diagram()
U23 = Tor10.Contract(U2, U3)
# U23.Print_diagram()
U0123 = Tor10.Contract(U01, U23)
U0123.Print_diagram()

S = Tor10.Contract(U0123, M)
S.Print_diagram()

-----------------------
tensor Name : M
tensor Rank : 4
has_symmetry: False
on device     : cpu
is_diag       : False
            -------------      
           /             \     
     0 ____| 3           |        
           |             |     
     1 ____| 3           |        
           |             |     
     2 ____| 3           |        
           |             |     
     3 ____| 3           |        
           \             /     
            -------------      
Tensor name: 
is_diag    : False
tensor([[ 1.0000e+00,  1.7682e-16,  1.6561e-16],
        [ 1.7682e-16,  1.0000e+00, -4.3008e-17],
        [ 1.6561e-16, -4.3008e-17,  1.0000e+00]], dtype=torch.float64)

Tensor name: 
is_diag    : False
tensor([[ 1.0000e+00,  1.7682e-16,  1.6561e-16],
        [ 1.7682e-16,  1.0000e+00, -4.3008e-17],
        [ 1.6561e-16, -4.3008e-17,  1.0000e+00]], dtype=torch.float64)

Tensor name: 
is_diag    : False
tensor([[ 1.0000e+00,  5.8256e-17, -5.1833e-17],
        [ 5.8256e-17,  1.0000e

### All orthogonality

$$
  \langle S_{:,j,:,:} | S_{:,j^\prime,:,:}\rangle 
  =\sum_{ikl} S_{ijkl} S_{ij^\prime kl} = 0 \; \text{if} \; i \neq j.
$$

In [35]:
# S.labels = [10,11,12,13]

# i,ip
S0t = copy.deepcopy(S)
S0t.SetLabels([-10,11,12,13])
SS0 = Tor10.Contract(S0t, S)
SS0.SetName('SS0')
print(SS0)
# Tprint_diag(SS0)

MP = copy.deepcopy(M)
MP.SetName('MP')    
MP = MP.Permute([0,1,2,3], rowrank=1)
MPt = MP.Whole_transpose()
MPt.SetName('MPt')    
MPt.SetLabel(1111, 3)
# MP.Print_diagram()
# MPt.Print_diagram()
X = Tor10.Contract(MP, MPt)
X.SetName('X')
X.Print_diagram()

U0t = U0.Whole_transpose()
U0t.Print_diagram()
U0.SetLabels([1111,11])
U0.Print_diagram()
print(Tor10.Contract(Tor10.Contract(U0t,X),U0))



# j,jp
S1t = copy.deepcopy(S)
S1t.SetLabels([10,-11,12,13])
SS1 = Tor10.Contract(S1t, S)
# print(SS1)
Tprint_diag(SS1)

# k,kp
S2t = copy.deepcopy(S)
S2t.SetLabels([10,11,-12,13])
SS2 = Tor10.Contract(S2t, S)
# print(SS2)
Tprint_diag(SS2)

# l,lp
S3t = copy.deepcopy(S)
S3t.SetLabels([10,11,12,-13])
SS3 = Tor10.Contract(S3t, S)
print(SS3)
Tprint_diag(SS3)

Tensor name: SS0
is_diag    : False
tensor([[ 2.7290e+01, -8.2278e-17,  1.0526e-15],
        [-8.2278e-17,  3.5206e+00, -1.3059e-16],
        [ 1.0526e-15, -1.3059e-16,  2.0183e+00]], dtype=torch.float64)

-----------------------
tensor Name : X
tensor Rank : 2
has_symmetry: False
on device     : cpu
is_diag       : False
            -------------      
           /             \     
     0 ____| 3         3 |____ 1111
           \             /     
            -------------      
-----------------------
tensor Name : 
tensor Rank : 2
has_symmetry: False
on device     : cpu
is_diag       : False
            -------------      
           /             \     
    11 ____| 3         3 |____ 1111
           \             /     
            -------------      
-----------------------
tensor Name : 
tensor Rank : 2
has_symmetry: False
on device     : cpu
is_diag       : False
            -------------      
           /             \     
   1111 ____| 3         3 |____ 11 
           \  

In [17]:
SS0.bonds[0].dim

3

In [21]:
def trunc_err(SS, D):
    if D>SS.bonds[0].dim:
        return
    err = 0.0
    for i in range(D, SS.bonds[0].dim):
        err += SS[i,i].item()
    return err

In [23]:
for Dc in range(1,3+1):
    print('Dc={:d}'.format(Dc))
    for SS in [SS0, SS1, SS2, SS3]:
        print(trunc_err(SS, Dc))

Dc=1
5.538888437547049
3.4852579709834766
4.843043482038766
5.083866044367964
Dc=2
2.018256437914992
0.9675600037110549
2.15970465209085
1.9780187027023937
Dc=3
0.0
0.0
0.0
0.0
