From 95a182488062403f3f6d734be89e202b012afb16 Mon Sep 17 00:00:00 2001 From: Humphrey Yang Date: Fri, 19 Sep 2025 10:04:37 +1000 Subject: [PATCH] update jax shape discussion --- lectures/jax_intro.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index faae2d9b..531200e4 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -274,12 +274,18 @@ for A in matrices: print(A) ``` -One point to remember is that JAX expects tuples to describe array shapes, even for flat arrays. Hence, to get a one-dimensional array of normal random draws we use `(len, )` for the shape, as in +To get a one-dimensional array of normal random draws, we can either use `(len, )` for the shape, as in ```{code-cell} ipython3 random.normal(key, (5, )) ``` +or simply use `5` as the shape argument: + +```{code-cell} ipython3 +random.normal(key, 5) +``` + ## JIT compilation The JAX just-in-time (JIT) compiler accelerates logic within functions by fusing linear