In [4]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

import numpy as np

In [5]:
key = random.PRNGKey(0)
size = 3000

In [6]:
x = np.random.rand(size,size).astype(np.float32)
%timeit  np.dot(x,x.T)

108 ms ± 4.27 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [7]:
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit  jnp.dot(x, x.T).block_until_ready()

93.9 ms ± 1.47 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [10]:
@jit
def jit_dot(x,y):
    return jnp.dot(x, y)

%timeit  jit_dot(x, x.T).block_until_ready()

94.1 ms ± 1.66 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [13]:
def jit_dot(x,y):
    return jnp.dot(x, y)

build_jit_dot = jit(jit_dot)
%timeit  build_jit_dot(x, x.T)

jit_dot
88.6 ms ± 1.41 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [19]:

@jit
def f(x, y):
  print("Running f():")
  print(f"  x = {x}")
  print(f"  y = {y}")
  result = jnp.dot(x + 1, y + 1)
  print(f"  result = {result}")
  return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)
f(x, y)

x = np.random.randn(6, 8)
y = np.random.randn(8)
f(x, y)

Running f():
  x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=1/0)>
  y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
  result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>
Running f():
  x = Traced<ShapedArray(float32[6,8])>with<DynamicJaxprTrace(level=1/0)>
  y = Traced<ShapedArray(float32[8])>with<DynamicJaxprTrace(level=1/0)>
  result = Traced<ShapedArray(float32[6])>with<DynamicJaxprTrace(level=1/0)>


Array([ 6.5362043,  7.500759 ,  3.7097254,  5.4297004, 20.999733 ,
       -1.6489006], dtype=float32)