In [1]:
import os

os.environ["JAX_PLATFORM_NAME"] = "cpu"

import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax.random as jrn

import opt_einsum as oe

from utils.jax_ops import *
from utils.scale_ordering import *
from utils.mpos import kron_delta

In [2]:
from utils.temporally_decaying_jets import initial_fields

In [3]:
import numpy as np

seed = np.random.randint(1e3, 1e4)

key = jrn.key(seed)

In [None]:
d = 4
n = 11
chi = 4 ** int(n / 2 - 1)

key, subkey = jrn.split(key)
rand = np.random.random((d,) * n)

tdj = transform_tensor(initial_fields(n=n)[0])

key, subkey = jrn.split(key)
candidate = np.random.random((d,) * n)

In [5]:
mps_tdj = MPS(tdj, max_bond=None)

mps_rand = MPS(rand, max_bond=None)

mps_candidate = MPS(candidate, max_bond=chi)
mps_candidate = canonicalize(mps_candidate, 0)

In [None]:
inner_ein_str = ""
offs = ord("a")
for i in range(n):
    inner_ein_str += chr(3 * i + offs)
    inner_ein_str += chr(3 * i + 1 + offs)
    inner_ein_str += chr(3 * i + 3 + offs)
    inner_ein_str += ","
for i in range(n):
    inner_ein_str += chr(3 * i + 2 + offs)
    inner_ein_str += chr(3 * i + 1 + offs)
    inner_ein_str += chr(3 * i + 5 + offs)
    inner_ein_str += ","
inner_ein_str = inner_ein_str[:-1]


def inner(A_MPS, B_MPS):
    return oe.contract(inner_ein_str, *A_MPS, *B_MPS).squeeze()


inner_wo_center_ein_strs = []
for site in range(n):
    ein_str = ""
    offs = ord("a")
    for i in range(site):
        ein_str += chr(3 * i + offs)
        ein_str += chr(3 * i + 1 + offs)
        ein_str += chr(3 * i + 3 + offs)
        ein_str += ","
    for i in range(site + 1, n):
        ein_str += chr(3 * i + offs)
        ein_str += chr(3 * i + 1 + offs)
        ein_str += chr(3 * i + 3 + offs)
        ein_str += ","
    for i in range(n):
        ein_str += chr(3 * i + 2 + offs)
        ein_str += chr(3 * i + 1 + offs)
        ein_str += chr(3 * i + 5 + offs)
        ein_str += ","
    ein_str = ein_str[:-1]
    inner_wo_center_ein_strs += [ein_str]

print(inner_wo_center_ein_strs[0])


def inner_wo_center(A_MPS, B_MPS, site):
    return oe.contract(
        inner_wo_center_ein_strs[site], *A_MPS[:site], *A_MPS[site + 1 :], *B_MPS
    ).reshape(A_MPS[site].shape)

deg,ghj,jkm,mnp,pqs,stv,vwy,yz|,|},,cbf,fei,ihl,lko,onr,rqu,utx,xw{,{z~,~},


In [None]:
inner_wo_center_tensor_ein_strs = []
for site in range(n):
    ein_str = ""
    offs = ord("a")
    for i in range(site):
        ein_str += chr(3 * i + offs)
        ein_str += chr(3 * i + 1 + offs)
        ein_str += chr(3 * i + 3 + offs)
        ein_str += ","
    for i in range(site + 1, n):
        ein_str += chr(3 * i + offs)
        ein_str += chr(3 * i + 1 + offs)
        ein_str += chr(3 * i + 3 + offs)
        ein_str += ","
    for i in range(n):
        ein_str += chr(3 * i + 1 + offs)
    inner_wo_center_tensor_ein_strs += [ein_str]

print(inner_wo_center_tensor_ein_strs[0])


def inner_wo_center_tensor(MPS, tensor, site):
    return oe.contract(
        inner_wo_center_tensor_ein_strs[site], *MPS[:site], *MPS[site + 1 :], tensor
    ).reshape(MPS[site].shape)

deg,ghj,jkm,mnp,pqs,stv,vwy,yz|,|},,behknqtwz}


In [8]:
def slowcomp(initial):
    return compress_MPS(initial, max_bond=chi)

In [9]:
def slowcomp_tens_to_MPS(initial):
    return MPS(initial, max_bond=chi)

