# Verification notebook

Here every component implemented in tensorflow, jax and numpy is cross verified 

First that they all produce the same output then some timing

# Verify Jax vs Np

In [None]:
import sys
# Server vs Local 
if '/Code' in os.getcwd():
    os.chdir('/Code/ActualThesisWork')
    %pip install monotonic-nn==0.3.5 # Just cloned it into a dir on server
    %pip install tensorflow_probability==0.23 # Comes with 04.03.24
else:
    os.chdir('/scratch/midway3/fsemler')

sys.path.append(os.getcwd())

In [None]:
import layer_reimplementation_jax as ld_jax
import layer_reimplementation_np as ld_np
import layer_definitions as ld
# To compile the numpy code 
import numba

# to generate random numbers
import scipy

In [None]:
# Some things we will need for testing

i0s = #TODO
pmt_pos_top = # TODO 
r_tpc = 66.4

n_events = 1_000
n_pmts = len(pmt_pos_top)

def points_in_circle(n_points, seed = 0):
    rng = np.random.default_rng(seed)
    # sample angles uniformly
    theta = rng.random(n_points) * 2 * np.pi            # [0, 2π)
    # sample radii with sqrt for area density
    r     = tpc_r * np.sqrt(rng.random(n_points))       # [0, R]
    # convert to Cartesian
    x = r * np.cos(theta)
    y = r * np.sin(theta)
    return np.stack([x, y], axis=1)

In [None]:
# Some utility functions

def compare_arrays(a, b, rtol: float = 1e-6) -> None:
    """
    Compare two array-like inputs (NumPy, JAX, or TensorFlow) for near-equality.
    Prints whether all elements are close within the given relative tolerance.
    If not, prints a table of absolute difference statistics.

    Parameters:
    -----------
    a, b : array-like or tensor
        Inputs to compare. Can be numpy.ndarray, JAX arrays, or TensorFlow tensors.
    rtol : float
        Relative tolerance for np.allclose.
    """
    def to_numpy(x):
        # TensorFlow tensor
        if _has_tf and isinstance(x, tf.Tensor):
            return x.numpy()
        # JAX array
        if _has_jax and isinstance(x, jax.Array):
            return np.array(x)
        # NumPy or other array-like
        return np.array(x)

    a_np = to_numpy(a).astype(np.float64)
    b_np = to_numpy(b).astype(np.float64)

    if a_np.shape != b_np.shape:
        print(f"Shape mismatch: {a_np.shape} vs {b_np.shape}")
        return

    if np.allclose(a_np, b_np, rtol=rtol):
        print(f"All elements close within rtol={rtol}")
    else:
        diff = np.abs(a_np - b_np)
        # Single-line table
        metrics = ["max diff", "min diff", "mean diff", "std diff"]
        values = [f"{diff.max():.6g}", f"{diff.min():.6g}", f"{diff.mean():.6g}", f"{diff.std():.6g}"]
        # Print headers
        print(" | ".join(metrics))
        # Print values
        print(" | ".join(values))

In [None]:
jnp_func_to_test = (
    # All inputs are tested with the first func
    "get_input_functions",
    "make_normalization_layer",
    "make_dense_layer",
    "make_mono_activations",
    "make_I0_layer",
    "make_radial_lce_layer", 
    # All below done at once through last
    "make_lut_table_with_std",
    "make_lut_table_fixed_std",
    "make_exact_lr",
    "make_lut_trainable_std_lr",
    "make_lut_fixed_std_lr",
    "make_likelihood_fn",
)

In [None]:
# Returns 4 functions - all the input functions
j_out  = ld_jax.get_input_functions(pmt_pos_top)
np_out = ld_np.get_input_functions(pmt_pos_top)
nb_out = numba.njit(ld_np.get_input_functions(pmt_pos_top))

pos = points_in_circle(n_events)

j_out   = [i(pos) for i in j_out]
np_out  = [i(pos) for i in np_out]
nb_out  = [i(pos) for i in nb_out]

del pos

for i in range(len(j_out)):
    print("jax v np")
    compare_arrays(j_out[i], np_out[i])
    print("np v nb")
    compare_arrays(np_out[i], nb_out[i])

In [None]:
random_input = scipt.stats.poisson(mu=1000).rvs(size=(512, n_pmts))

j_out  = ld_jax.make_normalization_layer()(random_input)
np_out = ld_np.make_normalization_layer()(random_input)
nb_out = numba.njit(ld_np.make_normalization_layer())(random_input)

