<a href="https://colab.research.google.com/github/anshulsawant/WhatDoesThisReallyDo/blob/main/What_does_this_function_really_do_JAX.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
## My attempt to address my biggest numpy fear: I assume that this function is processing stuff in some way,
## but what if it isn't. I hope this is not going to come crashing down.
import jax
from jax import numpy as np

In [14]:
## jax.lax.select
x = np.reshape(np.arange(6)*2 + 1, (2,3))
y = np.reshape(np.arange(6)*2, (2,3))
p = np.reshape(np.array([True, False, True, False, True, False]), (2,3))
z = jax.lax.select(p, x, y)

## Another possible way to select. Is this always true?
assert np.all(z == x*p + y*(1-p))
x,y,z

(Array([[ 1,  3,  5],
        [ 7,  9, 11]], dtype=int32),
 Array([[ 0,  2,  4],
        [ 6,  8, 10]], dtype=int32),
 Array([[ 1,  2,  5],
        [ 6,  9, 10]], dtype=int32))

# JAX vmap

In [52]:
## jax.vmap
def dot(x, y):
  return np.sum(x*y)

def mult(X, Y):
  return X@Y

def mult_einsum(X, Y):
  return np.einsum('ij,kj -> ik', X, Y)

mult = jax.jit(mult)
mult_einsum = jax.jit(mult_einsum)

X = np.reshape(np.repeat(np.arange(10), 100000), (1000,-1))
Y = np.reshape(np.repeat(np.arange(10), 100000), (1000,-1))

mv = jax.vmap(dot, in_axes=(0, None))
mm = jax.jit(jax.vmap(mv, in_axes=(None, 0)))

Y_t = Y.transpose()
X.shape, np.all(mm(X,Y) == mult(X, Y_t)), np.all(mm(X,Y) == mult_einsum(X,Y))

((1000, 1000), Array(True, dtype=bool), Array(True, dtype=bool))

In [32]:
%%timeit
mm(X,Y)

1.54 ms ± 407 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [31]:
%%timeit
np.matmul(X, Y_t)

1.38 ms ± 91.5 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [30]:
%%timeit
mult_einsum(X, Y)

2.5 ms ± 10 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [51]:
f = lambda x, y : x + y
S = jax.jit(jax.vmap(f, in_axes=(1, 2)))
"""
Args:
  x: shape (m x n x p)
  y: shape (m x p x n)
Returns:
  Array z of shape m x n x p x q where z(i, j, k, l) = x[i, j, k, 1] + y[i, 1, j, k]
"""
x = np.reshape(np.arange(24), (2,3, 4))
y = np.reshape(np.arange(2*3*4), (2, 4, 3))
S(x,y), x[:,0,:], y[:,:,0]

(Array([[[ 0,  4,  8, 12],
         [24, 28, 32, 36]],
 
        [[ 5,  9, 13, 17],
         [29, 33, 37, 41]],
 
        [[10, 14, 18, 22],
         [34, 38, 42, 46]]], dtype=int32),
 Array([[ 0,  1,  2,  3],
        [12, 13, 14, 15]], dtype=int32),
 Array([[ 0,  3,  6,  9],
        [12, 15, 18, 21]], dtype=int32))