In [10]:
def quickcomp(candidate, initial, sweeps=1):
    # candidate = canonicalize(candidate, 0)
    for i in range(sweeps):
        for site in range(len(candidate) - 1):
            update = inner_wo_center(candidate, initial, site)
            candidate[site] = 0.5 * update
            candidate = right_shift_canonical_center(candidate, site)

        for site in range(len(candidate) - 1, 0, -1):
            update = inner_wo_center(candidate, initial, site)
            candidate[site] = 0.5 * update
            candidate = left_shift_canonical_center(candidate, site)
    candidate[0] = 2 * candidate[0]
    return candidate

In [11]:
def quickcomp_tens_to_MPS(candidate, initial, sweeps=1):
    # candidate = canonicalize(candidate, 0)
    for i in range(sweeps):
        for site in range(len(candidate) - 1):
            update = inner_wo_center_tensor(candidate, initial, site)
            candidate[site] = 0.5 * update
            candidate = right_shift_canonical_center(candidate, site)

        for site in range(len(candidate) - 1, 0, -1):
            update = inner_wo_center_tensor(candidate, initial, site)
            candidate[site] = 0.5 * update
            candidate = left_shift_canonical_center(candidate, site)
    candidate[0] = 2 * candidate[0]
    return candidate

In [12]:
jittest = jax.jit(slowcomp)
%time result = jax.block_until_ready(jittest(mps_rand))
%timeit jax.block_until_ready(jittest(mps_rand))
contr_res = contract_MPS(result)
print(jnp.mean((rand-contr_res)**2))
print(jnp.allclose(rand, contr_res))

CPU times: user 2.91 s, sys: 257 ms, total: 3.16 s
Wall time: 1.38 s
1.12 s ± 27.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
0.061525267785417434
False


In [13]:
jittest2 = jax.jit(quickcomp, static_argnums=2)
%time result = jax.block_until_ready(jittest2(mps_candidate, mps_rand))
%timeit jax.block_until_ready(jittest2(mps_candidate, mps_rand))
contr_res = contract_MPS(result)
print(jnp.mean((rand-contr_res)**2))
print(jnp.allclose(rand, contr_res))

CPU times: user 4.09 s, sys: 196 ms, total: 4.29 s
Wall time: 1.01 s
587 ms ± 6.63 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
0.06667507500788632
False


In [14]:
jittest = jax.jit(slowcomp)
%time result = jax.block_until_ready(jittest(mps_tdj))
%timeit jax.block_until_ready(jittest(mps_tdj))
contr_res = contract_MPS(result)
print(jnp.mean((tdj-contr_res)**2))
print(jnp.allclose(tdj, contr_res))

CPU times: user 2.16 s, sys: 212 ms, total: 2.37 s
Wall time: 1.03 s
1 s ± 13.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.7172792331750996e-26
True


In [15]:
jittest2 = jax.jit(quickcomp, static_argnums=2)
%time result = jax.block_until_ready(jittest2(mps_candidate, mps_tdj))
%timeit jax.block_until_ready(jittest2(mps_candidate, mps_tdj))
contr_res = contract_MPS(result)
print(jnp.mean((tdj-contr_res)**2))
print(jnp.allclose(tdj, contr_res))

CPU times: user 2.84 s, sys: 164 ms, total: 3 s
Wall time: 720 ms
591 ms ± 3.27 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.719038503442886e-26
True


In [None]:
jittest = jax.jit(slowcomp_tens_to_MPS)
%time result = jax.block_until_ready(jittest(tdj))
%timeit jax.block_until_ready(jittest(tdj))
contr_res = contract_MPS(result)
print(jnp.mean((tdj-contr_res)**2))
print(jnp.allclose(tdj, contr_res))

CPU times: user 3.11 s, sys: 256 ms, total: 3.37 s
Wall time: 1.47 s
1.36 s ± 22.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.724469387572018e-26
True


In [17]:
jittest2 = jax.jit(quickcomp_tens_to_MPS, static_argnums=2)
%time result = jax.block_until_ready(jittest2(mps_candidate, tdj))
%timeit jax.block_until_ready(jittest2(mps_candidate, tdj))
contr_res = contract_MPS(result)
print(jnp.mean((tdj-contr_res)**2))
print(jnp.allclose(tdj, contr_res))

CPU times: user 2.82 s, sys: 177 ms, total: 3 s
Wall time: 729 ms
392 ms ± 44.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
6.1304107473954675e-31
True
