In [None]:
%config InlineBackend.figure_formats = ['svg']
import quimb as qu
import quimb.tensor as qtn
import numpy as np
import collections
import autoray as ar
from autoray import do
import xyzpy as xyz
import cotengra as ctg
import opt_einsum as oe

In [None]:
from tnmpa import (
    setup_vbp,
    iterate_vbp,
    get_messages,
)

In [None]:
# this speeds up (multi-threads) numpy einsum, if torch installed
import torch

def einsum_via_torch(eq, x, y):
    # this is all no-copy
    return torch.einsum(
        eq, 
        torch.from_numpy(x), 
        torch.from_numpy(y),
    ).numpy()

# this subs in the implementation for quimb + cotengra
ar.register_function('numpy', 'einsum', einsum_via_torch)

Create a random SAT style cubic tensor network for testing.

In [None]:
rng = np.random.default_rng(666)

tn = qtn.TN_from_edges_and_fill_fn(
    lambda s: rng.choice([0.0, 1.0], size=s),
    edges=qtn.edges_3d_cubic(50, 50, 50),
    D=3,
)
tn

Setup:

In [None]:
%%time
inputs, outs, exprs, maskin, maskout, output_locs = setup_vbp(tn)

Time first run:

In [None]:
%%time
inputs, outs = iterate_vbp(inputs, outs, exprs, maskin, maskout)

Time 10 runs:

In [None]:
%%time
for _ in range(10):
    # nb the bottleneck in `numpy.einsum` which is *not* parallelized
    inputs, outs = iterate_vbp(inputs, outs, exprs, maskin, maskout)

In [None]:
# the stack of messages of all corner adjacent bonds
xyz.visualize_tensor(outs[2].T, spacing_factor=0.0, figsize=(10, 10))

Extract the messages from the arrays:

In [None]:
messages = get_messages(outs, output_locs)

# jax compiled

In [None]:
# re-initialize
inputs, outs, exprs, maskin, maskout, output_locs = setup_vbp(tn)

In [None]:
import jax
import functools


@jax.jit
def iterate_vbp_jax(inputs, outs):
    # compute new output messages
    for n, arrays in inputs.items():
        outs[n] = exprs[n](*arrays)

    # renormalize to distribution
    for n, out in outs.items():
        out /= out.sum(axis=1).reshape(-1, 1)

    # copy output messages into inputs
    for no, ni, j in maskin:
        # modified syntax for inplace update (jax)
        inputs[ni][j].at[maskin[no, ni, j]].set(outs[no][maskout[no, ni, j]])
        
    return inputs, outs

In [None]:
%%time
# first compile is slow
inputs, outs = iterate_vbp_jax(inputs, outs)

In [None]:
%%time
# next runs should be quick...
for _ in range(10):
    inputs, outs = iterate_vbp_jax(inputs, outs)
    
# for benchmarking
_ = next(iter(outs.values())).block_until_ready()

In [None]:
xyz.visualize_tensor(outs[2].T, spacing_factor=0.0, figsize=(10, 10))

Speedup in this case is ~50x:

In [None]:
11.9 / 218e-3

Messages look slightly different, due to precision?