In [61]:
import jax
import jax.numpy as np

from jax import random
from jax import grad, jit, vmap

In [44]:
key = random.PRNGKey(0)
key

DeviceArray([0, 0], dtype=uint32)

# Getting Started

In [17]:
x = np.arange(12)
x

DeviceArray([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11], dtype=int32)

ndarray’s shape (the length along each axis)

In [18]:
x.shape

(12,)

the total number of elements in an ndarray

In [19]:
x.size

12

In [26]:
x = x.reshape(3, 4)
x

DeviceArray([[ 0,  1,  2,  3],
             [ 4,  5,  6,  7],
             [ 8,  9, 10, 11]], dtype=int32)

In [30]:
x = x.reshape(-1, 4)
x

DeviceArray([[ 0,  1,  2,  3],
             [ 4,  5,  6,  7],
             [ 8,  9, 10, 11]], dtype=int32)

In [31]:
np.empty((3, 4))

DeviceArray([[0., 0., 0., 0.],
             [0., 0., 0., 0.],
             [0., 0., 0., 0.]], dtype=float32)

In [32]:
np.zeros((2, 3, 4))

DeviceArray([[[0., 0., 0., 0.],
              [0., 0., 0., 0.],
              [0., 0., 0., 0.]],

             [[0., 0., 0., 0.],
              [0., 0., 0., 0.],
              [0., 0., 0., 0.]]], dtype=float32)

In [33]:
np.ones((2, 3, 4))

DeviceArray([[[1., 1., 1., 1.],
              [1., 1., 1., 1.],
              [1., 1., 1., 1.]],

             [[1., 1., 1., 1.],
              [1., 1., 1., 1.],
              [1., 1., 1., 1.]]], dtype=float32)

randomly sample the values for each element in an ndarray from some probability distribution

>Each of its elements is randomly sampled from a standard Gaussian (normal) distribution with a mean of  0  and a standard deviation of  1 .

In [47]:
x = random.normal(key, (3, 4))
x

DeviceArray([[ 1.1901636 , -1.0996889 ,  0.44367835,  0.5984696 ],
             [-0.39189562,  0.6926197 ,  0.4601835 , -2.0685785 ],
             [-0.21438184, -0.98983073, -0.6789305 ,  0.27362567]],            dtype=float32)

