In [1]:
import os

os.environ["JAX_PLATFORM_NAME"] = "cpu"

import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax.random as jrn

import opt_einsum as oe

from utils.jax_ops import *
from utils.mpos import kron_delta

In [2]:
import numpy as np

seed = np.random.randint(1e3, 1e4)

key = jrn.key(seed)

In [None]:
d = 2
n = 5
N_factors = 2
chi = None

factors = []
for i in range(N_factors):
    key, subkey = jrn.split(key)
    factors += [np.random.random((d,) * n)]


key, subkey = jrn.split(key)
candidate = np.random.random((d,) * n)
kron = jnp.array(kron_delta(d))
multi_kron = [jnp.array(kron_delta(d, rank=N_factors + 1))] * n

In [4]:
mps_factors = []
for factor in factors:
    mps_factors += [MPS(factor, max_bond=chi)]

mps_candidate = MPS(candidate, max_bond=chi)
# mps_candidate = canonicalize(mps_candidate, 0)

In [5]:
actual_product = jnp.ones((d,) * n)
for factor in factors:
    actual_product *= factor
mps_product = jnp.ones((d,) * n)
for factor in mps_factors:
    mps_product *= contract_MPS(factor)

In [None]:
inner_ein_str = ""
offs = ord("a")
for i in range(n):
    inner_ein_str += chr(3 * i + offs)
    inner_ein_str += chr(3 * i + 1 + offs)
    inner_ein_str += chr(3 * i + 3 + offs)
    inner_ein_str += ","
for i in range(n):
    inner_ein_str += chr(3 * i + 2 + offs)
    inner_ein_str += chr(3 * i + 1 + offs)
    inner_ein_str += chr(3 * i + 5 + offs)
    inner_ein_str += ","
inner_ein_str = inner_ein_str[:-1]


def inner(A_MPS, B_MPS):
    return oe.contract(inner_ein_str, *A_MPS, *B_MPS).squeeze()


inner_wo_center_ein_strs = []
for site in range(n):
    ein_str = ""
    offs = ord("a")
    for i in range(site):
        ein_str += chr(3 * i + offs)
        ein_str += chr(3 * i + 1 + offs)
        ein_str += chr(3 * (i + 1) + offs)
        ein_str += ","
    for i in range(site + 1, n):
        ein_str += chr(3 * i + offs)
        ein_str += chr(3 * i + 1 + offs)
        ein_str += chr(3 * (i + 1) + offs)
        ein_str += ","
    for i in range(n):
        ein_str += chr(3 * i + 2 + offs)
        ein_str += chr(3 * i + 1 + offs)
        ein_str += chr(3 * (i + 1) + 2 + offs)
        ein_str += ","
    ein_str = ein_str[:-1]
    inner_wo_center_ein_strs += [ein_str]


def inner_wo_center(A_MPS, B_MPS, site):
    return oe.contract(
        inner_wo_center_ein_strs[site], *A_MPS[:site], *A_MPS[site + 1 :], *B_MPS
    ).reshape(A_MPS[site].shape)

In [7]:
multiply_ein_strs = []
for site in range(n):
    ein_str = ""
    offs = ord("a")
    multiplier = 2 * N_factors + 2
    for i in range(site):
        ein_str += chr(multiplier * i + offs)
        ein_str += chr(multiplier * i + 1 + offs)
        ein_str += chr(multiplier * (i + 1) + offs)
        ein_str += ","
    for i in range(site + 1, n):
        ein_str += chr(multiplier * i + offs)
        ein_str += chr(multiplier * i + 1 + offs)
        ein_str += chr(multiplier * (i + 1) + offs)
        ein_str += ","
    for i in range(n):
        ein_str += chr(multiplier * i + 1 + offs)
        for j in range(N_factors):
            ein_str += chr(multiplier * i + j + 2 + offs)
        ein_str += ","
    for j in range(N_factors):
        for i in range(n):
            ein_str += chr(multiplier * i + N_factors + j + 2 + offs)
            ein_str += chr(multiplier * i + j + 2 + offs)
            ein_str += chr(multiplier * (i + 1) + N_factors + j + 2 + offs)
            ein_str += ","
    ein_str = ein_str[:-1]
    multiply_ein_strs += [ein_str]

multiply_ein_strs[0]


def local_multiply(A_MPS, mps_factors, site):
    unpacked = []
    for mps_factor in mps_factors:
        unpacked += mps_factor
    return oe.contract(
        multiply_ein_strs[site],
        *A_MPS[:site],
        *A_MPS[site + 1 :],
        *multi_kron,
        *unpacked
    ).reshape(A_MPS[site].shape)

In [8]:
def slowmul(factors):
    result = factors[0]
    for factor in factors[1:]:
        result = multiply_MPS_MPS(kron, factor, result)
        result = compress_MPS(result, max_bond=chi)
    return result

In [9]:
def quickmul(candidate, factors, sweeps=1):
    candidate = canonicalize(candidate, 0)
    for i in range(sweeps):
        for site in range(len(candidate) - 1):
            candidate[site] = 0.5 * local_multiply(candidate, factors, site)
            candidate = right_shift_canonical_center(candidate, site)

        for site in range(len(candidate) - 1, 0, -1):
            candidate[site] = 0.5 * local_multiply(candidate, factors, site)
            candidate = left_shift_canonical_center(candidate, site)
    candidate[0] = 2 * candidate[0]
    return candidate

In [10]:
jittest = jax.jit(slowmul)
%time result = jax.block_until_ready(jittest(mps_factors))
%timeit jax.block_until_ready(jittest(mps_factors))
contr_res = contract_MPS(result)
print(jnp.mean((mps_product-contr_res)**2))
print(jnp.allclose(mps_product, contr_res))

CPU times: user 398 ms, sys: 15.1 ms, total: 413 ms
Wall time: 114 ms
83.6 μs ± 1.86 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
6.463159463419941e-32
True


In [11]:
jittest2 = jax.jit(quickmul, static_argnums=2)
%time result = jax.block_until_ready(jittest2(mps_candidate, mps_factors))
%timeit jax.block_until_ready(jittest2(mps_candidate, mps_factors))
contr_res = contract_MPS(result)
print(jnp.mean((mps_product-contr_res)**2))
print(jnp.allclose(mps_product, contr_res))

CPU times: user 814 ms, sys: 21 ms, total: 835 ms
Wall time: 230 ms
114 μs ± 3.6 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
2.9997769477052185e-32
True


In [None]:
jnp.sum(mps_candidate[0] ** 2)

Array(11.53979079, dtype=float64)

In [13]:
inner(mps_candidate, mps_candidate)

Array(11.53979079, dtype=float64)

In [None]:
oe.contract(
    "abc, dbc, ebc, fbc",
    mps_candidate[-1],
    mps_candidate[-1],
    mps_candidate[-1],
    mps_candidate[-1],
)

Array([[[[ 0.52237165,  0.1033699 ],
         [ 0.1033699 ,  0.47762835]],

        [[ 0.1033699 ,  0.47762835],
         [ 0.47762835, -0.1033699 ]]],


       [[[ 0.1033699 ,  0.47762835],
         [ 0.47762835, -0.1033699 ]],

        [[ 0.47762835, -0.1033699 ],
         [-0.1033699 ,  0.52237165]]]], dtype=float64)