In [1]:
import os

os.environ["JAX_PLATFORM_NAME"] = "cpu"

import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax.random as jrn

import opt_einsum as oe

from utils.jax_ops import *

In [2]:
import numpy as np

seed = np.random.randint(1e3, 1e4)

key = jrn.key(seed)

In [3]:
d = 4
n = 10
N_summands = 2
chi = 32

summands = []
for i in range(N_summands):
    key, subkey = jrn.split(key)
    summands += [np.random.random((d,) * n)]


key, subkey = jrn.split(key)
candidate = np.random.random((d,) * n)

In [4]:
mps_summands = []
for summand in summands:
    mps_summands += [MPS(summand, max_bond=chi)]

mps_candidate = MPS(candidate, max_bond=chi)
mps_candidate = canonicalize(mps_candidate, 0)

In [5]:
actual_sum = jnp.zeros((d,) * n)
for summand in summands:
    actual_sum += summand
mps_sum = jnp.zeros((d,) * n)
for summand in mps_summands:
    mps_sum += contract_MPS(summand)

In [6]:
inner_ein_str = ""
offs = ord("a")
for i in range(n):
    inner_ein_str += chr(3 * i + offs)
    inner_ein_str += chr(3 * i + 1 + offs)
    inner_ein_str += chr(3 * i + 3 + offs)
    inner_ein_str += ","
for i in range(n):
    inner_ein_str += chr(3 * i + 2 + offs)
    inner_ein_str += chr(3 * i + 1 + offs)
    inner_ein_str += chr(3 * i + 5 + offs)
    inner_ein_str += ","
inner_ein_str = inner_ein_str[:-1]


def inner(A_MPS, B_MPS):
    return oe.contract(inner_ein_str, *A_MPS, *B_MPS).squeeze()


inner_wo_center_ein_strs = []
for site in range(n):
    ein_str = ""
    offs = ord("a")
    for i in range(site):
        ein_str += chr(3 * i + offs)
        ein_str += chr(3 * i + 1 + offs)
        ein_str += chr(3 * i + 3 + offs)
        ein_str += ","
    for i in range(site + 1, n):
        ein_str += chr(3 * i + offs)
        ein_str += chr(3 * i + 1 + offs)
        ein_str += chr(3 * i + 3 + offs)
        ein_str += ","
    for i in range(n):
        ein_str += chr(3 * i + 2 + offs)
        ein_str += chr(3 * i + 1 + offs)
        ein_str += chr(3 * i + 5 + offs)
        ein_str += ","
    ein_str = ein_str[:-1]
    inner_wo_center_ein_strs += [ein_str]


def inner_wo_center(A_MPS, B_MPS, site):
    return oe.contract(
        inner_wo_center_ein_strs[site], *A_MPS[:site], *A_MPS[site + 1 :], *B_MPS
    ).reshape(A_MPS[site].shape)

In [7]:
def slowsum(summands):
    result = summands[0]
    for summand in summands[1:]:
        result = add_MPS_MPS(summand, result)
        result = compress_MPS(result, max_bond=chi)
    return result

In [None]:
def quicksum(candidate, summands, sweeps=1):
    # candidate = canonicalize(candidate, 0)
    for i in range(sweeps):
        for site in range(len(candidate) - 1):
            update = jnp.zeros(candidate[site].shape)
            for summand in summands:
                update += inner_wo_center(candidate, summand, site)
            candidate[site] = 0.5 * update
            candidate = right_shift_canonical_center(candidate, site)

        for site in range(len(candidate) - 1, 0, -1):
            update = jnp.zeros(candidate[site].shape)
            for summand in summands:
                update += inner_wo_center(candidate, summand, site)
            candidate[site] = 0.5 * update
            candidate = left_shift_canonical_center(candidate, site)
    candidate[0] = 2 * candidate[0]
    return candidate

In [12]:
jittest = jax.jit(slowsum)
%time result = jax.block_until_ready(jittest(mps_summands))
%timeit jax.block_until_ready(jittest(mps_summands))
contr_res = contract_MPS(result)
print(jnp.mean((mps_sum-contr_res)**2))
print(jnp.allclose(mps_sum, contr_res))

CPU times: user 11.7 ms, sys: 2.67 ms, total: 14.4 ms
Wall time: 11.6 ms
4.25 ms ± 149 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
0.002468249618651816
False


In [13]:
jittest2 = jax.jit(quicksum, static_argnums=2)
%time result = jax.block_until_ready(jittest2(mps_candidate, mps_summands))
%timeit jax.block_until_ready(jittest2(mps_candidate, mps_summands))
contr_res = contract_MPS(result)
print(jnp.mean((mps_sum-contr_res)**2))
print(jnp.allclose(mps_sum, contr_res))

CPU times: user 1.89 s, sys: 40.1 ms, total: 1.93 s
Wall time: 529 ms
1.74 ms ± 16.7 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
0.00287149927189354
False
