<a href="https://colab.research.google.com/github/ai2ys/Python-Cheat-Sheet-As-Jupyter-Notebooks/blob/master/python_cheat_sheet_jax_numpy_array.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# JAX Cheat Sheet - JAX NumPy Array

This is my 'cheat sheet' for creating and handling arrays using JAX Numpy API. Topics are


In [1]:
import jax
import jax.numpy as jnp
jax.__version__

'0.4.4'

## Array creation routines
The following link gives an overview available functions including array creation routines such as `zeros`, `ones`, and `empty`.

https://jax.readthedocs.io/en/latest/jax.numpy.html

### Create an array of zeros, ones, or an empty array
- `jnp.zeros`
- `jnp.ones`
- `jnp.empty` ➡ in plain NumPy this this routine is much faster than `ones` and `zeros`, here it is not.

Array creation in JAX will not be faster using `empty`, because XLA cannot create unitialized arrays.

https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.empty.html#jax.numpy.empty

In [2]:
rows = 2000
columns = 3000

Using `jnp.empty`

In [3]:
%time array_empty = jnp.empty((rows, columns))
print('shape', array_empty.shape)

CPU times: user 809 ms, sys: 1.17 s, total: 1.98 s
Wall time: 2.82 s
shape (2000, 3000)


Using `jnp.zeros`

In [4]:
%time array_zeros = jnp.zeros((rows, columns))
print('shape', array_zeros.shape)

CPU times: user 1.47 ms, sys: 0 ns, total: 1.47 ms
Wall time: 2.06 ms
shape (2000, 3000)


Using `jnp.ones`

In [5]:
%time array_ones = jnp.ones((rows, columns))
print('shape', array_ones.shape)

CPU times: user 2.11 ms, sys: 46 µs, total: 2.16 ms
Wall time: 1.92 ms
shape (2000, 3000)


### Create a 1D array of equally spaced numbers

- `jax.numpy.arange`<br>https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.arange.html
- `jax.numpy.linspace`<br>https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.linspace.html

Create an array with integers from `0` to `9` using `jnp.arange(10)`. Keep in mind that the stop value (here: `10`) is not included.

Using `jnp.arrange(10)` will create an array containing all integer values from `0` to `9`. Without specifying explicitly the start value will be `0` and the step size will be `1`. 

In [6]:
jnp.arange(10)

Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

Create an array with a specified start, stop, and step size. As mentioned previously the stop value is not included.

In [7]:
jnp.arange(start=5, stop=30, step=5)

Array([ 5, 10, 15, 20, 25], dtype=int32)

Here an example with an non-integer step size. The NumPy documentation recommends to use `linspace` instead of `arrange` with non-integer step sizes.

In [8]:
jnp.arange(start=0, stop=1, step=0.1, dtype=float)

Array([0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], dtype=float32)

In contrast to `jnp.arange` the method `jnp.linspace` offeres the option to include or exclude the endpoint. But keep in mind that you have to adjust the parameters to get the same spacing.

In [9]:
start = 0
num = 10

In [10]:
jnp.linspace(start=start, stop=1, endpoint=False, num=num)

Array([0.        , 0.1       , 0.2       , 0.3       , 0.4       ,
       0.5       , 0.6       , 0.7       , 0.8       , 0.90000004],      dtype=float32)

In [11]:
jnp.linspace(start=start, stop=0.9, endpoint=True, num=num)

Array([0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], dtype=float32)

In [12]:
jnp.linspace(start=start, stop=1, endpoint=True , num=num)

Array([0.        , 0.11111111, 0.22222222, 0.33333334, 0.44444445,
       0.5555556 , 0.6666667 , 0.7777778 , 0.8888889 , 1.        ],      dtype=float32)

## Selecting array elements, crop and subsample arrays
Crops can be extracted from arrays and they can be subsampled as well. Here an example for subsampling an array.

Indices in Python are zero-based and the `stop`-index value will not be included.
- `array[index]` &rarr; accessing an array element by its index
- `array[start:stop]` &rarr; extracting a slice of the array containing all values from index `start` to `stop-1`
- `array[start:stop:step]` &rarr; extracting a slice of the array containing only n-th value defined by the `step`-size.

In [13]:
data_1D = jnp.arange(101)
data_1D

Array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
        13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
        26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
        39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
        52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
        65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
        78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
        91,  92,  93,  94,  95,  96,  97,  98,  99, 100], dtype=int32)

Explicitly defining the `start` and `end`. Both could be ommitted, because `start=0` and `end=101` are the first and last index. 

In [14]:
# start=0, end=101, step=10
data_1D[0:101:10]

Array([  0,  10,  20,  30,  40,  50,  60,  70,  80,  90, 100], dtype=int32)

Ommitting the values for `start` and `end` as all values from start to end will be selected.

