In [1]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2025 NVIDIA Corporation
Built on Fri_Feb_21_20:23:50_PST_2025
Cuda compilation tools, release 12.8, V12.8.93
Build cuda_12.8.r12.8/compiler.35583870_0


In [None]:
#!pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

In [2]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

Sat Dec  6 13:37:00 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.105.08             Driver Version: 580.105.08     CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4080        On  |   00000000:01:00.0 Off |                  N/A |
|  0%   39C    P8             22W /  400W |      15MiB /  16376MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

In [1]:
import jax
print(jax.devices())

[CudaDevice(id=0)]


In [2]:
jax.devices("cpu")

[CpuDevice(id=0)]

In [3]:
jax.devices("gpu")

[CudaDevice(id=0)]

In [4]:
import numpy as np
import jax.numpy as jnp

In [5]:
# a function with some amount of calculations
def f(x):
  y1 = x + x*x + 3
  y2 = x*x + x*x.T
  return y1*y2

# generate some random data
x = np.random.randn(3000, 3000).astype('float32')
jax_x_gpu = jax.device_put(jnp.array(x), jax.devices('gpu')[0])
jax_x_cpu = jax.device_put(jnp.array(x), jax.devices('cpu')[0])

# compile function to CPU and GPU backends with JAX
jax_f_cpu = jax.jit(f, backend='cpu')
jax_f_gpu = jax.jit(f, backend='gpu')

# warm-up
jax_f_cpu(jax_x_cpu)
jax_f_gpu(jax_x_gpu);

In [6]:
%timeit -n100 f(x)

32.3 ms ± 224 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
%timeit -n100 f(jax_x_cpu).block_until_ready()

31.8 ms ± 300 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [8]:
%timeit -n100 jax_f_cpu(jax_x_cpu).block_until_ready()

7.55 ms ± 321 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [9]:
%timeit -n100 f(jax_x_gpu).block_until_ready()

1.21 ms ± 328 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [10]:
%timeit -n100 jax_f_gpu(jax_x_gpu).block_until_ready()

119 μs ± 32.9 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
