In [2]:
import jax 
import jax.numpy as jnp
import numpy as np

# 1. ```jax.Array``` and ```jax.numpy.ndarray``` are the same thing and represented by ```jaxlib.xla_extension.ArrayImpl```

In [6]:
x = jax.random.normal(jax.random.PRNGKey(0), (5,))
x_jnp = jnp.array([1, 2, 3])
x_np = np.array([1, 2, 3])
print(type(x))
print(type(x_jnp))
print(type(x_np))
print(isinstance(x, jnp.ndarray))
print(isinstance(x, jax.Array))
print(isinstance(x_jnp, jnp.ndarray))
print(isinstance(x_jnp, jax.Array))

<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'jaxlib.xla_extension.ArrayImpl'>
<class 'numpy.ndarray'>
True
True
True
True


## 1.1. ```jnp.ndarray``` is just an allias for ```jax.Array```. [JAX official doc](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.html)
## 1.2. ```jax.Array``` seems to be built directly from ```object``` (i.e ```jax.Array(object)``` )

In [15]:
print(type(x_jnp))
print(type(x_jnp).__bases__)
print(type(x).__bases__)

<class 'jaxlib.xla_extension.ArrayImpl'>
(<class 'object'>,)
(<class 'object'>,)


# 2. Shapes

In [17]:
a = jnp.array([1, 2, 3])
b = jnp.zeros((1,))

print(f'{a.shape=}')
print(f'{a[0].shape=}')
print(f'{jnp.array(a[0]).shape=}')
print(f'{jnp.expand_dims(a[0], axis=0).shape=}')
print(f'{b.shape=}')
print(f'{b[0].shape=}')

a.shape=(3,)
a[0].shape=()
jnp.array(a[0]).shape=()
jnp.expand_dims(a[0], axis=0).shape=(1,)
b.shape=(1,)
b[0].shape=()
