In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [2]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)

[-0.3721109   0.26423115 -0.18252768 -0.7368197  -0.44030377 -0.1521442
 -0.67135346 -0.5908641   0.73168886  0.5673026 ]


In [3]:
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU

813 µs ± 1.21 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [5]:
from jax import device_put
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()

835 µs ± 18.6 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [6]:
import jax
import jax.numpy as jnp
import jax.profiler

def func1(x):
  return jnp.tile(x, 10) * 0.5

def func2(x):
  y = func1(x)
  return y, jnp.tile(x, 10) + 1

x = jax.random.normal(jax.random.PRNGKey(42), (1000, 1000))
y, z = func2(x)

z.block_until_ready()

jax.profiler.save_device_memory_profile("memory.prof")

In [7]:
pprof --web memory.prof

SyntaxError: invalid syntax (3932396307.py, line 1)

In [9]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [4]:
import jax
jax.default_backend()

'gpu'

In [5]:
print(jax.config.jax_platforms)

None


In [6]:
import jax

# 设置 JAX 使用的平台为 GPU
jax.config.update('jax_platform_name', 'gpu')

# 检查当前的平台
print("JAX backend:", jax.lib.xla_bridge.get_backend().platform)


JAX backend: gpu


In [8]:
jax.config.update('jax_platform_name', 'gpu')
print(jax.config.jax_platforms)

None
