From eb2cf317c7a5b6ffe385f0606a99e7c509b771b3 Mon Sep 17 00:00:00 2001 From: LisaCoiffard <91796648+LisaCoiffard@users.noreply.github.com> Date: Tue, 13 Feb 2024 14:04:21 +0000 Subject: [PATCH] fix: Change deprecated jax type jax.random.KeyArray to jax.Array (#175) --- qdax/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qdax/types.py b/qdax/types.py index 5000869b..4581c880 100644 --- a/qdax/types.py +++ b/qdax/types.py @@ -45,5 +45,5 @@ def __init__(self) -> None: Mask: TypeAlias = jnp.ndarray # Others -RNGKey: TypeAlias = jax.random.KeyArray +RNGKey: TypeAlias = jax.Array Metrics: TypeAlias = Dict[str, jnp.ndarray]