del random_input

print("jax v np")
compare_arrays(j_out, np_out)
print("np v nb")
compare_arrays(np_out, nb_out)

In [None]:
batch_size = 246
rng        = np.random.default_rng(0)
for base_act in ['tanh', 'exponential', 'relu', 'sigmoid']:
    # Random layer dimensions
    input_dim = rng.integers(1, 33)
    units     = rng.integers(1, 33)

    # Random kernel & bias
    kernel_np = rng.standard_normal((input_dim, units), dtype=np.float32)
    bias_np   = rng.standard_normal((units,),           dtype=np.float32)

    # Build activation functions
    jax_act = ld_jax.make_mono_activations(base_act)
    np_act  = ld_np.make_mono_activations(base_act)

    # Instantiate layers
    jax_layer = ld_jax.make_dense_layer(jnp.array(kernel_np),
                                        jnp.array(bias_np),
                                        jax_act)
    np_layer  = ld_np.make_dense_layer(kernel_np,
                                       bias_np,
                                       np_act)

    # Numba‐compile the NumPy layer
    nb_layer  = numba.njit(np_layer)

    # Generate random input
    X_np = rng.standard_normal((batch_size, n_pmts, input_dim), dtype=np.float32)
    X_j  = jnp.array(X_np)

    # Compute outputs
    y_jax = np.array(jax_layer(X_j))    # JAX → NumPy
    y_np  =  np_layer(X_np)
    y_nb  =  nb_layer(X_np)

    # Compare
    print(f"\n=== base_act={base_act!r}, input_dim={input_dim}, units={units} ===")
    print("jax vs np: ", end=""); compare_arrays(y_jax, y_np)
    print("np vs nb:  ", end=""); compare_arrays(y_np,  y_nb)

In [None]:
random_input = scipt.stats.poisson(mu=1000).rvs(size=(512, n_pmts))
i0s = np.random.rand(n_pmts)

j_out  = ld_jax.make_I0_layer(i0s)(random_input)
np_out = ld_np.make_I0_layer(i0s)(random_input)
nb_out = numba.njit(ld_np.make_I0_layer(i0s))(random_input)

del random_input

print("jax v np")
compare_arrays(j_out, np_out)
print("np v nb")
compare_arrays(np_out, nb_out)

In [None]:
# Pretrained guess -> easiest
params = [1.6266745e+00,  9.4918861e+00, -4.2176653e-05,  7.7804564e-03,]

random_input = tpc_r * np.sqrt(rng.random(n_points))

j_out  = ld_jax.make_radial_lce_layer(params)(random_input)
np_out = ld_np.make_radial_lce_layer(params)(random_input)
nb_out = numba.njit(ld_np.make_radial_lce_layer(params))(random_input)

del random_input

print("jax v np")
compare_arrays(j_out, np_out)
print("np v nb")
compare_arrays(np_out, nb_out)

In [None]:
# And the likelihood functions 

# ---- 1) shared parameters ----
P = 16                     # small PMT count for speed
switching_signal = 10.0
p_dpe = 0.2
m, z = 2, 2                # small subdivisions for test
nan_safe_value = 1e5

# 1a) build domains in NumPy
n_pe_domain_np = np.arange(
    0.0,
    switching_signal + 5 * np.sqrt(switching_signal) + 2
)
n_ph_domain_np = np.arange(
    0.0,
    switching_signal / (1 + p_dpe)
    + 5 * np.sqrt(switching_signal / (1 + p_dpe))
    + 2
)
x_domain_np     = np.linspace(0, switching_signal, 40*m + 1)
sigma_domain_np = np.linspace(0.05, 1.0, z)

# 1b) corresponding JAX arrays
n_pe_domain_j = jnp.array(n_pe_domain_np, dtype=jnp.float32)
n_ph_domain_j = jnp.array(n_ph_domain_np, dtype=jnp.float32)
x_domain_j    = jnp.array(x_domain_np,    dtype=jnp.float32)
sigma_domain_j = jnp.array(sigma_domain_np, dtype=jnp.float32)

