In [None]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Mon_Oct_12_20:09:46_PDT_2020
Cuda compilation tools, release 11.1, V11.1.105
Build cuda_11.1.TC455_06.29190527_0


In [None]:
#!pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

Sun Jun 26 12:51:39 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   39C    P0    27W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

gpu


In [None]:
import jax

In [None]:
jax.devices()

[GpuDevice(id=0, process_index=0)]

In [None]:
jax.devices("cpu")

[CpuDevice(id=0)]

In [None]:
jax.devices("gpu")

[GpuDevice(id=0, process_index=0)]

In [None]:
import numpy as np
import jax.numpy as jnp

In [None]:
# a function with some amount of calculations
def f(x):
  y1 = x + x*x + 3
  y2 = x*x + x*x.T
  return y1*y2

# generate some random data
x = np.random.randn(3000, 3000).astype('float32')
jax_x_gpu = jax.device_put(jnp.array(x), jax.devices('gpu')[0])
jax_x_cpu = jax.device_put(jnp.array(x), jax.devices('cpu')[0])

# compile function to CPU and GPU backends with JAX
jax_f_cpu = jax.jit(f, backend='cpu')
jax_f_gpu = jax.jit(f, backend='gpu')

# warm-up
jax_f_cpu(jax_x_cpu)
jax_f_gpu(jax_x_gpu);

In [None]:
%timeit -n100 f(x)

100 loops, best of 5: 49.8 ms per loop


In [None]:
%timeit -n100 f(jax_x_cpu).block_until_ready()

100 loops, best of 5: 59.5 ms per loop


In [None]:
%timeit -n100 jax_f_cpu(jax_x_cpu).block_until_ready()

100 loops, best of 5: 10.5 ms per loop


In [None]:
%timeit -n100 f(jax_x_gpu).block_until_ready()

100 loops, best of 5: 1.87 ms per loop


In [None]:
%timeit -n100 jax_f_gpu(jax_x_gpu).block_until_ready()

100 loops, best of 5: 649 µs per loop


In [None]:
jax_x_cpu.device()

CpuDevice(id=0)

In [None]:
jax_x_gpu.device()

GpuDevice(id=0, process_index=0)