# Load libraries

In [None]:
using Pkg
Pkg.activate(".");  Pkg.resolve()

Pkg.add([
    "Gen",
    "PyCall",
    "FunctionalCollections",
    "StatsFuns",
    "BenchmarkTools",
    "Test",
    "Printf"
])
ENV["PYTHON"] = "/usr/bin/python3"
Pkg.build("PyCall")
using PyCall
run(`python3 -m pip install -U opt_einsum cupy-cuda12x`)

try
    run(`python3 -m pip install -U "jax[cpu]"`)
catch e
    @warn "JAX install (CPU) failed; continuing with NumPy/CuPy only" exception=e
end


In [None]:
using Pkg

Pkg.activate("gve-test"; shared=false)
Pkg.add(["Gen", "BenchmarkTools"])
Pkg.add(PackageSpec(name="Gen", version="0.4.8"))
Pkg.add(PackageSpec(
    url = "https://github.com/abtinmU/GenVariableElimination.jl.git",
    rev = "feature/einsum-jax",
))
Pkg.precompile()

Pkg.status("GenVariableElimination")

In [3]:
using GenVariableElimination
pathof(GenVariableElimination)
Pkg.status("GenVariableElimination")

[32m[1mStatus[22m[39m `/content/gve-test/Project.toml`
  [90m[43085b26] [39mGenVariableElimination v0.2.0 `https://github.com/abtinmU/GenVariableElimination.jl.git#feature/einsum-jax`


# correctness tests
These tests check the correctness of the internal data structures. Similar tests were already implemented for the original GenVariableElimination.jl. I adapted them to check the correctness of the new method while comparing them with the original native method.

In [10]:
##############################
# Compares engines  :native vs :einsum
##############################

using Gen
using Test
using Printf
using GenVariableElimination
const M = GenVariableElimination

@gen function backwards_sampler_dml_positional(
    trace,
    addrs,
    latents::Dict{Any,M.Latent},
    observations::Dict{Any,M.Observation},
    engine::Union{Symbol,String}
)
    s = M.generate_backwards_sampler_fixed_trace(
        trace, addrs, latents, observations; engine=Symbol(engine)
    )
    {*} ~ s()
end

# -------- utilities --------
normed(arr) = arr / sum(arr)
pretty(b) = b ? "PASS" : "FAIL"

# call VE with or without engine keyword
function call_variable_elimination(fg, order; engine=:native, backend=:auto, optimize="auto", dtype=:float64, jit=true, cache=true)
    return M.variable_elimination(fg, order; engine=engine, backend=backend, optimize=optimize, dtype=dtype, jit=jit, cache=cache)
end

# route elimination by engine (uses eliminate_einsum for einsum)
function call_eliminate(fg, addr; engine=:native, backend=:auto, optimize="auto", dtype=:float64, jit=true, cache=true)
    if engine == :einsum
        return M.eliminate_einsum(fg, addr; backend=backend, optimize=optimize, dtype=dtype, jit=jit, cache=cache)
    else
        return M.eliminate(fg, addr)
    end
end

# make engine-aware samplers
function call_generate_fixed_trace(trace, addrs; engine=:native, backend=:auto, optimize="auto", dtype=:float64, jit=true, cache=true, latents=nothing, observations=nothing)
    if latents === nothing
        return M.generate_backwards_sampler_fixed_trace(trace, addrs; engine=engine, backend=backend, optimize=optimize, dtype=dtype, jit=jit, cache=cache)
    else
        return M.generate_backwards_sampler_fixed_trace(trace, addrs, latents, observations; engine=engine, backend=backend, optimize=optimize, dtype=dtype, jit=jit, cache=cache)
    end
end

function call_generate_fixed_structure(trace, addrs; engine=:native, backend=:auto, optimize="auto", dtype=:float64, jit=true, cache=true)
    return M.generate_backwards_sampler_fixed_structure(trace, addrs; engine=engine, backend=backend, optimize=optimize, dtype=dtype, jit=jit, cache=cache)
end

#  tests

function test_compiling_factor_graph()::Bool
    @gen function foo()
        x ~ bernoulli(0.6)
        y ~ bernoulli(x ? 0.2 : 0.9)
        z ~ bernoulli((x && y) ? 0.4 : 0.9)
        w ~ bernoulli(z ? 0.4 : 0.5)
    end
    trace = Gen.simulate(foo, ())
    latents = Dict{Any,Latent}()
    latents[:x] = M.Latent([true,false], [])
    latents[:y] = M.Latent([true,false], [:x])
    latents[:z] = M.Latent([true,false], [:x,:y])
    latents[:w] = M.Latent([true,false], [:z])
    observations = Dict{Any,Observation}()
    fg = M.compile_trace_to_factor_graph(trace, latents, observations)

    ok_nodes = (fg.num_factors == 4) && (length(fg.var_nodes) == 4)
    function test_node(addr)
        node = M.idx_to_var_node(fg, M.addr_to_idx(fg, addr))
        return node.addr == addr &&
               M.num_values(node) == 2 &&
               M.idx_to_value(node, 1) == true &&
               M.idx_to_value(node, 2) == false &&
               M.value_to_idx(node, true) == 1 &&
               M.value_to_idx(node, false) == 2
    end
    ok_nodes &= all(test_node.([:x,:y,:z,:w]))

    all_factors = Set{typeof(first(collect(first(values(fg.var_nodes)).factor_nodes)))}()
    for node in values(fg.var_nodes); union!(all_factors, M.factor_nodes(node)); end
    ok_nfactors = length(all_factors) == 4

    # f1
    f1 = first(filter(fn -> (length(M.vars(fn))==1 && M.addr_to_idx(fg,:x) in M.vars(fn)), all_factors))
    f1_xtrue  = M.factor_value(fg, f1, Dict(M.addr_to_idx(fg,:x)=>true))
    f1_xfalse = M.factor_value(fg, f1, Dict(M.addr_to_idx(fg,:x)=>false))
    ok_f1 = isapprox(normed([f1_xtrue,f1_xfalse]), normed([0.6,0.4]))

    # f2
    f2 = first(filter(fn -> (length(M.vars(fn))==2 && M.addr_to_idx(fg,:x) in M.vars(fn) && M.addr_to_idx(fg,:y) in M.vars(fn)), all_factors))
    F2 = [
        M.factor_value(fg, f2, Dict(M.addr_to_idx(fg,:x)=>true,  M.addr_to_idx(fg,:y)=>true)),
        M.factor_value(fg, f2, Dict(M.addr_to_idx(fg,:x)=>true,  M.addr_to_idx(fg,:y)=>false)),
        M.factor_value(fg, f2, Dict(M.addr_to_idx(fg,:x)=>false, M.addr_to_idx(fg,:y)=>true)),
        M.factor_value(fg, f2, Dict(M.addr_to_idx(fg,:x)=>false, M.addr_to_idx(fg,:y)=>false))
    ]
    ok_f2 = isapprox(normed(F2), normed([0.2,0.8,0.9,0.1]))

    # f3
    f3 = first(filter(fn -> (length(M.vars(fn))==3 &&
                             M.addr_to_idx(fg,:x) in M.vars(fn) &&
                             M.addr_to_idx(fg,:y) in M.vars(fn) &&
                             M.addr_to_idx(fg,:z) in M.vars(fn)), all_factors))
    F3 = [
        M.factor_value(fg, f3, Dict(M.addr_to_idx(fg,:x)=>true,  M.addr_to_idx(fg,:y)=>true,  M.addr_to_idx(fg,:z)=>true)),
        M.factor_value(fg, f3, Dict(M.addr_to_idx(fg,:x)=>true,  M.addr_to_idx(fg,:y)=>false, M.addr_to_idx(fg,:z)=>true)),
        M.factor_value(fg, f3, Dict(M.addr_to_idx(fg,:x)=>false, M.addr_to_idx(fg,:y)=>true,  M.addr_to_idx(fg,:z)=>true)),
        M.factor_value(fg, f3, Dict(M.addr_to_idx(fg,:x)=>false, M.addr_to_idx(fg,:y)=>false, M.addr_to_idx(fg,:z)=>true)),
        M.factor_value(fg, f3, Dict(M.addr_to_idx(fg,:x)=>true,  M.addr_to_idx(fg,:y)=>true,  M.addr_to_idx(fg,:z)=>false)),
        M.factor_value(fg, f3, Dict(M.addr_to_idx(fg,:x)=>true,  M.addr_to_idx(fg,:y)=>false, M.addr_to_idx(fg,:z)=>false)),
        M.factor_value(fg, f3, Dict(M.addr_to_idx(fg,:x)=>false, M.addr_to_idx(fg,:y)=>true,  M.addr_to_idx(fg,:z)=>false)),
        M.factor_value(fg, f3, Dict(M.addr_to_idx(fg,:x)=>false, M.addr_to_idx(fg,:y)=>false, M.addr_to_idx(fg,:z)=>false))
    ]
    ok_f3 = isapprox(normed(F3), normed([0.4,0.9,0.9,0.9, 0.6,0.1,0.1,0.1]))

    # f4
    f4 = first(filter(fn -> (length(M.vars(fn))==2 &&
                             M.addr_to_idx(fg,:z) in M.vars(fn) &&
                             M.addr_to_idx(fg,:w) in M.vars(fn)), all_factors))
    F4 = [
        M.factor_value(fg, f4, Dict(M.addr_to_idx(fg,:z)=>true,  M.addr_to_idx(fg,:w)=>true)),
        M.factor_value(fg, f4, Dict(M.addr_to_idx(fg,:z)=>true,  M.addr_to_idx(fg,:w)=>false)),
        M.factor_value(fg, f4, Dict(M.addr_to_idx(fg,:z)=>false, M.addr_to_idx(fg,:w)=>true)),
        M.factor_value(fg, f4, Dict(M.addr_to_idx(fg,:z)=>false, M.addr_to_idx(fg,:w)=>false))
    ]
    ok_f4 = isapprox(normed(F4), normed([0.4,0.6,0.5,0.5]))

    return ok_nodes && ok_nfactors && ok_f1 && ok_f2 && ok_f3 && ok_f4
end

function test_variable_elimination(; engine=:native)::Bool
    @gen function foo()
        x ~ bernoulli(0.6)
        y ~ bernoulli(x ? 0.2 : 0.9)
        z ~ bernoulli((x && y) ? 0.4 : 0.9)
        w ~ bernoulli(z ? 0.4 : 0.5)
    end
    trace = Gen.simulate(foo, ())
    latents = Dict{Any,Latent}()
    latents[:x] = M.Latent([true,false], [])
    latents[:y] = M.Latent([true,false], [:x])
    latents[:z] = M.Latent([true,false], [:x,:y])
    latents[:w] = M.Latent([true,false], [:z])
    observations = Dict{Any,Observation}()
    fg = M.compile_trace_to_factor_graph(trace, latents, observations)

    fg1 = call_eliminate(fg, :w; engine=engine)  # -> f5
    ok1 = (fg1.num_factors == 5) && (length(fg1.var_nodes) == 3)

    all_factors1 = Set{typeof(first(collect(first(values(fg1.var_nodes)).factor_nodes)))}()
    for node in values(fg1.var_nodes); union!(all_factors1, M.factor_nodes(node)); end
    f5 = first(filter(fn -> (length(M.vars(fn))==1 && M.addr_to_idx(fg1,:z) in M.vars(fn)), all_factors1))
    F5 = [ M.factor_value(fg1, f5, Dict(M.addr_to_idx(fg1,:z)=>true)),
           M.factor_value(fg1, f5, Dict(M.addr_to_idx(fg1,:z)=>false)) ]
    ok_f5 = isapprox(normed(F5), normed([0.4+0.6, 0.5+0.5]))

    fg2 = call_eliminate(fg1, :x; engine=engine) # -> f6
    ok2 = (fg2.num_factors == 6) && (length(fg2.var_nodes) == 2)
    all_factors2 = Set{typeof(first(collect(first(values(fg2.var_nodes)).factor_nodes)))}()
    for node in values(fg2.var_nodes); union!(all_factors2, M.factor_nodes(node)); end
    f6 = first(filter(fn -> (length(M.vars(fn))==2 && M.addr_to_idx(fg2,:y) in M.vars(fn) && M.addr_to_idx(fg2,:z) in M.vars(fn)), all_factors2))
    F6_true_true  = (0.6 * 0.2 * 0.4) + (0.4 * 0.9 * 0.9)
    F6_true_false = (0.6 * 0.2 * 0.6) + (0.4 * 0.9 * 0.1)
    F6_false_true = (0.6 * 0.8 * 0.9) + (0.4 * 0.1 * 0.9)
    F6_false_false= (0.6 * 0.8 * 0.1) + (0.4 * 0.1 * 0.1)
    F6 = [
        M.factor_value(fg2, f6, Dict(M.addr_to_idx(fg2,:y)=>true,  M.addr_to_idx(fg2,:z)=>true)),
        M.factor_value(fg2, f6, Dict(M.addr_to_idx(fg2,:y)=>true,  M.addr_to_idx(fg2,:z)=>false)),
        M.factor_value(fg2, f6, Dict(M.addr_to_idx(fg2,:y)=>false, M.addr_to_idx(fg2,:z)=>true)),
        M.factor_value(fg2, f6, Dict(M.addr_to_idx(fg2,:y)=>false, M.addr_to_idx(fg2,:z)=>false))
    ]
    ok_f6 = isapprox(normed(F6), normed([F6_true_true, F6_true_false, F6_false_true, F6_false_false]))
    return ok1 && ok_f5 && ok2 && ok_f6
end

function test_conditional_dist()::Bool
    @gen function foo()
        x ~ bernoulli(0.6)
        y ~ bernoulli(x ? 0.2 : 0.9)
        z ~ bernoulli((x && y) ? 0.4 : 0.9)
        w ~ bernoulli(z ? 0.4 : 0.5)
    end
    trace = Gen.simulate(foo, ())
    latents = Dict{Any,Latent}()
    latents[:x] = M.Latent([true,false], [])
    latents[:y] = M.Latent([true,false], [:x])
    latents[:z] = M.Latent([true,false], [:x,:y])
    latents[:w] = M.Latent([true,false], [:z])
    observations = Dict{Any,Observation}()
    fg = M.compile_trace_to_factor_graph(trace, latents, observations)

    fg = M.eliminate(fg, :w)
    fg = M.eliminate(fg, :x)

    F6 = [ (0.6 * 0.2 * 0.4) + (0.4 * 0.9 * 0.9),
           (0.6 * 0.2 * 0.6) + (0.4 * 0.9 * 0.1),
           (0.6 * 0.8 * 0.9) + (0.4 * 0.1 * 0.9),
           (0.6 * 0.8 * 0.1) + (0.4 * 0.1 * 0.1) ]
    F5 = [0.4+0.6, 0.5+0.5]

    values = Vector{Any}(undef, 4)
    values[M.addr_to_idx(fg, :y)] = true
    actual1 = M.conditional_dist(fg, values, :z)
    expected1 = normed([F6[1]*F5[1], F6[2]*F5[2]])
    values[M.addr_to_idx(fg, :y)] = false
    actual2 = M.conditional_dist(fg, values, :z)
    expected2 = normed([F6[3]*F5[1], F6[4]*F5[2]])
    return isapprox(actual1, expected1) && isapprox(actual2, expected2)
end

function test_mh_always_accepts_dml(; engine=:native)::Bool
    @gen function bar()
        x ~ bernoulli(0.6)
        y ~ bernoulli(x ? 0.2 : 0.9)
        z ~ bernoulli((x && y) ? 0.4 : 0.9)
        w ~ bernoulli(z ? 0.4 : 0.5)
        obs ~ bernoulli((x && w) ? 0.4 : 0.1)
    end
    trace = Gen.simulate(bar, ())

    latents = Dict{Any,Latent}()
    latents[:x] = M.Latent([true,false], [])
    latents[:y] = M.Latent([true,false], [:x])
    latents[:z] = M.Latent([true,false], [:x,:y])
    latents[:w] = M.Latent([true,false], [:z])
    observations = Dict{Any,Observation}()
    observations[:obs] = M.Observation([:x,:w])

    order = [:w,:x,:z,:y]
    sampler = call_generate_fixed_trace(
        trace, order;
        engine=engine,
        backend=:numpy,
        dtype=:float64,
        latents=latents,
        observations=observations
    )

    ok = true
    t = trace
    for _ in 1:100
        qf = Gen.simulate(sampler, ())
        new_t, weight, _, discard = Gen.update(t, Gen.get_args(t), map(_->Gen.NoChange(), Gen.get_args(t)), Gen.get_choices(qf))
        qb, _ = Gen.generate(sampler, (), discard)
        log_ratio = weight + Gen.get_score(qb) - Gen.get_score(qf)
        local tol = engine == :einsum ? 1e-10 : 1e-10
        ok = ok && isfinite(log_ratio) && abs(log_ratio) <= tol
        t = new_t
    end
    return ok
end

function test_hmm_ffbs(; engine=:native)::Bool
    prior = rand(3); prior ./= sum(prior)
    A = rand(3,3);  A ./= sum(A, dims=2)
    B = rand(3,3);  B ./= sum(B, dims=2)
    T = 10
    @gen function hmm()
        z = ({(:z,1)} ~ categorical(prior))
        {(:x,1)} ~ categorical(B[z,:])
        for t in 2:T
            z = ({(:z,t)} ~ categorical(A[z,:]))
            {(:x,t)} ~ categorical(B[z,:])
        end
    end

    latents = Dict{Any,Latent}()
    latents[(:z,1)] = M.Latent(collect(1:3), [])
    for t in 2:T
        latents[(:z,t)] = M.Latent(collect(1:3), [(:z,t-1)])
    end
    observations = Dict{Any,Observation}()
    for t in 1:T; observations[(:x,t)] = M.Observation([(:z,t)]); end
    order = Any[(:z,t) for t in 1:T]

    trace = Gen.simulate(hmm, ())
    ok = true

    let elim_order = order, lats = latents, obs = observations, eng = engine
        @gen function dml_mh_sampler(tr)
            {*} ~ backwards_sampler_dml_positional(tr, elim_order, lats, obs, eng)
       end
        for _ in 1:10
            trace, acc = Gen.mh(trace, dml_mh_sampler, ())
            ok &= acc
        end
    end
    return ok
end

function compute_next_probs(x1)
    probs = fill(1/20, 20)
    probs[x1] *= 10
    return probs ./ sum(probs)
end
@gen (static) function static_model()
    p1 ~ beta(0.5, 0.5)
    p2 ~ beta(0.5, 0.5)
    p3 ~ beta(0.5, 0.5)
    x1 ~ uniform_discrete(1, 20)
    x2 ~ categorical(compute_next_probs(x1))
    x3 ~ categorical(compute_next_probs(x2))
    x4 ~ categorical(compute_next_probs(x3))
    x5 ~ categorical(compute_next_probs(x4))
    x6 ~ categorical(compute_next_probs(x5))
    x7 ~ bernoulli(x6 > 5 ? p2 : p3)
    x8 ~ bernoulli(x7 ? p2 : p3)
    x9 ~ bernoulli(x8 ? p2 : p3)
    x10 ~ bernoulli(x9 ? p2 : p3)
    x11 ~ bernoulli(x10 ? p2 : p3)
    x12 ~ bernoulli(x11 ? p2 : p3)
    x13 ~ bernoulli(x12 ? p2 : p3)
    x14 ~ bernoulli(x13 ? p2 : p3)
    x15 ~ bernoulli(x14 ? p2 : p3)
    x16 ~ bernoulli(x15 ? p2 : p3)
    x17 ~ bernoulli(x16 ? p2 : p3)
    x18 ~ bernoulli(x17 ? p2 : p3)
    x19 ~ bernoulli(x18 ? p2 : p3)
    x20 ~ bernoulli(x19 ? 0.5 : 0.1)
end
function test_sml_basic_block(; engine=:native)::Bool
    @load_generated_functions()
    trace = Gen.simulate(static_model, ())

    # structure-specialized sampler
    sampler = call_generate_fixed_structure(trace, [:x1,:x2,:x3]; engine=engine)
    ok = true
    for _ in 1:100
        trace, acc = Gen.mh(trace, sampler, ())
        ok &= acc
    end

    for _ in 1:100
        qf = Gen.simulate(sampler, (trace,))
        new_t, weight, _, discard =
            Gen.update(trace, Gen.get_args(trace),
                      map(_->Gen.NoChange(), Gen.get_args(trace)),
                      Gen.get_choices(qf))
        qb, _ = Gen.generate(sampler, (new_t,), discard)

        # tolerance tuned for numeric backend/dtype
        tol = 1e-10
        ok &= (abs(weight + Gen.get_score(qb) - Gen.get_score(qf)) < tol)

        trace = new_t
    end

    return ok
end

@gen (static) function local_model(inlier_mean)
    z ~ bernoulli(0.5)
    x ~ normal(z ? inlier_mean : 0.0, z ? 1.0 : 10.0)
end
@gen (static) function global_model(n)
    inlier_mean ~ normal(0, 10.0)
    data ~ Map(local_model)(fill(inlier_mean, n))
end
function test_map(; engine=:native)::Bool
    @load_generated_functions()
    n = 10
    trace = Gen.simulate(global_model, (n,))
    sampler = call_generate_fixed_structure(trace, [(:data=>i=>:z) for i in 1:n]; engine=engine)
    ok = true
    for _ in 1:100
        trace, acc = Gen.mh(trace, sampler, ())
        ok &= acc
        trace, _ = Gen.mh(trace, Gen.select(:inlier_mean))
    end
    return ok
end


prior = rand(3); prior ./= sum(prior)
A = rand(3,3);  A ./= sum(A, dims=2)
B = rand(3,3);  B ./= sum(B, dims=2)
@gen (static) function step(t::Int, z_prev::Int)
    z ~ categorical(A[z_prev,:])
    x ~ categorical(B[z,:])
    return z
end
@gen (static) function hmm(T::Int)
    z_init ~ categorical(prior)
    x_init ~ categorical(B[z_init,:])
    steps ~ (Unfold(step))(T, z_init)
end
function test_unfold_and_fixed_trace(; engine=:native)::Bool

    @load_generated_functions()

    # Unfold acceptance
    trace = Gen.simulate(hmm, (10,))
    sampler = call_generate_fixed_structure(trace, [:z_init, ([:steps=>t=>:z for t in 1:10]...)]; engine=engine)
    ok1 = true
    for _ in 1:100
        trace, acc = Gen.mh(trace, sampler, ())
        ok1 &= acc
    end

    # fixed trace log-ml equality
    Tlen = 10
    trace2 = Gen.simulate(hmm, (Tlen,))
    observations = Gen.choicemap()
    observations[:x_init] = trace2[:x_init]
    for t in 1:Tlen
        observations[:steps=>t=>:x] = trace2[:steps=>t=>:x]
    end
    sampler2 = call_generate_fixed_trace(trace2, [:z_init, ([:steps=>t=>:z for t in 1:Tlen]...)]; engine=engine)
    q1 = Gen.simulate(sampler2, ())
    p1, = Gen.generate(hmm, (Tlen,), merge(observations, Gen.get_choices(q1)))
    q2 = Gen.simulate(sampler2, ())
    p2, = Gen.generate(hmm, (Tlen,), merge(observations, Gen.get_choices(q2)))
    ok2 = isapprox(Gen.get_score(p1) - Gen.get_score(q1), Gen.get_score(p2) - Gen.get_score(q2))
    return ok1 && ok2
end

#  run & print
using BenchmarkTools, Printf

using Statistics  # for mean and std

# Pretty-print an array on one line
fmtarr(x) = "[" * join(map(y -> @sprintf("%.6g", y), vec(collect(x))), ", ") * "]"

# Show factor variable order as symbolic addrs
_var_order_labels(fg, f) = [M.idx_to_var_node(fg, vidx).addr for vidx in M.vars(f)]


# One place to print diffs for a failing check
function print_diff(name; expected, native, einsum)
    println("Failure diagnostics: ", name)
    println("  expected: ", fmtarr(expected))
    println("  native:   ", fmtarr(native))
    println("  einsum:   ", fmtarr(einsum))
    diffe = abs.(einsum .- expected)
    diffn = abs.(native .- expected)
    println("  |einsum - expected|: ", fmtarr(diffe))
    println("  |native - expected|: ", fmtarr(diffn))
    println()
end



# Run a test and return (ok, seconds).
# `bench=true` uses @belapsed for stable timings.
run_timed(fn; bench::Bool=false, seconds::Float64=1.0) = begin
    if bench
        fn()  # warm-up
        t = @belapsed $fn() seconds=seconds evals=1
        ok = fn()
        return ok, t
    else
        local tm = @timed ok = fn()
        return ok, tm.time
    end
end


# Rebuild the tiny Bernoulli graph used in several tests
@gen function _foo_model()
    x ~ bernoulli(0.6)
    y ~ bernoulli(x ? 0.2 : 0.9)
    z ~ bernoulli((x && y) ? 0.4 : 0.9)
    w ~ bernoulli(z ? 0.4 : 0.5)
end

function _make_fg_for_foo()
    trace = Gen.simulate(_foo_model, ())
    latents = Dict{Any,Latent}()
    latents[:x] = M.Latent([true,false], [])
    latents[:y] = M.Latent([true,false], [:x])
    latents[:z] = M.Latent([true,false], [:x,:y])
    latents[:w] = M.Latent([true,false], [:z])
    observations = Dict{Any,Observation}()
    return M.compile_trace_to_factor_graph(trace, latents, observations)
end

# diagnostics: variable elimination
function diagnose_variable_elimination()
    fg0 = _make_fg_for_foo()

    # expected values used in the test
    F5_expected = [0.4 + 0.6, 0.5 + 0.5]
    F6_true_true  = (0.6 * 0.2 * 0.4) + (0.4 * 0.9 * 0.9)
    F6_true_false = (0.6 * 0.2 * 0.6) + (0.4 * 0.9 * 0.1)
    F6_false_true = (0.6 * 0.8 * 0.9) + (0.4 * 0.1 * 0.9)
    F6_false_false= (0.6 * 0.8 * 0.1) + (0.4 * 0.1 * 0.1)
    F6_expected = [F6_true_true, F6_true_false, F6_false_true, F6_false_false]

    # helper to pull factor arrays
    function f5_f6_for(engine_sym)
        fg1 = call_eliminate(fg0, :w; engine=engine_sym)  # f5 is a unary factor on z
        all_factors1 = Set{typeof(first(collect(first(values(fg1.var_nodes)).factor_nodes)))}()
        for node in values(fg1.var_nodes); union!(all_factors1, M.factor_nodes(node)); end
        f5 = first(filter(fn -> (length(M.vars(fn))==1 && M.addr_to_idx(fg1,:z) in M.vars(fn)), all_factors1))
        F5 = [
            M.factor_value(fg1, f5, Dict(M.addr_to_idx(fg1,:z)=>true)),
            M.factor_value(fg1, f5, Dict(M.addr_to_idx(fg1,:z)=>false))
        ]

        fg2 = call_eliminate(fg1, :x; engine=engine_sym)  # f6 is a binary factor on y,z
        all_factors2 = Set{typeof(first(collect(first(values(fg2.var_nodes)).factor_nodes)))}()
        for node in values(fg2.var_nodes); union!(all_factors2, M.factor_nodes(node)); end
        f6 = first(filter(fn -> (length(M.vars(fn))==2 &&
                                 M.addr_to_idx(fg2,:y) in M.vars(fn) &&
                                 M.addr_to_idx(fg2,:z) in M.vars(fn)), all_factors2))
        F6 = [
            M.factor_value(fg2, f6, Dict(M.addr_to_idx(fg2,:y)=>true,  M.addr_to_idx(fg2,:z)=>true)),
            M.factor_value(fg2, f6, Dict(M.addr_to_idx(fg2,:y)=>true,  M.addr_to_idx(fg2,:z)=>false)),
            M.factor_value(fg2, f6, Dict(M.addr_to_idx(fg2,:y)=>false, M.addr_to_idx(fg2,:z)=>true)),
            M.factor_value(fg2, f6, Dict(M.addr_to_idx(fg2,:y)=>false, M.addr_to_idx(fg2,:z)=>false))
        ]
        return F5, F6
    end

    F5_native, F6_native = f5_f6_for(:native)
    F5_einsum, F6_einsum = f5_f6_for(:einsum)

    print_diff("F5 over z after eliminating w";
        expected=normed(F5_expected),
        native=normed(F5_native),
        einsum=normed(F5_einsum))

    print_diff("F6 over y,z after eliminating x";
        expected=normed(F6_expected),
        native=normed(F6_native),
        einsum=normed(F6_einsum))

    # Show which var order each engine produced for F6
    let
        fg1n = call_eliminate(_make_fg_for_foo(), :w; engine=:native)
        fg2n = call_eliminate(fg1n, :x; engine=:native)
        all2n = Set{typeof(first(collect(first(values(fg2n.var_nodes)).factor_nodes)))}()
        for node in values(fg2n.var_nodes); union!(all2n, M.factor_nodes(node)); end
        f6n = first(filter(fn -> (length(M.vars(fn))==2 &&
                                  M.addr_to_idx(fg2n,:y) in M.vars(fn) &&
                                  M.addr_to_idx(fg2n,:z) in M.vars(fn)), all2n))
        println("  native F6 var order: ", _var_order_labels(fg2n, f6n))
    end
    let
        fg1e = call_eliminate(_make_fg_for_foo(), :w; engine=:einsum)
        fg2e = call_eliminate(fg1e, :x; engine=:einsum)
        all2e = Set{typeof(first(collect(first(values(fg2e.var_nodes)).factor_nodes)))}()
        for node in values(fg2e.var_nodes); union!(all2e, M.factor_nodes(node)); end
        f6e = first(filter(fn -> (length(M.vars(fn))==2 &&
                                  M.addr_to_idx(fg2e,:y) in M.vars(fn) &&
                                  M.addr_to_idx(fg2e,:z) in M.vars(fn)), all2e))
        println("  einsum F6 var order: ", _var_order_labels(fg2e, f6e))
    end

end

# diagnostics: conditional_dist
function diagnose_conditional_dist()
    # expected numbers as in the test
    F6 = [ (0.6 * 0.2 * 0.4) + (0.4 * 0.9 * 0.9),
           (0.6 * 0.2 * 0.6) + (0.4 * 0.9 * 0.1),
           (0.6 * 0.8 * 0.9) + (0.4 * 0.1 * 0.9),
           (0.6 * 0.8 * 0.1) + (0.4 * 0.1 * 0.1) ]
    F5 = [0.4 + 0.6, 0.5 + 0.5]
    expected_ytrue  = normed([F6[1]*F5[1], F6[2]*F5[2]])
    expected_yfalse = normed([F6[3]*F5[1], F6[4]*F5[2]])

    function cond_for(engine_sym)
        fg = _make_fg_for_foo()
        fg = call_eliminate(fg, :w; engine=engine_sym)
        fg = call_eliminate(fg, :x; engine=engine_sym)

        values = Vector{Any}(undef, 4)
        values[M.addr_to_idx(fg, :y)] = true
        a1 = M.conditional_dist(fg, values, :z)
        values[M.addr_to_idx(fg, :y)] = false
        a2 = M.conditional_dist(fg, values, :z)
        return a1, a2
    end

    a1_native, a2_native = cond_for(:native)
    a1_einsum, a2_einsum = cond_for(:einsum)

    print_diff("P(z|y=true)";
        expected=expected_ytrue, native=a1_native, einsum=a1_einsum)
    print_diff("P(z|y=false)";
        expected=expected_yfalse, native=a2_native, einsum=a2_einsum)
end


function diagnose_mh_always_accepts_dml()
    @gen function bar()
        x ~ bernoulli(0.6)
        y ~ bernoulli(x ? 0.2 : 0.9)
        z ~ bernoulli((x && y) ? 0.4 : 0.9)
        w ~ bernoulli(z ? 0.4 : 0.5)
        obs ~ bernoulli((x && w) ? 0.4 : 0.1)
    end
    trace0 = Gen.simulate(bar, ())

    latents = Dict{Any,Latent}()
    latents[:x] = M.Latent([true,false], [])
    latents[:y] = M.Latent([true,false], [:x])
    latents[:z] = M.Latent([true,false], [:x,:y])
    latents[:w] = M.Latent([true,false], [:z])
    observations = Dict{Any,Observation}()
    observations[:obs] = M.Observation([:x,:w])
    order = [:w,:x,:z,:y]

    function collect_ratios(engine_sym)
        sampler = call_generate_fixed_trace(trace0, order; engine=engine_sym,
                                            latents=latents, observations=observations)
        t = trace0
        ratios = Float64[]
        triples = Tuple{Float64,Float64,Float64}[]
        tol = (engine_sym == :einsum) ? 1e-10 : 1e-10
        for _ in 1:100
            qf = Gen.simulate(sampler, ())
            new_t, weight, _, discard =
                Gen.update(t, Gen.get_args(t), map(_->Gen.NoChange(), Gen.get_args(t)), Gen.get_choices(qf))
            qb, _ = Gen.generate(sampler, (), discard)
            lr = weight + Gen.get_score(qb) - Gen.get_score(qf)
            push!(ratios, lr); push!(triples, (weight, Gen.get_score(qb), Gen.get_score(qf)))
            t = new_t
        end
        return ratios, triples, tol
    end

    rn, tn, toln = collect_ratios(:native)
    re, te, tole = collect_ratios(:einsum)

    absn = abs.(rn); abse = abs.(re)
    worst_e = sortperm(abse, rev=true)[1:min(end, 5)]
    println("Failure diagnostics: MH always accepts (DML)")
    println("  tol_native=", toln, "  tol_einsum=", tole)
    println("  native max |log_ratio|: ", maximum(absn))
    println("  einsum  max |log_ratio|: ", maximum(abse))
    for i in worst_e
        w, qb, qf = te[i]
        println("  iter=", i,
                "  einsum log_ratio=", re[i],
                "  components: weight=", w, " score(qb)=", qb, " score(qf)=", qf)
    end
    println()
end


# build per-trace samplers with engine baked in

@gen function _sml_kernel(tr, addrs, eng)
    s = M.generate_backwards_sampler_fixed_structure(
            tr, addrs; engine = Symbol(eng))
    {*} ~ s(tr)   # <-- pass the trace here
end

@gen function _dml_kernel(tr, addrs, latents, observations, eng)
    s = M.generate_backwards_sampler_fixed_trace(
            tr, addrs, latents, observations; engine = Symbol(eng))
    {*} ~ s()     # fixed-trace variant takes no args
end



# Public DML sampler should MH-accept every step (engine-agnostic)

function test_dml_public_sampler_accepts(; engine=:native)::Bool
    @gen function bar()
        x ~ bernoulli(0.6)
        y ~ bernoulli(x ? 0.2 : 0.9)
        z ~ bernoulli((x && y) ? 0.4 : 0.9)
        w ~ bernoulli(z ? 0.4 : 0.5)
        obs ~ bernoulli((x && w) ? 0.4 : 0.1)
    end
    trace = Gen.simulate(bar, ())

    latents = Dict{Any,M.Latent}(
        :x=>M.Latent([true,false], []),
        :y=>M.Latent([true,false], [:x]),
        :z=>M.Latent([true,false], [:x,:y]),
        :w=>M.Latent([true,false], [:z]),
    )
    observations = Dict{Any,M.Observation}(:obs=>M.Observation([:x,:w]))
    order = [:w,:x,:z,:y]

    ok = true
    for _ in 1:100
        trace, acc = Gen.mh(trace, _dml_kernel, (order, latents, observations, engine))
        ok &= acc
    end
    return ok
end

function test_sml_generic_kernel_public(; engine=:native)::Bool
    @load_generated_functions()
    trace = Gen.simulate(static_model, ())
    ok = true
    for _ in 1:100
        trace, acc = mh(trace, _sml_kernel, ([:x1,:x2,:x3], engine))
        ok &= acc
    end
    return ok
end



function run_suite(label; engine=:native, bench=false, seconds=1.0)
    results = Dict{String,Bool}()
    timings = Dict{String,Float64}()

    # little helper
    function run(name::String, fn)
        ok, t = run_timed(fn; bench=bench, seconds=seconds)
        results[name] = ok
        timings[name] = t
    end

    run("compiling factor graph from trace", () -> test_compiling_factor_graph())
    run("variable elimination",             () -> test_variable_elimination(engine=engine))
    run("conditional_dist",                 () -> test_conditional_dist())
    run("MH always accepts (DML)",          () -> test_mh_always_accepts_dml(engine=engine))
    run("HMM FFBS",                         () -> test_hmm_ffbs(engine=engine))
    run("SML basic block",                  () -> test_sml_basic_block(engine=engine))
    run("Map",                              () -> test_map(engine=engine))
    run("Unfold and fixed-trace",           () -> test_unfold_and_fixed_trace(engine=engine))
    run("DML public sampler accepts",       () -> test_dml_public_sampler_accepts(engine=engine))
    run("SML public sampler accepts",       () -> test_sml_generic_kernel_public(engine=engine))

    return (label=label, results=results, timings=timings)
end


native_suite = run_suite("Native"; engine=:native, bench=true, seconds=1.0)
einsum_suite = run_suite("Einsum"; engine=:einsum, bench=true, seconds=1.0)

all_tests = [
    "compiling factor graph from trace",
    "variable elimination",
    "conditional_dist",
    "MH always accepts (DML)",
    "HMM FFBS",
    "SML basic block",
    "Map",
    "Unfold and fixed-trace",
    "DML public sampler accepts",
    "SML public sampler accepts"
]

println("Test results (native vs einsum):")
fmt = "- %-28s  native=%-4s (%8.4f s)   einsum=%-4s (%8.4f s)   speedup=%6.2fx\n"
for tname in all_tests
    n  = native_suite.results[tname];  tn = native_suite.timings[tname]
    e  = einsum_suite.results[tname];  te = einsum_suite.timings[tname]
    sp = tn / te
    Printf.format(stdout, Printf.Format(fmt),
                  tname, n ? "PASS" : "FAIL", tn, e ? "PASS" : "FAIL", te, sp)
end
println()

# Print diagnostics only for failures
diagnosers = Dict(
    "variable elimination" => diagnose_variable_elimination,
    "conditional_dist"     => diagnose_conditional_dist,
    "MH always accepts (DML)" => diagnose_mh_always_accepts_dml,
)

for tname in all_tests
    n = native_suite.results[tname]
    e = einsum_suite.results[tname]
    if !(n && e)
        println("\nDetailed numbers for failed test: ", tname)
        if haskey(diagnosers, tname)
            diagnosers[tname]()
        else
            println("(no numeric diff printer wired for this test)")
        end
    end
end

println("\nDone.")

[33m[1m└ [22m[39m[90m@ Gen ~/.julia/packages/Gen/mP0Sq/src/Gen.jl:33[39m
[33m[1m└ [22m[39m[90m@ Gen ~/.julia/packages/Gen/mP0Sq/src/Gen.jl:33[39m
[33m[1m└ [22m[39m[90m@ Gen ~/.julia/packages/Gen/mP0Sq/src/Gen.jl:33[39m
[33m[1m└ [22m[39m[90m@ Gen ~/.julia/packages/Gen/mP0Sq/src/Gen.jl:33[39m


Test results (native vs einsum):
- compiling factor graph from trace  native=PASS (  0.0006 s)   einsum=PASS (  0.0006 s)   speedup=  1.01x
- variable elimination          native=PASS (  0.0008 s)   einsum=PASS (  0.0030 s)   speedup=  0.28x
- conditional_dist              native=PASS (  0.0007 s)   einsum=PASS (  0.0007 s)   speedup=  1.02x
- MH always accepts (DML)       native=PASS (  0.0408 s)   einsum=PASS (  0.0424 s)   speedup=  0.96x
- HMM FFBS                      native=PASS (  0.1310 s)   einsum=PASS (  0.5530 s)   speedup=  0.24x
- SML basic block               native=PASS (  4.9426 s)   einsum=PASS (  6.0255 s)   speedup=  0.82x
- Map                           native=PASS (  0.4994 s)   einsum=PASS (  2.2621 s)   speedup=  0.22x
- Unfold and fixed-trace        native=PASS (  1.6553 s)   einsum=PASS (  4.5736 s)   speedup=  0.36x
- DML public sampler accepts    native=PASS (  0.2282 s)   einsum=PASS (  1.0802 s)   speedup=  0.21x
- SML public sampler accepts    native=PASS 

# speed checks
Now some speed comparisons between the native method and the einsum+jax method

## speed check 1

In [15]:
using BenchmarkTools, Printf

@gen function hmm_big(K::Int, T::Int, A, B, prior)
    z = ({(:z,1)} ~ categorical(prior))
    {(:x,1)} ~ categorical(B[z,:])
    for t in 2:T
        z = ({(:z,t)} ~ categorical(A[z,:]))
        {(:x,t)} ~ categorical(B[z,:])
    end
end

function bench_hmm_ve_big(; K::Int=128, T::Int=300,
                           einsum_backend::Symbol=:jax,
                           einsum_dtype::Symbol=:float32,
                           optimize::Union{String,Symbol}="greedy")

    # random normalized params
    prior = rand(K); prior ./= sum(prior)
    A = rand(K,K);  A ./= sum(A, dims=2)
    B = rand(K,K);  B ./= sum(B, dims=2)

    # build one trace (observations live inside it)
    tr = Gen.simulate(hmm_big, (K, T, A, B, prior))

    # factor graph
    latents = Dict{Any,Latent}()
    latents[(:z,1)] = Latent(collect(1:K), Any[])
    for t in 2:T
        latents[(:z,t)] = Latent(collect(1:K), Any[(:z,t-1)])
    end
    observations = Dict{Any,Observation}()
    for t in 1:T
        observations[(:x,t)] = Observation(Any[(:z,t)])
    end

    fg = compile_trace_to_factor_graph(tr, latents, observations)
    order = Any[(:z,t) for t in 1:T]

    # warm-up (JIT / planner / array caches)
    variable_elimination(fg, order; engine=:native)
    variable_elimination(fg, order; engine=:einsum,
                         backend=einsum_backend,
                         dtype=einsum_dtype,
                         optimize=optimize,
                         jit=true, cache=true)

    # timings (fg is immutable; VE returns a new graph)
    t_native = @belapsed variable_elimination($fg, $order; engine=:native)
    t_einsum = @belapsed variable_elimination($fg, $order; engine=:einsum,
                                              backend=$einsum_backend,
                                              dtype=$einsum_dtype,
                                              optimize=$optimize,
                                              jit=true, cache=true)

    println(@sprintf("HMM VE  K=%d  T=%d  |  native=%.3fs   einsum(%s,%s)=%.3fs   speedup=%.2fx",
                     K, T, t_native, String(einsum_backend), String(einsum_dtype), t_einsum, t_native/t_einsum))

    return (t_native=t_native, t_einsum=t_einsum)
end

bench_hmm_ve_big(K=16,  T=50)           # warm-up-ish

HMM VE  K=16  T=50  |  native=89.295s   einsum(jax,float32)=0.494s   speedup=180.85x


(t_native = 89.294777604, t_einsum = 0.493738768)

In [16]:
bench_hmm_ve_big(K=48, T=50)

HMM VE  K=48  T=50  |  native=621.312s   einsum(jax,float32)=1.087s   speedup=571.72x


(t_native = 621.312468637, t_einsum = 1.086737142)

In [17]:
bench_hmm_ve_big(K=62, T=50)

HMM VE  K=62  T=50  |  native=1030.990s   einsum(jax,float32)=1.381s   speedup=746.44x


(t_native = 1030.990331316, t_einsum = 1.381216314)

## speed check 2

In [18]:
using Random, BenchmarkTools

# Global parameters used by the static @gen functions
const CHAIN = Dict{Symbol,Any}()

function setup_chain!(K::Int, L::Int; seed::Int=42)
    rng = MersenneTwister(seed)
    prior = rand(rng, K); prior ./= sum(prior)
    A = rand(rng, K, K);  A ./= sum(A, dims=2)  # transitions
    B = rand(rng, K, K);  B ./= sum(B, dims=2)  # emissions

    CHAIN[:K] = K
    CHAIN[:L] = L
    CHAIN[:prior] = prior
    CHAIN[:A] = A
    CHAIN[:B] = B
    return nothing
end

#  Static SML model (top-level!)
@gen (static) function chain_step(t::Int, z_prev::Int)
    z ~ categorical(CHAIN[:A][z_prev, :])
    x ~ categorical(CHAIN[:B][z, :])
    return z
end

@gen (static) function chain_model(L::Int)
    z1 ~ categorical(CHAIN[:prior])
    x1 ~ categorical(CHAIN[:B][z1, :])
    zs ~ (Unfold(chain_step))(L-1, z1)
end

#  Factor graph + elimination order
function make_fg_chain(K::Int, L::Int)
    setup_chain!(K, L)
    trace = Gen.simulate(chain_model, (L,))

    # Latents: all z's in the chain
    latents = Dict{Any,Latent}()
    latents[:z1] = Latent(collect(1:CHAIN[:K]), Any[])
    for t in 1:L-1
        parent = (t == 1) ? :z1 : (:zs=>t-1=>:z)
        latents[:zs=>t=>:z] = Latent(collect(1:CHAIN[:K]), Any[parent])
    end

    # Observations: each x depends on its z
    observations = Dict{Any,Observation}()
    observations[:x1] = Observation(Any[:z1])
    for t in 1:L-1
        observations[:zs=>t=>:x] = Observation(Any[:zs=>t=>:z])
    end

    fg = compile_trace_to_factor_graph(trace, latents, observations)

    # FFBS-like elimination order over z's
    order = Any[:z1; ([:zs=>t=>:z for t in 1:L-1]...)]

    return (fg=fg, order=order, trace=trace, latents=latents, observations=observations)
end

# identical FG+order, native vs einsum
function bench_chain_sml(K::Int, L::Int;
        seconds::Float64=1.0, backend::Symbol=:auto,
        dtype::Symbol=:float32, jit::Bool=true, cache::Bool=true)

    data = make_fg_chain(K, L)
    fg, order = data.fg, data.order

    # Warm up both engines (important for JAX/XLA and opt_einsum expression caching)
    _ = variable_elimination(fg, order; engine=:native)
    _ = variable_elimination(fg, order; engine=:einsum,
                             backend=backend, dtype=dtype, jit=jit, cache=cache)

    tn = @belapsed variable_elimination($fg, $order; engine=:native) seconds=seconds evals=1
    te = @belapsed variable_elimination($fg, $order; engine=:einsum,
                                        backend=$backend, dtype=$dtype, jit=$jit, cache=$cache) seconds=seconds evals=1
    sp = tn / te

    println("SML chain VE benchmark  (K=$K, L=$L, backend=$(backend), dtype=$(dtype), jit=$(jit))")
    @printf("  native (CPU greedy):      %8.4f s\n", tn)
    @printf("  einsum (opt_einsum/JAX):  %8.4f s\n", te)
    @printf("  speedup (native/einsum):  %8.3fx\n", sp)

    return (tn=tn, te=te, speedup=sp, fg=fg, order=order, extras=data)
end

bench_chain_sml(16, 32; seconds=1.0, backend=:auto, dtype=:float32, jit=true);

SML chain VE benchmark  (K=16, L=32, backend=auto, dtype=float32, jit=true)
  native (CPU greedy):        8.0185 s
  einsum (opt_einsum/JAX):    0.1048 s
  speedup (native/einsum):    76.543x


In [19]:
bench_chain_sml(32, 64; seconds=2.0, backend=:auto, dtype=:float32, jit=true);

SML chain VE benchmark  (K=32, L=64, backend=auto, dtype=float32, jit=true)
  native (CPU greedy):      752.2822 s
  einsum (opt_einsum/JAX):    2.5412 s
  speedup (native/einsum):   296.038x


## speed check 3

In [9]:
# Scaled Map model (overhead sanity)

# Reuse the already-defined local_model/global_model if present
if !isdefined(Main, :local_model)
@gen (static) function local_model(inlier_mean)
    z ~ bernoulli(0.5)
    x ~ normal(z ? inlier_mean : 0.0, z ? 1.0 : 10.0)
end
end
if !isdefined(Main, :global_model)
@gen (static) function global_model(n)
    inlier_mean ~ normal(0, 10.0)
    data ~ Map(local_model)(fill(inlier_mean, n))
end
end

function bench_map_large(; n, dtype=:float32,
                          backend=:auto, optimize="auto",
                          jit=true, cache=true)

    # inner runner so both engines share the same setup/printing
    run_one(engine) = begin
        trace = Gen.simulate(global_model, (n,))
        sampler = generate_backwards_sampler_fixed_structure(
            trace, [(:data=>i=>:z) for i in 1:n];
            engine=engine, backend=backend, optimize=optimize,
            dtype=dtype, jit=jit, cache=cache
        )
        steps = 5
        t = @elapsed begin
            for _ in 1:steps
                trace, _ = Gen.mh(trace, sampler, ())
                trace, _ = Gen.mh(trace, Gen.select(:inlier_mean))
            end
        end
        println("map_large        n=$n  engine=$engine  dtype=$dtype  steps=$steps  time=$(round(t, digits=3)) s")
        t
    end

    t_native = run_one(:native)
    t_einsum = run_one(:einsum)
    speedup  = t_native / t_einsum
    @printf "einsum speedup: %.2fx (native/einsum)\n" speedup

    return (; native=t_native, einsum=t_einsum, speedup=speedup)
end

bench_map_large(n=50, dtype=:float32)

map_large        n=50  engine=native  dtype=float32  steps=5  time=51.995 s
map_large        n=50  engine=einsum  dtype=float32  steps=5  time=7.982 s
einsum speedup: 6.51x (native/einsum)


(native = 51.995298365, einsum = 7.982030794, speedup = 6.514043820036909)

In [10]:
bench_map_large(n=100, dtype=:float32)

map_large        n=100  engine=native  dtype=float32  steps=5  time=330.215 s
map_large        n=100  engine=einsum  dtype=float32  steps=5  time=61.036 s
einsum speedup: 5.41x (native/einsum)


(native = 330.214679557, einsum = 61.036224029, speedup = 5.41014266216904)