# ---- 2) Test LUT‐table generators ----
# 2a) variable‐std LUT
jax_lut_var = ld_jax.make_lut_table_with_std(
    n_pe_domain_j, n_ph_domain_j,
    x_domain_j, sigma_domain_j,
    switching_signal, p_dpe
)
np_lut_var  = ld_np.make_lut_table_with_std(
    n_pe_domain_np, n_ph_domain_np,
    x_domain_np, sigma_domain_np,
    switching_signal, p_dpe
)
nb_lut_var_fn = numba.njit(ld_np.make_lut_table_with_std)
nb_lut_var    = nb_lut_var_fn(
    n_pe_domain_np, n_ph_domain_np,
    x_domain_np, sigma_domain_np,
    switching_signal, p_dpe
)
print("LUT_var JAX vs NP:");    compare_arrays(np.array(jax_lut_var),  np_lut_var)
print("LUT_var NP vs NB:");    compare_arrays(                  np_lut_var, nb_lut_var)


# 2b) fixed‐std LUT
# here we need stds for sigma‐domain
stds_np = np.ones((P,),dtype=np.float32)*0.5
stds_j  = jnp.array(stds_np)
jax_lut_fix = ld_jax.make_lut_table_fixed_std(
    n_pe_domain_j, n_ph_domain_j,
    x_domain_j, stds_j,
    switching_signal, p_dpe
)
np_lut_fix  = ld_np.make_lut_table_fixed_std(
    n_pe_domain_np, n_ph_domain_np,
    x_domain_np, stds_np,
    switching_signal, p_dpe
)
nb_lut_fix_fn = numba.njit(ld_np.make_lut_table_fixed_std)
nb_lut_fix    = nb_lut_fix_fn(
    n_pe_domain_np, n_ph_domain_np,
    x_domain_np, stds_np,
    switching_signal, p_dpe
)
print("LUT_fix JAX vs NP:");  compare_arrays(np.array(jax_lut_fix),  np_lut_fix)
print("LUT_fix NP vs NB:");  compare_arrays(                 np_lut_fix, nb_lut_fix)

# ---- 3) Test exact‐LR generator ----
jax_exact = ld_jax.make_exact_lr(
    n_pe_domain_j, n_ph_domain_j, p_dpe,
    switching_signal, nan_safe=True, nan_safe_value=nan_safe_value
)
np_exact  = ld_np.make_exact_lr(
    n_pe_domain_np, n_ph_domain_np, p_dpe,
    switching_signal, nan_safe=True, nan_safe_value=nan_safe_value
)
nb_exact_fn = numba.njit(np_exact)
# sample small test
B = 8
x_test = np.random.uniform(0, switching_signal, size=(B,))
mu_test = np.random.uniform(0, switching_signal, size=(B,))
std_test = np.ones((B,))*0.5
# run
j_ex = np.array(jax_exact(jnp.array(x_test), jnp.array(mu_test), jnp.array(std_test)))
n_ex =            np_exact(    x_test,             mu_test,             std_test)
nb_ex =           nb_exact_fn(x_test,             mu_test,             std_test)
print("exactLR JAX vs NP:");    compare_arrays(j_ex,  n_ex)
print("exactLR NP vs NB:");    compare_arrays(n_ex, nb_ex)

jax_train = ld_jax.make_lut_trainable_std_lr(
    jax_lut_var, x_domain_j, sigma_domain_j,
    switching_signal, lambda x: x/1.2,
    return_ratio=False, nan_safe=True, nan_safe_value=nan_safe_value
)
np_train = ld_np.make_lut_trainable_std_lr(
    np_lut_var, x_domain_np, sigma_domain_np,
    switching_signal, lambda x: x/1.2,
    return_ratio=False, nan_safe=True, nan_safe_value=nan_safe_value
)
nb_train_fn = numba.njit(np_train)
# sample pred/obs/std: shape (B,P)
pred = np.random.uniform(0, switching_signal, size=(B,P))
obs  = np.random.uniform(0, switching_signal, size=(B,P))
std_  = np.tile(stds_np[None,:], (B,1))
# run
j_tr = np.array(jax_train(jnp.array(pred), jnp.array(obs), jnp.array(std_)))
n_tr =            np_train(     pred,                obs,                std_)
nb_tr =           nb_train_fn(pred,                obs,                std_)
print("trainLR JAX vs NP:");    compare_arrays(j_tr,  n_tr)
print("trainLR NP vs NB:");    compare_arrays(n_tr, nb_tr)


modes = ["exact", "LUT_trainable_std", "LUT_fixed_std"]
flags = list(product(modes, (False, True), (False, True)))

