In [1]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax.random as jrn

import opt_einsum as oe

# import numpy as np
from utils.jax_ops import *

In [2]:
import numpy as np

seed = np.random.randint(1e3, 1e4)

key = jrn.key(seed)
key, subkey = jrn.split(key)

In [3]:
n = 10
key, subkey = jrn.split(key)
A = jrn.uniform(subkey, (2,) * n)
key, subkey = jrn.split(key)
B = jrn.uniform(subkey, (2,) * n)
key, subkey = jrn.split(key)
C = jrn.uniform(subkey, (2,) * n)
key, subkey = jrn.split(key)
D = jrn.uniform(subkey, (2,) * n)

In [4]:
mps_A = MPS(A)
mps_B = MPS(B)
mps_C = MPS(C)
mps_D = MPS(D)

In [5]:
mps_A = canonicalize(mps_A, 0)
mps_B = canonicalize(mps_B, 0)
mps_C = canonicalize(mps_C, 0)
mps_D = canonicalize(mps_D, 0)
contr_ApBpD = contract_MPS(mps_A) + contract_MPS(mps_B) + contract_MPS(mps_D)

In [6]:
def inner(A_MPS, B_MPS):
    ein_str = ""
    offs = ord("a")
    for i in range(len(A_MPS)):
        ein_str += chr(3 * i + offs)
        ein_str += chr(3 * i + 1 + offs)
        ein_str += chr(3 * i + 3 + offs)
        ein_str += ","
    for i in range(len(B_MPS)):
        ein_str += chr(3 * i + 2 + offs)
        ein_str += chr(3 * i + 1 + offs)
        ein_str += chr(3 * i + 5 + offs)
        ein_str += ","
    return oe.contract(ein_str, *A_MPS, *B_MPS).squeeze()


def inner_wo_center(A_MPS, B_MPS, site):
    ein_str = ""
    offs = ord("a")
    for i in range(site):
        ein_str += chr(3 * i + offs)
        ein_str += chr(3 * i + 1 + offs)
        ein_str += chr(3 * i + 3 + offs)
        ein_str += ","
    for i in range(site + 1, len(A_MPS)):
        ein_str += chr(3 * i + offs)
        ein_str += chr(3 * i + 1 + offs)
        ein_str += chr(3 * i + 3 + offs)
        ein_str += ","
    for i in range(len(B_MPS)):
        ein_str += chr(3 * i + 2 + offs)
        ein_str += chr(3 * i + 1 + offs)
        ein_str += chr(3 * i + 5 + offs)
        ein_str += ","
    ein_str = ein_str[:-1]
    return oe.contract(ein_str, *A_MPS[:site], *A_MPS[site + 1 :], *B_MPS).reshape(
        A_MPS[site].shape
    )

In [7]:
for i in range(1):
    for site in range(len(mps_C) - 1):
        mps_C[site] = 0.5 * (
            inner_wo_center(mps_C, mps_A, site)
            + inner_wo_center(mps_C, mps_B, site)
            + inner_wo_center(mps_C, mps_D, site)
        )
        mps_C = right_shift_canonical_center(mps_C, site)

    for site in range(len(mps_C) - 1, 0, -1):
        mps_C[site] = 0.5 * (
            inner_wo_center(mps_C, mps_A, site)
            + inner_wo_center(mps_C, mps_B, site)
            + inner_wo_center(mps_C, mps_D, site)
        )
        mps_C = left_shift_canonical_center(mps_C, site)

    contr_C = contract_MPS(mps_C)
    print(jnp.mean((2 * contr_C - contr_ApBpD) ** 2))

6.80091077577843e-30


In [8]:
print(jnp.allclose(contr_ApBpD, 2 * contr_C))

True
