<a href="https://colab.research.google.com/github/Peter-obi/JAX/blob/main/Working_with_jax_arrays_gpu.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import jax
import jax.numpy as jnp

In [2]:
np.array ([1,2,3,4,5])

array([1, 2, 3, 4, 5])

In [3]:
jnp.array([1,2,3,4,5])

Array([1, 2, 3, 4, 5], dtype=int32)

In [4]:
np.sum([1,2,3,4,5])

np.int64(15)

In [5]:
try:
  jnp.sum([1,2,3,4,5])
except TypeError as e:
  print(e)

sum requires ndarray or scalar arguments, got <class 'list'> at position 0.


In [6]:
jnp.sum(jnp.array([1,2,3,4,5]))

Array(15, dtype=int32)

In [7]:
arr = jnp.array([[1,2,3,4,5]])

In [8]:
arr.ndim

2

In [9]:
arr.shape

(1, 5)

In [10]:
arr.dtype

dtype('int32')

In [11]:
arr.size

5

In [12]:
arr.nbytes

20

##Jax x Devices

In [13]:
jax.devices()

[CudaDevice(id=0)]

In [14]:
jax.local_devices()

[CudaDevice(id=0)]

In [15]:
jax.device_count('gpu')

1

In [16]:
arr = jnp.array([[1,2,3,4,5]])

In [17]:
arr.device

CudaDevice(id=0)

In [19]:
arr_cpu = jax.device_put(arr, jax.devices('cpu')[0]) #create a copy of data on the specified device and return it
arr_cpu.device

CpuDevice(id=0)

In [21]:
arr.device

CudaDevice(id=0)

In [22]:
arr_host = jax.device_get(arr) #to transfer data from a device to the Python process on host.

In [23]:
type(arr_host)

numpy.ndarray

In [24]:
arr_host

array([[1, 2, 3, 4, 5]], dtype=int32)

##Committed vs Uncommitted

In [25]:
arr.device #uncommitted

CudaDevice(id=0)

In [26]:
arr_cpu.device

CpuDevice(id=0)

In [27]:
arr + arr_cpu

Array([[ 2,  4,  6,  8, 10]], dtype=int32)

In [29]:
arr_gpu = jax.device_put(arr, jax.devices('gpu')[0]) #committed

In [31]:
try:
  arr_gpu + arr_cpu
except ValueError as e:
  print(e) #cant do operations on tensors committed to different devices

Received incompatible devices for jitted computation. Got argument x of add with shape int32[1,5] and device ids [0] on platform GPU and argument y of add with shape int32[1,5] and device ids [0] on platform CPU


##Asynchronous Dispatch

In [32]:
a = jnp.array(range(1000000)).reshape((1000,1000))
a.device

CudaDevice(id=0)

In [39]:
%time x = jnp.dot(a,a)

CPU times: user 1.13 ms, sys: 0 ns, total: 1.13 ms
Wall time: 786 Âµs


In [46]:
%time x = jnp.dot(a,a).block_until_ready()

CPU times: user 1.26 ms, sys: 0 ns, total: 1.26 ms
Wall time: 7.12 ms


In [36]:
a_cpu = jax.device_put(a, jax.devices('cpu')[0])
a_cpu

Array([[     0,      1,      2, ...,    997,    998,    999],
       [  1000,   1001,   1002, ...,   1997,   1998,   1999],
       [  2000,   2001,   2002, ...,   2997,   2998,   2999],
       ...,
       [997000, 997001, 997002, ..., 997997, 997998, 997999],
       [998000, 998001, 998002, ..., 998997, 998998, 998999],
       [999000, 999001, 999002, ..., 999997, 999998, 999999]],      dtype=int32)

In [44]:
%time x = jnp.dot(a_cpu, a_cpu).block_until_ready()

CPU times: user 242 ms, sys: 0 ns, total: 242 ms
Wall time: 123 ms