In [48]:
np.array([[2, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])

DeviceArray([[2, 1, 4, 3],
             [1, 2, 3, 4],
             [4, 3, 2, 1]], dtype=int32)

# 2.1.2. Operations

### Elementwise operations
> We can call elementwise operations on any two tensors of **the same shape**

In [49]:
x = np.array([1, 2, 4, 8])
y = np.array([2, 2, 2, 2])
x + y, x - y, x * y, x / y, x ** y  # The ** operator is exponentiation

(DeviceArray([ 3,  4,  6, 10], dtype=int32),
 DeviceArray([-1,  0,  2,  6], dtype=int32),
 DeviceArray([ 2,  4,  8, 16], dtype=int32),
 DeviceArray([0.5, 1. , 2. , 4. ], dtype=float32),
 DeviceArray([ 1,  4, 16, 64], dtype=int32))

In [50]:
np.exp(x)

DeviceArray([2.7182817e+00, 7.3890562e+00, 5.4598152e+01, 2.9809580e+03],            dtype=float32)

### Concatenation
> We can also concatenate multiple ndarrays together, stacking them end-to-end to form a larger ndarray. We just need to provide a list of ndarrays and tell the system along which axis to concatenate. 

rows (axis  0 , the first element of the shape) vs. columns (axis  1 , the second element of the shape)

In [51]:
x = np.arange(12).reshape(3, 4)
y = np.array([[2, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])
np.concatenate([x, y], axis=0), np.concatenate([x, y], axis=1)

(DeviceArray([[ 0,  1,  2,  3],
              [ 4,  5,  6,  7],
              [ 8,  9, 10, 11],
              [ 2,  1,  4,  3],
              [ 1,  2,  3,  4],
              [ 4,  3,  2,  1]], dtype=int32),
 DeviceArray([[ 0,  1,  2,  3,  2,  1,  4,  3],
              [ 4,  5,  6,  7,  1,  2,  3,  4],
              [ 8,  9, 10, 11,  4,  3,  2,  1]], dtype=int32))

In [52]:
x, y, x == y

(DeviceArray([[ 0,  1,  2,  3],
              [ 4,  5,  6,  7],
              [ 8,  9, 10, 11]], dtype=int32),
 DeviceArray([[2, 1, 4, 3],
              [1, 2, 3, 4],
              [4, 3, 2, 1]], dtype=int32),
 DeviceArray([[False,  True, False,  True],
              [False, False, False, False],
              [False, False, False, False]], dtype=bool))

In [53]:
x.sum()

DeviceArray(66, dtype=int32)

# 2.1.3. Broadcasting Mechanism

>when shapes differ, we can still perform elementwise operations by invoking the broadcasting mechanism. 

In [54]:
a = np.arange(3).reshape(3, 1)
b = np.arange(2).reshape(1, 2)
a, b

(DeviceArray([[0],
              [1],
              [2]], dtype=int32),
 DeviceArray([[0, 1]], dtype=int32))

We broadcast the entries of both matrices into a larger  3×2  matrix as follows: for matrix a it replicates the columns and for matrix b it replicates the rows before adding up both elementwise

In [55]:
a + b

DeviceArray([[0, 1],
             [1, 2],
             [2, 3]], dtype=int32)

# 2.1.4. Indexing and Slicing

In [56]:
x[-1], x[1:3]

(DeviceArray([ 8,  9, 10, 11], dtype=int32),
 DeviceArray([[ 4,  5,  6,  7],
              [ 8,  9, 10, 11]], dtype=int32))

In [64]:
x

DeviceArray([[ 0,  1,  2,  3],
             [ 4,  5,  6,  7],
             [ 8,  9, 10, 11]], dtype=int32)

##### [JAX] NEW

In [65]:
jax.ops.index_update(x, jax.ops.index[1, 2], 9.)

DeviceArray([[ 0,  1,  2,  3],
             [ 4,  5,  9,  7],
             [ 8,  9, 10, 11]], dtype=int32)

In [66]:
jax.ops.index_update(x, jax.ops.index[0:2, :], 12.)

DeviceArray([[12, 12, 12, 12],
             [12, 12, 12, 12],
             [ 8,  9, 10, 11]], dtype=int32)

# 2.1.5. Saving Memory

In [67]:
before = id(y)
y = y + x
id(y) == before

False

In [70]:
z = np.zeros_like(y)
print('id(z):', id(z))
jax.ops.index_update(z, jax.ops.index[:], x + y)
print('id(z):', id(z))

id(z): 140203182873592
id(z): 140203182873592


In [71]:
before = id(x)
x += y
id(x) == before

False

# 2.1.6. Conversion to Other Python Objects¶


In [77]:
a = np.array([3.5])
a, a.item(), float(a), int(a)

(DeviceArray([3.5], dtype=float32), 3.5, 3.5, 3)

### Exercises

1. Run the code in this section. Change the conditional statement x == y in this section to x < y or x > y, and then see what kind of ndarray you can get.

In [92]:
x, y, x > y

(DeviceArray([[ 2,  3,  8,  9],
              [ 9, 12, 15, 18],
              [20, 21, 22, 23]], dtype=int32),
 DeviceArray([[ 2,  2,  6,  6],
              [ 5,  7,  9, 11],
              [12, 12, 12, 12]], dtype=int32),
 DeviceArray([[False,  True,  True,  True],
              [ True,  True,  True,  True],
              [ True,  True,  True,  True]], dtype=bool))

2. Replace the two ndarrays that operate by element in the broadcasting mechanism with other shapes, e.g., three dimensional tensors. Is the result the same as expected?

In [86]:
a = np.arange(100).reshape(5, 2, -1)
b = np.arange(50).reshape(5, 2, -1)
a, b

(DeviceArray([[[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
               [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]],
 
              [[20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
               [30, 31, 32, 33, 34, 35, 36, 37, 38, 39]],
 
              [[40, 41, 42, 43, 44, 45, 46, 47, 48, 49],
               [50, 51, 52, 53, 54, 55, 56, 57, 58, 59]],
 
              [[60, 61, 62, 63, 64, 65, 66, 67, 68, 69],
               [70, 71, 72, 73, 74, 75, 76, 77, 78, 79]],
 
              [[80, 81, 82, 83, 84, 85, 86, 87, 88, 89],
               [90, 91, 92, 93, 94, 95, 96, 97, 98, 99]]], dtype=int32),
 DeviceArray([[[ 0,  1,  2,  3,  4],
               [ 5,  6,  7,  8,  9]],
 
              [[10, 11, 12, 13, 14],
               [15, 16, 17, 18, 19]],
 
              [[20, 21, 22, 23, 24],
               [25, 26, 27, 28, 29]],
 
              [[30, 31, 32, 33, 34],
               [35, 36, 37, 38, 39]],
 
              [[40, 41, 42, 43, 44],
               [45, 46, 47, 48, 49]]], dtype=int32))

In [88]:
try:
    a + b
except Exception as e:
  print("Exception {}".format(e))

Exception add got incompatible shapes for broadcasting: (5, 2, 10), (5, 2, 5).
