# Data manipulation

In [1]:
import jax 
import jax.numpy as jnp

In [2]:
x = jnp.arange(12)
x

Array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11], dtype=int32)

In [3]:
x.size

12

In [5]:
X = x.reshape(3,4)
X

Array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]], dtype=int32)

In [6]:
x.reshape(3,-1) # Since we can automatically compute second dimension

Array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]], dtype=int32)

In [8]:
jnp.zeros((2,3,4))

Array([[[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]],

       [[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]]], dtype=float32)

In [14]:
key = jax.random.PRNGKey(0)
key

Array([0, 0], dtype=uint32)

In [15]:
jax.random.normal(key, (3,4))

Array([[ 1.1901639 , -1.0996888 ,  0.44367844,  0.5984697 ],
       [-0.39189556,  0.69261974,  0.46018356, -2.068578  ],
       [-0.21438177, -0.9898306 , -0.6789304 ,  0.27362573]],      dtype=float32)

In [16]:
jnp.array([[1,2,3],[43,2,1]])

Array([[ 1,  2,  3],
       [43,  2,  1]], dtype=int32)

In [17]:
X[-1]

Array([ 8,  9, 10, 11], dtype=int32)

In [18]:
X[1:3]

Array([[ 4,  5,  6,  7],
       [ 8,  9, 10, 11]], dtype=int32)

In [21]:
X_new_1 = X.at[1,2].set(400)
X_new_1

Array([[  0,   1,   2,   3],
       [  4,   5, 400,   7],
       [  8,   9,  10,  11]], dtype=int32)

In [24]:
X_new_2 = X.at[0,1:3].set(123321)
X_new_2

Array([[     0, 123321, 123321,      3],
       [     4,      5,      6,      7],
       [     8,      9,     10,     11]], dtype=int32)

In [25]:
jnp.exp(x)

Array([1.0000000e+00, 2.7182817e+00, 7.3890562e+00, 2.0085537e+01,
       5.4598152e+01, 1.4841316e+02, 4.0342880e+02, 1.0966332e+03,
       2.9809580e+03, 8.1030840e+03, 2.2026465e+04, 5.9874141e+04],      dtype=float32)

In [28]:
x = jnp.array([1., 2, 4, 8])
y = jnp.array([2, 2, 2, 2])
x + y, x - y, x * y, x / y, x ** y

(Array([ 3.,  4.,  6., 10.], dtype=float32),
 Array([-1.,  0.,  2.,  6.], dtype=float32),
 Array([ 2.,  4.,  8., 16.], dtype=float32),
 Array([0.5, 1. , 2. , 4. ], dtype=float32),
 Array([ 1.,  4., 16., 64.], dtype=float32))

In [35]:
x = jnp.arange(12).reshape(3,-1)
y = jnp.ones((3,4),dtype="int32")
jnp.concatenate((x,y),axis=0)

Array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [ 1,  1,  1,  1],
       [ 1,  1,  1,  1],
       [ 1,  1,  1,  1]], dtype=int32)

In [45]:
x > y

Array([[False, False,  True,  True],
       [ True,  True,  True,  True],
       [ True,  True,  True,  True]], dtype=bool)

In [37]:
x.sum()

Array(66, dtype=int32)

In [42]:
type(jax.device_put(jax.device_get(x)))

jaxlib.xla_extension.ArrayImpl