In [15]:
# step=10
data_1D[::10]

Array([  0,  10,  20,  30,  40,  50,  60,  70,  80,  90, 100], dtype=int32)

Example for creating a crop. Have in mind that the endpoint will not be included.

In [16]:
start = 12
end = 50
data_1D = jnp.arange(101)
data_1D[start:end]

Array([12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28,
       29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45,
       46, 47, 48, 49], dtype=int32)

### Selecting values using negative indices
Using index `-1` will select the last element in the array. Using index `-2` will select the penultimate element and so on.

In [17]:
data_1D = jnp.arange(10)
print(data_1D)
print(f"index -1: {data_1D[-1]} - last element")
print(f"index -2: {data_1D[-2]} - penultimate element")

[0 1 2 3 4 5 6 7 8 9]
index -1: 9 - last element
index -2: 8 - penultimate element


But have in mind that **using the index `-1` as `end`-point** in order to crop a portion of the source array will **exclude the last element** of the source array.

In [18]:
data_1D[0:-1]

Array([0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=int32)

The zero as startpoint can be ommitted.

In [19]:
data_1D[:-1]

Array([0, 1, 2, 3, 4, 5, 6, 7, 8], dtype=int32)

In order to extract a subset including the last value of the array, the `end`-point has to be defined explicitly or ommitting it.

In [20]:
# ommiting the end point to extract including the last value
data_1D[5:]

Array([5, 6, 7, 8, 9], dtype=int32)

In [21]:
# explicitily setting the end point to extract including the last value
data_1D[5:10]

Array([5, 6, 7, 8, 9], dtype=int32)

### Modifying single array values

In NumPy it is possible access and modify array values using indices like the following:


In [22]:
import numpy as np
array_a = np.array([0,1,2,3,4])
print(f'array_a: {array_a}')
array_a[1:3] = -1
print(f'array_a: {array_a}')

array_a: [0 1 2 3 4]
array_a: [ 0 -1 -1  3  4]


Jax does not support item assignment like it is done in NumPy, but the same result can be achieved using the helper `at` the following:


In [23]:
array_a = jnp.array([0,1,2,3,4])
print(f'array_a: {array_a}')
array_a = array_a.at[1:3].set(-1)
print(f'array_a: {array_a}')

array_a: [0 1 2 3 4]
array_a: [ 0 -1 -1  3  4]


For more information about the usage of the helper property `at` checkout:

https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html#jax.numpy.ndarray.at

## Array multiplications
Types of multiplications

- Scalar product of matrix `A` with scalar `b`<br>
$A \cdot b$<br>
```python
# Python
A * b
# in case b is a scalar, equivalent with
jnp.multiply(A, b)
```

- Hadamard product or element-wise product of two matrices `A` and `B`<br>
$A \circ B = A \odot B$<br>
```python
# Python
A * B
# equivalent with
jnp.multiply(A, B)
```

- Dot product of two matrices `A` and `B`<br>
$A \cdot B$
```python
# Python
A.dot(B) 
# equivalent with
jnp.dot(A, B)
# equivalent with
A @ B 
```

In [24]:
array_a = jnp.arange(start=0,stop=4)
array_b = jnp.arange(start=2,stop=6)
scalar=3
print(f'array_a: {array_a}')
print(f'array_b: {array_b}')
print(f'scalar: {scalar}')

array_a: [0 1 2 3]
array_b: [2 3 4 5]
scalar: 3


In [25]:
# scalar multiplication
array_a * scalar

Array([0, 3, 6, 9], dtype=int32)

In [26]:
# element-wise multiplication (equivalent to Matlab: a .* b)
print(array_a * array_b)
print(jnp.multiply(array_a, array_b))

[ 0  3  8 15]
[ 0  3  8 15]


In [27]:
# dot product
print(jnp.dot(array_a, array_b))
print(array_a @ array_b)

26
26


## Stacking arrays in depth

In [28]:
array_list = [
    jnp.arange(start=0, stop=10), 
    jnp.arange(start=10, stop=20), 
    jnp.arange(start=20, stop=30)]
arrays_stacked = jnp.dstack(array_list)
print(f'array_list:\n{array_list}')
print(f'arrays_stacked:\n{arrays_stacked}')
print(f'shape: {arrays_stacked.shape}')

array_list:
[Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32), Array([10, 11, 12, 13, 14, 15, 16, 17, 18, 19], dtype=int32), Array([20, 21, 22, 23, 24, 25, 26, 27, 28, 29], dtype=int32)]
arrays_stacked:
[[[ 0 10 20]
  [ 1 11 21]
  [ 2 12 22]
  [ 3 13 23]
  [ 4 14 24]
  [ 5 15 25]
  [ 6 16 26]
  [ 7 17 27]
  [ 8 18 28]
  [ 9 19 29]]]
shape: (1, 10, 3)
