In [1]:
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax.random as jrn

import opt_einsum as oe

# import numpy as np
from utils.jax_ops import *

In [2]:
import numpy as np

seed = np.random.randint(1e3, 1e4)

key = jrn.key(seed)

In [3]:
d = 4
n = 13
N_summands = 100
chi = None

summands = []
for i in range(N_summands):
    key, subkey = jrn.split(key)
    summands = [jrn.uniform(subkey, (d,) * n)]


key, subkey = jrn.split(key)
candidate = jrn.uniform(subkey, (d,) * n)

In [None]:
mps_summands = []
for summand in summands:
    mps_summands += [MPS(summand, max_bond=chi)]

mps_candidate = MPS(candidate, max_bond=chi)
mps_candidate = canonicalize(mps_candidate, 0)

In [None]:
actual_sum = jnp.zeros((d,) * n)
for summand in summands:
    actual_sum += summand
mps_sum = jnp.zeros((d,) * n)
for summand in mps_summands:
    mps_sum += contract_MPS(summand)

In [None]:
def inner(A_MPS, B_MPS):
    ein_str = ""
    offs = ord("a")
    for i in range(len(A_MPS)):
        ein_str += chr(3 * i + offs)
        ein_str += chr(3 * i + 1 + offs)
        ein_str += chr(3 * i + 3 + offs)
        ein_str += ","
    for i in range(len(B_MPS)):
        ein_str += chr(3 * i + 2 + offs)
        ein_str += chr(3 * i + 1 + offs)
        ein_str += chr(3 * i + 5 + offs)
        ein_str += ","
    return oe.contract(ein_str, *A_MPS, *B_MPS).squeeze()


def inner_wo_center(A_MPS, B_MPS, site):
    ein_str = ""
    offs = ord("a")
    for i in range(site):
        ein_str += chr(3 * i + offs)
        ein_str += chr(3 * i + 1 + offs)
        ein_str += chr(3 * i + 3 + offs)
        ein_str += ","
    for i in range(site + 1, len(A_MPS)):
        ein_str += chr(3 * i + offs)
        ein_str += chr(3 * i + 1 + offs)
        ein_str += chr(3 * i + 3 + offs)
        ein_str += ","
    for i in range(len(B_MPS)):
        ein_str += chr(3 * i + 2 + offs)
        ein_str += chr(3 * i + 1 + offs)
        ein_str += chr(3 * i + 5 + offs)
        ein_str += ","
    ein_str = ein_str[:-1]
    return oe.contract(ein_str, *A_MPS[:site], *A_MPS[site + 1 :], *B_MPS).reshape(
        A_MPS[site].shape
    )

In [None]:
def slowsum(summands):
    result = summands[0]
    for summand in summands[1:]:
        result = add_MPS_MPS(summand, result)
        result = compress_MPS(result, max_bond=chi)
    return result

In [None]:
def quicksum(candidate, summands, sweeps=1):

    for i in range(sweeps):
        for site in range(len(candidate) - 1):
            update = jnp.zeros(candidate[site].shape)
            for summand in summands:
                update = inner_wo_center(candidate, summand, site)
            candidate[site] = 0.5 * update
            candidate = right_shift_canonical_center(candidate, site)

        for site in range(len(candidate) - 1, 0, -1):
            update = jnp.zeros(candidate[site].shape)
            for summand in summands:
                update = inner_wo_center(candidate, summand, site)
            candidate[site] = 0.5 * update
            candidate = left_shift_canonical_center(candidate, site)

    candidate[0] = 2 * candidate[0]
    return candidate

In [None]:
jitslow = jax.jit(slowsum)
%time result = jitslow(mps_summands)[0].block_until_ready()
result = jitslow(mps_summands)
contr_res = contract_MPS(result)
print(jnp.mean((contr_res-mps_sum)**2))
print(jnp.allclose(mps_sum, contr_res))

CPU times: user 18.7 ms, sys: 3.26 ms, total: 21.9 ms
Wall time: 17.3 ms
0.0
True


In [None]:
%timeit res = jitslow(mps_summands)[0].block_until_ready()

973 μs ± 15.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
jitslow(mps_summands)[0].block_until_ready()

Array([[[-266.33335778,  100.48609932,  -20.43973416,   76.48708127],
        [-266.67782865,  -33.99638226,  123.05040435,    6.21520889],
        [-266.45059848,   31.61285099,  -29.48033044, -120.21301419],
        [-266.65408238,  -97.95455472,  -73.18837856,   37.5104326 ]]],      dtype=float64)

In [None]:
jitquick = jax.jit(quicksum)
%time result = jitquick(mps_candidate, mps_summands)[0].block_until_ready()
result = jitquick(mps_candidate, mps_summands)
contr_res = contract_MPS(result)
print(jnp.mean((contr_res-mps_sum)**2))
print(jnp.allclose(mps_sum, contr_res))

CPU times: user 3.24 s, sys: 126 ms, total: 3.36 s
Wall time: 847 ms
2.1830290798939257e-30
True


In [None]:
%timeit res = jitquick(mps_candidate, mps_summands)[0].block_until_ready()

433 ms ± 4.72 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
jitquick(mps_candidate, mps_summands)[0].block_until_ready()

Array([[[-266.34269518,   96.86677591,   50.00556801,   66.9089648 ],
        [-266.66665863,   20.77519605,   17.752829  , -124.87982464],
        [-266.4381814 , -108.92955403,   57.68219208,   33.6654854 ],
        [-266.6683209 ,   -8.77405794, -125.22886741,   24.32733139]]],      dtype=float64)