In [19]:
import numpy as np

def tensor_inner_product(X, Y):
    product = np.einsum('ijkl,ijkl->', np.transpose(X, axes=(1, 0, 2, 3)), Y)
    return product

def tensor_global_arnoldi(A, B, m):
    J1, J2, _, _ = A.shape
    _, _, K1, K2 = B.shape
    
    V_tilde = np.zeros((J1, J2, K1, (m+1)*K2))
    H = np.zeros((m+1, m))
    
    beta = np.linalg.norm(B)
    V1 = B / beta
    V_tilde[:, :, :, :K2] = V1
    
    for j in range(m):
        W = np.einsum('ijkl,klmn->ijmn', A, V_tilde[:, :, :, j*K2:(j+1)*K2])
        for i in range(j+1):
            h_ij = tensor_inner_product(V_tilde[:, :, :, i*K2:(i+1)*K2], W)
            H[i, j] = h_ij
            W = W - h_ij * V_tilde[:, :, :, i*K2:(i+1)*K2]
        
        h_j1_j = np.linalg.norm(W)
        H[j+1, j] = h_j1_j
        
        if h_j1_j == 0:
            break
        else:
            V_tilde[:, :, :, (j+1)*K2:(j+2)*K2] = W / h_j1_j
    
    return V_tilde, H

def check_orthonormality(V_tilde, m, K2):
    identity_matrix = np.eye(m)
    inner_product_matrix = np.zeros((m, m))
    
    for i in range(m):
        for j in range(m):
            Vi = V_tilde[:, :, :, i*K2:(i+1)*K2]
            Vj = V_tilde[:, :, :, j*K2:(j+1)*K2]
            inner_product_matrix[i, j] = tensor_inner_product(Vi, Vj)
    
    print("Inner Product Matrix:\n", inner_product_matrix)
    return np.allclose(inner_product_matrix, identity_matrix, atol=1e-6)

A = np.random.rand(3, 3, 3, 3)
B = np.random.rand(3, 3, 3, 3)
m = 2
V_tilde, H = tensor_global_arnoldi(A, B, m)
print("V_tilde:\n", V_tilde)
print("H:\n", H)

is_orthonormal = check_orthonormality(V_tilde, m, 3)
print("Is orthonormal:", is_orthonormal)


V_tilde:
 [[[[ 1.38955716e-01  2.04047256e-02  1.46518368e-02 -8.10467434e-03
     1.94911634e-01  1.55080756e-01  1.50542038e-02  9.71914587e-02
     1.67381350e-01]
   [ 9.21513162e-02  8.56821339e-02  1.47673142e-02  8.51523623e-02
     1.20506310e-01  1.48151738e-01  1.10484250e-01  1.70687425e-03
     7.55582583e-02]
   [ 8.07266175e-02  7.11188020e-03  4.55148515e-02  1.09969947e-01
     8.34497341e-02  1.00783998e-01  9.77738414e-02  1.37957722e-01
     2.03373630e-01]]

  [[ 1.31776275e-01  1.44584589e-01  1.16486790e-01  2.76154992e-02
    -2.10095456e-03  1.55902557e-02  7.72524227e-02  3.48417996e-02
     7.70760369e-02]
   [ 1.71512998e-01  7.29782018e-02  1.36540509e-01  1.74672774e-02
     8.06935892e-02  3.42382770e-02 -1.22412276e-02  3.03958404e-01
    -2.25487772e-01]
   [ 1.63379038e-01  1.21208560e-01  1.62450452e-01 -4.33358178e-03
    -2.17152878e-02 -1.47795097e-02  5.12030331e-02 -1.63634366e-01
    -3.02198914e-02]]

  [[ 4.27023855e-02  1.72028409e-01  5.52185