In [1]:
import jax.numpy as jnp
from jax import jit
from fractrics._components.HMM.forward.factor import update as f_update
from fractrics._components.HMM.forward.base import update as b_update

In [21]:
nfactors = 5
nstates = 2
nstates_total = nstates ** nfactors
prob = 1.0 / nstates
T = 10000

data_lik = jnp.abs(jnp.cos(jnp.arange(T)))
matrices = jnp.full([nfactors, nstates, nstates], prob)

joint_matrix = matrices[0]
for i in range(1, matrices.shape[0]):
    joint_matrix = jnp.kron(joint_matrix, matrices[i])
matrices = tuple(jnp.full((nstates, nstates), prob) for _ in range(nfactors)) # make tuple for f_update
log_joint_matrix = jnp.log(joint_matrix)
data_log_lik = jnp.log(data_lik.reshape(T, -1))

prior_tensor = jnp.full((nstates,) * nfactors, 1.0/nstates_total)
prior = prior_tensor.flatten()
lik_shape = (T,) + (nstates,) * nfactors
data_lik = jnp.full(lik_shape, 1.0)  # uniform likelihood

print(matrices)
print(joint_matrix)

(Array([[0.5, 0.5],
       [0.5, 0.5]], dtype=float32, weak_type=True), Array([[0.5, 0.5],
       [0.5, 0.5]], dtype=float32, weak_type=True), Array([[0.5, 0.5],
       [0.5, 0.5]], dtype=float32, weak_type=True), Array([[0.5, 0.5],
       [0.5, 0.5]], dtype=float32, weak_type=True), Array([[0.5, 0.5],
       [0.5, 0.5]], dtype=float32, weak_type=True))
[[0.03125 0.03125 0.03125 ... 0.03125 0.03125 0.03125]
 [0.03125 0.03125 0.03125 ... 0.03125 0.03125 0.03125]
 [0.03125 0.03125 0.03125 ... 0.03125 0.03125 0.03125]
 ...
 [0.03125 0.03125 0.03125 ... 0.03125 0.03125 0.03125]
 [0.03125 0.03125 0.03125 ... 0.03125 0.03125 0.03125]
 [0.03125 0.03125 0.03125 ... 0.03125 0.03125 0.03125]]


In [22]:
import time
f_factor = jit(f_update)
f_base   = jit(b_update)


def benchmark(fn, *args, runs=200):
    # warmup
    out = fn(*args)
    out[0].block_until_ready()

    start = time.time()
    for _ in range(runs):
        out = fn(*args)
        out[0].block_until_ready()
    return (time.time() - start) / runs

t_factor = benchmark(f_factor, prior_tensor, data_lik, matrices)
t_base   = benchmark(f_base, prior, data_log_lik, log_joint_matrix)

print(f"with {T} observations and {nfactors} factors: \n")
print("mean factor time:", t_factor)
print("mean base time:  ", t_base)


with 10000 observations and 5 factors: 

mean factor time: 0.017981290817260742
mean base time:   0.01441245198249817
