diff --git a/qdax/types.py b/qdax/types.py index 5000869b..4581c880 100644 --- a/qdax/types.py +++ b/qdax/types.py @@ -45,5 +45,5 @@ def __init__(self) -> None: Mask: TypeAlias = jnp.ndarray # Others -RNGKey: TypeAlias = jax.random.KeyArray +RNGKey: TypeAlias = jax.Array Metrics: TypeAlias = Dict[str, jnp.ndarray]