for mode, return_ratio, nan_safe in flags:
    print(f"\n=== mode={mode}, return_ratio={return_ratio}, nan_safe={nan_safe} ===")

    # build kwargs for JAX and NumPy
    common = dict(
        mode=mode,
        return_ratio=return_ratio,
        nan_safe=nan_safe,
        nan_safe_value=nan_safe_value,
        switching_signal=switching_signal,
        mle_estimator=lambda x: x/1.2,
        std=jnp.ones((P,), jnp.float32)*0.5,
        p_dpe=p_dpe
    )
    if mode == "LUT_trainable_std":
        jargs = {**common, "lut_table": jax_lut_var,
                       "x_domain": x_domain_j,
                       "sigma_domain": sigma_domain_j}
        npargs= {**common, "lut_table": np_lut_var,
                       "x_domain": x_domain_np,
                       "sigma_domain": sigma_domain_np}
    elif mode == "LUT_fixed_std":
        jargs = {**common, "lut_table": jax_lut_fix,
                       "x_domain": x_domain_j}
        npargs= {**common, "lut_table": np_lut_fix,
                       "x_domain": x_domain_np}
    else:  # exact
        jargs = common
        npargs= common

    # instantiate
    jfn  = ld_jax.make_likelihood_fn(**jargs)
    npfn = ld_np.make_likelihood_fn(**npargs)
    nbfn = numba.njit(npfn)

    # run
    j_out  = np.array(jfn(jnp.array(pred), jnp.array(obs)))
    np_out =  npfn(pred, obs)
    nb_out =  nbfn(pred, obs)

    print("jax vs np: ", end=""); compare_arrays(j_out,  np_out)
    print("np vs nb: ", end=""); compare_arrays(np_out, nb_out)

    # clean up big locals
    del jfn, npfn, nbfn, j_out, np_out, nb_out

# once you’re done with both LUT tables, delete them too:
del jax_lut_var, np_lut_var, jax_lut_fix, np_lut_fix

# Verify Tf vs Jax

In [None]:
#TODO

Probably easiest if i wait till the model is trained generate two full models and verify the output is the same

# Now Timing

We only time the full model, no need to do partial 

In [None]:
weights_path = #TODO

In [None]:
def gen_tf_model(pmt_pos_top, weights_path,
    include_wall=True, include_perp=True,
    include_anode=True, multiplication_layers=False,
    radialLCE=False,
    ):
    # FIXME: 
    raise Exception("Not Implemented")

In [None]:
jax_model_fn = gen_jax_model(
    pmt_pos_top, weights_path,
    include_wall=True, include_perp=True,
    include_anode=True, multiplication_layers=False,
    radialLCE=False,
)
np_model_fn = gen_np_model(
    pmt_pos_top, weights_path,
    include_wall=True, include_perp=True,
    include_anode=True, multiplication_layers=False,
    radialLCE=False,
)
tf_model    = gen_tf_model(
    pmt_pos_top, weights_path,
    include_wall=True, include_perp=True,
    include_anode=True, multiplication_layers=False,
    radialLCE=False,
)

# 2) Compile / jit / njit
jax_compiled = jax.jit(jax_model_fn)
nb_model     = numba.njit(np_model_fn)

# 3) Warm-up and trigger compilation
B_warm = 8
xy_warm  = sample_points_in_circle(B_warm, tpc_r=66.4, seed=0)
obs_warm = stats.poisson(mu=1000).rvs(size=(B_warm, pmt_pos_top.shape[0]))
_ = jax_compiled(jnp.array(xy_warm), jnp.array(obs_warm)).block_until_ready()
_ = np_model_fn(xy_warm, obs_warm)
_ = nb_model(xy_warm, obs_warm)
_ = tf_model(xy_warm, obs_warm)  # eager TF call
del xy_warm, obs_warm

# 4) Helper timers
def time_trials(fn, args, n_trials=10, is_jax=False, is_tf=False):
    times = []
    for _ in range(n_trials):
        t0 = time.perf_counter()
        out = fn(*args)
        if is_jax:
            out.block_until_ready()
        elif is_tf:
            # ensure TensorFlow finishes
            if isinstance(out, tf.Tensor): out.numpy()
        t1 = time.perf_counter()
        times.append(t1 - t0)
    arr = np.array(times, dtype=np.float64)
    return arr.mean(), arr.std()

