In [1]:
import time
import jax
jax.config.update("jax_enable_x64", True)  # Ensure float64 so kernel dtypes match
import jax.numpy as jnp
from temgym_core.evaluate import eval_gaussians_differentiable

import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '0.5' 

In [None]:
# ----- use the differentiable wrapper that has @custom_vjp: eval_gaussians_differentiable -----
# assumes you already defined:
# - eval_gaussians_ref(...)
# - eval_gaussians_gpu_kernel(...)
# - eval_gaussians_differentiable = @custom_vjp wrapper using kernel (fwd) + ref (bwd)

key = jax.random.key(0)
P_side = 128
P = P_side * P_side
N = 2**13

r2 = jax.random.normal(key, (P, 2), dtype=jnp.float64)
r_m   = jax.random.normal(key, (N, 2), dtype=jnp.float64)
k     = jax.random.normal(key, (N,),   dtype=jnp.float64)
C     = (jax.random.normal(key, (N,), dtype=jnp.float64)
         + 1j * jax.random.normal(key, (N,), dtype=jnp.float64)).astype(jnp.complex128)
eta   = (jax.random.normal(key, (N,2), dtype=jnp.float64)
         + 1j * jax.random.normal(key, (N,2), dtype=jnp.float64)).astype(jnp.complex128)
Q_inv = (jax.random.normal(key, (N,2,2), dtype=jnp.float64)
         + 1j * jax.random.normal(key, (N,2,2), dtype=jnp.float64)).astype(jnp.complex128)

def forward_E(r_m, C, eta, Q_inv, k, r2):
    return eval_gaussians_differentiable(r_m, C, eta, Q_inv, k, r2)

E0 = forward_E(r_m, C, eta, Q_inv, k, r2)
I_meas = (jnp.abs(E0)**2).block_until_ready()

def loss(r_m, C, eta, Q_inv, k, r2):
    E = forward_E(r_m, C, eta, Q_inv, k, r2)
    return jnp.mean((jnp.abs(E)**2 - I_meas)**2)

# --- Reverse-mode directional derivative check on r_m ---
dr = jax.random.normal(key, r_m.shape, dtype=jnp.float64) * 1e-3

# compute grad wrt r_m
grad_r_m = jax.grad(loss, argnums=0)(r_m, C, eta, Q_inv, k, r2)

# reverse-mode directional derivative = <grad, dr>
dd_rev = jnp.vdot(grad_r_m, dr).real  # real scalar

# finite difference for comparison
eps = 1e-4
fd = (loss(r_m + eps*dr, C, eta, Q_inv, k, r2) - loss(r_m, C, eta, Q_inv, k, r2)) / eps
print("directional derivative (reverse-mode) vs finite diff:",
      float(dd_rev), float(fd))

# --- Timing reverse-mode ---
loss_jit = jax.jit(loss)
t0 = time.time(); v = loss_jit(r_m, C, eta, Q_inv, k, r2).block_until_ready(); t1 = time.time()
print(f"Loss fwd (jit; first call compiles): {t1-t0:.3f}s  value={float(v):.6e}")

grads_jit = jax.jit(jax.grad(loss, argnums=(0,1,2,3,4)))
t0 = time.time(); g = grads_jit(r_m, C, eta, Q_inv, k, r2)
jax.tree_util.tree_map(lambda x: x.block_until_ready(), g); t1 = time.time()
print(f"Grad (jit; first call compiles): {t1-t0:.3f}s")


2025-09-25 14:45:01.557020: W external/xla/xla/tsl/framework/bfc_allocator.cc:501] Allocator (GPU_0_bfc) ran out of memory trying to allocate 4.00GiB (rounded to 4294967296)requested by op 
2025-09-25 14:45:01.557343: W external/xla/xla/tsl/framework/bfc_allocator.cc:512] *******************************____*_*__***********____***********__****************************____
E0925 14:45:01.557401 3869473 pjrt_stream_executor_client.cc:3026] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 4294967296 bytes. [tf-allocator-allocation-error='']


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 4294967296 bytes.

: 

In [None]:
t0 = time.time(); g = grads_jit(r_m, C, eta, Q_inv, k, r2)
jax.tree_util.tree_map(lambda x: x.block_until_ready(), g); t1 = time.time()
print(f"Grad (jit): {t1-t0:.3f}s")

Grad (jit): 0.210s
