In [None]:
import numpy as np
import jax

jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from utils.jax_ops import truncated_svd

n = 4
vec_shape = (2**n,)
tt_shape = (2, 3, 4, 5)
A = jnp.array(np.random.random(tt_shape))

In [None]:
# (     ) - O
#  | | |    |

rest, s, leaf_11 = truncated_svd(A.reshape((np.prod(tt_shape[:3]), tt_shape[-1])))
bond_dim_11 = len(s)
rest = (rest @ jnp.diag(s)).reshape(tt_shape[:3] + (bond_dim_11,))

In [None]:
# (     ) - O
#  | | |    |
#      O
#      |
rest, s, leaf_12 = truncated_svd(
    rest.transpose(0, 1, 3, 2).reshape(
        np.prod(tt_shape[:2]) * tt_shape[-1], tt_shape[2]
    )
)
bond_dim_12 = len(s)
rest = (
    (rest @ jnp.diag(s))
    .reshape(tt_shape[:2] + (bond_dim_11,) + (bond_dim_12,))
    .transpose(0, 1, 3, 2)
)

In [None]:
# (   ) - O - O
#  | |    |   |
#         O
#         |
rest, s, branch_1 = truncated_svd(
    rest.reshape(np.prod(tt_shape[:2]), bond_dim_12 * bond_dim_11)
)
bond_dim_1 = len(s)
branch_1 = branch_1.reshape((bond_dim_1, bond_dim_12, bond_dim_11))
rest = (rest @ jnp.diag(s)).reshape(tt_shape[:2] + (bond_dim_1,))

In [None]:
# (   ) - O - O - O
#  | |        |   |
#             O
#             |
rest, s, trunk = truncated_svd(rest.reshape(np.prod(tt_shape[:2]), bond_dim_1))
bond_dim_2 = len(s)
rest = (rest @ jnp.diag(s)).reshape(tt_shape[:2] + (bond_dim_2,))

In [6]:
rest.shape

(2, 3, 6)

In [None]:
# (   ) - O - O - O
#  | |        |   |
#    O        O
#    |        |
rest, s, leaf_21 = truncated_svd(
    rest.transpose(0, 2, 1).reshape(tt_shape[0] * bond_dim_2, tt_shape[1])
)
bond_dim_21 = len(s)
rest = (
    (rest @ jnp.diag(s))
    .reshape((tt_shape[0], bond_dim_2, bond_dim_21))
    .transpose(0, 2, 1)
)

In [8]:
rest.shape

(2, 3, 6)

In [None]:
# O - O - O - O - O
# |   |       |   |
#     O       O
#     |       |
leaf_22, s, branch_2 = truncated_svd(
    rest.reshape(tt_shape[0], bond_dim_21 * bond_dim_2)
)
bond_dim_22 = len(s)
branch_2 = branch_2.reshape(bond_dim_22, bond_dim_21, bond_dim_2)
leaf_22 = (leaf_22 @ jnp.diag(s)).transpose()

In [10]:
print(trunk.shape)
print(branch_2.shape, branch_1.shape)
print(
    leaf_22.shape,
    leaf_21.shape,
    leaf_12.shape,
    leaf_11.shape,
)

(6, 6)
(2, 3, 6) (6, 4, 5)
(2, 2) (3, 3) (4, 4) (5, 5)


In [None]:
def contract_tree(trunk, branch_2, branch_1, leaf_22, leaf_21, leaf_12, leaf_11):
    return jnp.einsum(
        "ab, cda, bef, cg, dh, ei, fj",
        trunk,
        branch_2,
        branch_1,
        leaf_22,
        leaf_21,
        leaf_12,
        leaf_11,
    )

In [None]:
A_contracted = contract_tree(
    trunk, branch_2, branch_1, leaf_22, leaf_21, leaf_12, leaf_11
)

In [13]:
jnp.allclose(A, A_contracted)

Array(True, dtype=bool)

In [14]:
A.flatten() @ A.flatten(), leaf_22.flatten() @ leaf_22.flatten()

(Array(40.42318013, dtype=float64), Array(40.42318013, dtype=float64))