# 5) Benchmark over batch sizes
print("#Events |    JAX (ms)    |   NumPy (ms)  |  Numba (ms)  |   TF (ms)")
for n_events in [100, 1_000, 5_000, 10_000]:
    xy = sample_points_in_circle(n_events, tpc_r=66.4)
    obs= stats.poisson(mu=1000).rvs(size=(n_events, pmt_pos_top.shape[0]))

    args_j = (jnp.array(xy), jnp.array(obs))
    args_p = (xy, obs)
    args_t = (xy, obs)

    j_me, j_st = time_trials(jax_compiled, args_j, n_trials=10, is_jax=True)
    p_me, p_st = time_trials(np_model_fn,  args_p, n_trials=10)
    n_me, n_st = time_trials(nb_model,     args_p, n_trials=10)
    t_me, t_st = time_trials(tf_model.__call__, args_t, n_trials=10, is_tf=True)

    print(f"{n_events:7d} | "
          f"{j_me*1e3:7.2f}±{j_st*1e3:5.2f} | "
          f"{p_me*1e3:7.2f}±{p_st*1e3:5.2f} | "
          f"{n_me*1e3:7.2f}±{n_st*1e3:5.2f} | "
          f"{t_me*1e3:7.2f}±{t_st*1e3:5.2f}")

    # also compare outputs once for equivalence
    j_out = np.array(jax_compiled(jnp.array(xy), jnp.array(obs)))
    np_out= np_model_fn(xy, obs)
    nb_out= nb_model(xy, obs)
    tf_out= tf_model(xy, obs).numpy()

    print("  compare JAX vs NP: ", end=""); compare_arrays(j_out,  np_out)
    print("  compare NP vs NB:  ", end=""); compare_arrays(np_out, nb_out)
    print("  compare NP vs TF:  ", end=""); compare_arrays(np_out, tf_out)

    del xy, obs, j_out, np_out, nb_out, tf_out

# 6) cleanup
del jax_model_fn, np_model_fn, tf_model, jax_compiled, nb_model

# Compute JAX Flops From XLA dump


TODO below is a basic script to compare them amongst one another check all works then do the following:

- Comparison baseline bs = 1024, n_pmts = n_alive 

- Within each category vary, plot and print all options FLOP

In [None]:
from layer_reimplementation_jax import gen_jax_model

BATCH = 2
xy_dummy  = jnp.zeros((BATCH, 2),  dtype=jnp.float32)
obs_dummy = jnp.zeros((BATCH, pmt_pos_top.shape[0]), dtype=jnp.float32)

# Define the variants you care about
variants = [
    ("direct_only", dict(include_wall=False, include_perp=False, include_anode=False,
                         multiplication_layers=False, radialLCE=False)),
    ("with_wall",   dict(include_wall=True,  include_perp=False, include_anode=False,
                         multiplication_layers=False, radialLCE=False)),
    ("with_perp",   dict(include_wall=False, include_perp=True,  include_anode=False,
                         multiplication_layers=False, radialLCE=False)),
    ("with_anode",  dict(include_wall=False, include_perp=False, include_anode=True,
                         multiplication_layers=False, radialLCE=False)),
    ("all_add",     dict(include_wall=True,  include_perp=True,  include_anode=True,
                         multiplication_layers=False, radialLCE=False)),
    ("all_mul",     dict(include_wall=True,  include_perp=True,  include_anode=True,
                         multiplication_layers=True,  radialLCE=False)),
    ("radialLCE",   dict(include_wall=False, include_perp=False, include_anode=False,
                         multiplication_layers=False, radialLCE=True)),
]

for name, cfg in variants:
    print(f"\n=== Tracing variant: {name} ===")
    # 1) build the JAX model function
    model_fn = gen_jax_model(
        pmt_pos_top, weights_path,
        **cfg
    )

    # 2) get its XLA computation
    xla_comp = jax.xla_computation(model_fn)(xy_dummy, obs_dummy)

    # 3a) grab the human-readable HLO text
    hlo_text = xla_comp.as_hlo_text()

    # 3b) grab the raw HLO proto
    hlo_proto = xla_comp.as_hlo_proto()

    # 4) either inspect in Python...
    print("HLO snippet:")
    for line in hlo_text.splitlines()[0:10]:
        print("  ", line)
    # ...or save to files for later profiling:
    with open(f"hlo_{name}.txt", "w") as ftxt:
        ftxt.write(hlo_text)
    with open(f"hlo_{name}.pb", "wb") as fbin:
        fbin.write(hlo_proto.SerializeToString())

    # clean up before next variant
    del model_fn, xla_comp, hlo_text, hlo_proto
