In [1]:
import numpy as np
import time

def updateL(L, A, W, A_star):
    """
    Perform the UpdateL tensor contraction with the following steps:
      1. Contract L (shape: [a, i, b]) with A (shape: [b, j, c]) over index b.
      2. Contract the intermediate tensor (shape: [a, i, j, c]) with W (shape: [c, j, k, d])
         over indices c and j.
      3. Contract the result (shape: [a, i, j, k, d]) with A_star (shape: [d, k, e])
         over indices d and k.
    The final result L_new has shape (a, i, e).
    """
    # Step 1: Contract L and A over index "b"
    T1 = np.einsum("aib, bjc -> aijc", L, A)
    
    # Step 2: Contract T1 with W over indices "c" and "j"
    T2 = np.einsum("aijc, cjkd -> aijkd", T1, W)
    
    # Step 3: Contract T2 with A_star over indices "d" and "k"
    L_new = np.einsum("aijkd, dke -> aie", T2, A_star)
    
    return L_new

def main():
    # Define dimensions (example values)
    # L: (a, i, b), A: (b, j, c), W: (c, j, k, d), A_star: (d, k, e)
    a = 5    # left bond dimension for L and the result
    i = 2    # physical index dimension (s)
    b = 6    # bond dimension connecting L and A
    j = 2    # physical index (same as i)
    c = 7    # bond dimension connecting A and W
    k = 2    # physical index (same as j)
    d = 5    # bond dimension connecting W and A_star
    e = 5    # final bond dimension for the result

    # Create random tensors (using np.float32 for precision)
    L = np.random.rand(a, i, b).astype(np.float32)
    A = np.random.rand(b, j, c).astype(np.float32)
    W = np.random.rand(c, j, k, d).astype(np.float32)
    A_star = np.random.rand(d, k, e).astype(np.float32)
    
    # Measure the time for the contraction
    start = time.perf_counter()
    L_new = updateL(L, A, W, A_star)
    end = time.perf_counter()
    
    print("Updated left block tensor shape:", L_new.shape)
    print("Time taken for contraction: {:.6f} seconds".format(end - start))
    print("Result L_new:")
    print(L_new)

if __name__ == "__main__":
    main()

Updated left block tensor shape: (5, 2, 5)
Time taken for contraction: 0.000146 seconds
Result L_new:
[[[34.663723 26.858585 33.36639  32.125824 27.188856]
  [55.537937 42.99234  53.6211   51.601807 43.741142]]

 [[56.357765 43.743336 54.18474  52.343353 44.26127 ]
  [58.09102  45.2945   56.1056   54.23579  45.752384]]

 [[69.852066 54.016068 67.40987  65.05545  54.792088]
  [71.77674  55.322716 69.28301  66.63314  56.37684 ]]

 [[47.829716 36.662415 45.95102  44.417404 37.429245]
  [42.410645 33.19967  40.975796 39.55999  33.468327]]

 [[61.807793 47.470493 59.46248  57.267056 48.38548 ]
  [62.56972  48.62532  60.348206 58.214478 49.143665]]]
