## [Deep learning with JAX](https://github.com/che-shr-cat/JAX-in-Action)
- Chapter 6: Vectorizing your code

In [1]:
import os
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

import jax
from jax import random
from jax.extend import backend  
jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")
from jax import numpy as jnp

print(f"number of cores:", jax.local_device_count())
print(f"using: ", backend.get_backend().platform) 

number of cores: 8
using:  cpu


In [2]:
def dot(v1, v2):
    return jax.numpy.dot(v1, v2)

dot(jax.numpy.array([1.0, 1, 1]), jax.numpy.array([1, 2,-1]))

Array(2., dtype=float64)

In [8]:
from jax import random
rng_key = random.PRNGKey(42)

# vs = random.normal(rng_key, shape=(10000,3))
vs = jnp.ones((10000,3))
vs = vs.at[:,1].set(2)
vs = vs.at[:,2].set(3)
v1s = vs[:5000, :]
v2s = vs[5000:, :]
v1s.shape, v2s.shape

((5000, 3), (5000, 3))

In [9]:
v1s[:5, :]

Array([[1., 2., 3.],
       [1., 2., 3.],
       [1., 2., 3.],
       [1., 2., 3.],
       [1., 2., 3.]], dtype=float64)

In [10]:
vs.shape, v1s.shape, v2s.shape
dot(v1s, v2s.T).shape, dot(v1s.T, v2s).shape

((5000, 5000), (3, 3))

In [11]:
v1s[0][:5]

Array([1., 2., 3.], dtype=float64)

### Using List

In [12]:
# [dot(v1s[i], v2s[i]) for i in range(v1s.shape[0])]
data = [dot(v1, v2) for v1, v2 in zip(v1s, v2s)]
data[:5]

[Array(14., dtype=float64),
 Array(14., dtype=float64),
 Array(14., dtype=float64),
 Array(14., dtype=float64),
 Array(14., dtype=float64)]

### Manual Vectorization

In [13]:
def dot_vectorized(v1s, v2s):
    return jnp.einsum("ij,ij->i", v1s, v2s)

r = dot_vectorized(v1s, v2s)
r.shape, r[:5]

((5000,), Array([14., 14., 14., 14., 14.], dtype=float64))

### Automatic vectorization

In [14]:
dot_vmapped = jax.vmap(dot)
r = dot_vmapped(v1s, v2s)
r.shape, r[:5]

((5000,), Array([14., 14., 14., 14., 14.], dtype=float64))

In [15]:
dot_vectorized_jitted = jax.jit(dot_vectorized)
dot_vmapped_jitted = jax.jit(dot_vmapped)

#### timig list, manual and automatic vectorization

In [16]:
%timeit [dot(v1, v2) for v1, v2 in zip(v1s, v2s)]    # list comprehension
%timeit dot_vectorized(v1s, v2s).block_until_ready() # manual
%timeit dot_vmapped(v1s, v2s).block_until_ready()    # automatic

100 ms ± 3.31 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
283 μs ± 23.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
371 μs ± 15.1 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [17]:
%timeit dot_vectorized_jitted(v1s, v2s).block_until_ready()
%timeit dot_vmapped_jitted(v1s, v2s).block_until_ready()

43.7 μs ± 2.56 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
42.9 μs ± 1.44 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [18]:
jax.make_jaxpr(dot)(jnp.array([1.0, 1, 1]), jnp.array([1, 2,-1]))

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f64[3][39m b[35m:i64[3][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f64[][39m = dot_general[
      dimension_numbers=(([0], [0]), ([], []))
      preferred_element_type=float64
    ] a b
  [34m[22m[1min [39m[22m[22m(c,) }

In [19]:
jax.make_jaxpr(dot_vectorized)(v1s, v2s)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f64[5000,3][39m b[35m:f64[5000,3][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f64[5000][39m = dot_general[
      dimension_numbers=(([1], [1]), ([0], [0]))
      preferred_element_type=float64
    ] a b
  [34m[22m[1min [39m[22m[22m(c,) }

In [20]:
jax.make_jaxpr(dot_vmapped)(v1s, v2s)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f64[5000,3][39m b[35m:f64[5000,3][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f64[5000][39m = dot_general[
      dimension_numbers=(([1], [1]), ([0], [0]))
      preferred_element_type=float64
    ] a b
  [34m[22m[1min [39m[22m[22m(c,) }

In [21]:
jax.make_jaxpr(dot_vectorized_jitted)(v1s, v2s)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f64[5000,3][39m b[35m:f64[5000,3][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f64[5000][39m = pjit[
      name=dot_vectorized
      jaxpr={ [34m[22m[1mlambda [39m[22m[22m; d[35m:f64[5000,3][39m e[35m:f64[5000,3][39m. [34m[22m[1mlet
          [39m[22m[22mf[35m:f64[5000][39m = dot_general[
            dimension_numbers=(([1], [1]), ([0], [0]))
            preferred_element_type=float64
          ] d e
        [34m[22m[1min [39m[22m[22m(f,) }
    ] a b
  [34m[22m[1min [39m[22m[22m(c,) }

Controlling array axes to map over

In [22]:
jax.make_jaxpr(dot_vmapped_jitted)(v1s, v2s)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f64[5000,3][39m b[35m:f64[5000,3][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f64[5000][39m = pjit[
      name=dot
      jaxpr={ [34m[22m[1mlambda [39m[22m[22m; d[35m:f64[5000,3][39m e[35m:f64[5000,3][39m. [34m[22m[1mlet
          [39m[22m[22mf[35m:f64[5000][39m = dot_general[
            dimension_numbers=(([1], [1]), ([0], [0]))
            preferred_element_type=float64
          ] d e
        [34m[22m[1min [39m[22m[22m(f,) }
    ] a b
  [34m[22m[1min [39m[22m[22m(c,) }

You can control which array axes to map over. For this, the `vmap()` function has a parameter
called `in_axes`. This parameter can be an integer, None, or a (possibly nested) standard
Python container such as a tuple, list, or dict.

If the in_axes parameter is an integer (the default value is 0), then the array axis specified
by this number is used to map over all the function arguments. 
If you need to use a different index for different parameters, you can use a tuple of integers
and None’s with a length equal to the number of positional arguments of the original
function. The `None` value means we do not map over this particular parameter. The general
rule is the `in_axes` structure should correspond to the structure of the associated inputs.

In [23]:
print(v1s.shape, v2s.shape)
print(jax.vmap(dot, in_axes=(0, 0))(v1s, v2s).shape)
print(jax.vmap(dot, in_axes=(1, 1))(v1s, v2s).shape)

(5000, 3) (5000, 3)
(5000,)
(3,)


In [24]:
# assignment copies the array, in contrast to numpy
x1 = jnp.array([1.0, 1.0, 1.0])
x2 = x1 # x1.copy()

x2 = x2.at[0].set(2.0)
print(x1, x2, sep="\n")

[1. 1. 1.]
[2. 1. 1.]


In [25]:
def scaled_dot(v1, v2, koeff):
    return koeff * jnp.dot(v1, v2)

v1s_ = v1s 
v2s_ = v2s.T
k = 1.0

v1s_.shape, v2s_.shape

((5000, 3), (3, 5000))

In [26]:
scaled_dot_batched = jax.vmap(scaled_dot, in_axes=(0, 1, None))
tmp = scaled_dot_batched(v1s_, v2s_, k)
print(tmp.shape, tmp[:5])


(5000,) [14. 14. 14. 14. 14.]


In [27]:
# using the in-axes parameter with Python container.
def scaled_dot(data, koeff):
    return koeff * jnp.dot(data['a'], data['b'])

scaled_dot_batched = jax.vmap(scaled_dot, in_axes=({'a': 0, 'b': 1}, None))
tmp = scaled_dot_batched({'a':v1s_, 'b':v2s_}, k)
print(tmp.shape, tmp[:5])


(5000,) [14. 14. 14. 14. 14.]


- Using `out_axes` parameter

The out_axes parameter specifies where the mapped axis should appear in the output. In this case, `out_axes=1` indicates that the output should have the mapped axis as its second axis (axis 1).

In [28]:
def scale(v, koeff):
    return koeff * v 
scale_batched = jax.vmap(scale, in_axes=(0, None), out_axes=1)

print(v1s.shape, scale_batched(v1s, 2.0).shape)

(5000, 3) (3, 5000)


**using named arguments**
There is 2 ways to used named arguments with vmap:
1. use `functools.partial`
2. broadcast the named arguments to the batch dimension


1.

In [29]:
from functools import partial
scale2 = partial(scale, koeff=2.0)
scale_batched = jax.vmap(scale2, in_axes=0, out_axes=0)
scale_batched(v1s).shape

(5000, 3)

In [30]:
# using decorator
from functools import partial

@partial(jax.vmap, in_axes=(0, None), out_axes=(1))
def scale(v, koeff):
    return koeff * v

tmp = scale(v1s, 2.0)
print(tmp.shape, tmp[0, :5])

(3, 5000) [2. 2. 2. 2. 2.]


2.

In [33]:
def scale(v, koeff=1.0):
    return koeff * v 

scale_batched = jax.vmap(scale, in_axes=(0), out_axes=(1))
scale_batched(v1s, koeff=jnp.broadcast_to(2.0, (v1s.shape[0],))).shape

(3, 5000)

In [34]:
def scale(v, koeff=1.0):
    return koeff * v 

scale_batched = jax.vmap(scale, in_axes=(0, None), out_axes=0)
# scale_batched(v1s, koeff=2.0).shape # error passing keyword argument
scale_batched(v1s, 2.0).shape # works

(5000, 3)

#### [using collective operations](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators)

In [39]:
print(f"{sum(range(5))=}")
arr = jnp.array(range(5))
norm = jax.vmap(lambda x: x/jax.lax.psum(x, axis_name='batch'), axis_name='batch')
norm(arr)

sum(range(5))=10


Array([0. , 0.1, 0.2, 0.3, 0.4], dtype=float64)

In [40]:
# Augmenting a single element of data

from jax import lax

add_noise_func = lambda x: x + 10
horizontal_flip_func = lambda x: x + 1
rotate_func = lambda x: x + 2
adjust_colors_func = lambda x: x + 3

augmentations = [
    add_noise_func, 
    horizontal_flip_func, 
    rotate_func, 
    adjust_colors_func
    ]

def random_augmentation(image, augmentations, rng_key):
    augmentation_index = random.randint(key=rng_key, minval=0, maxval=len(augmentations), shape=())
    augmented_image = lax.switch(augmentation_index, augmentations, image)
    return augmented_image

image = jnp.array(range(100))
augmented_image = random_augmentation(image, augmentations, random.PRNGKey(42))

images = jnp.repeat(jnp.reshape(image, (1, -1)), 10, axis=0)
print(images.shape)

rng_keys = random.split(random.PRNGKey(42), len(images))
random_augmentation_batch = jax.vmap(random_augmentation, in_axes=(0, None, 0))
augmented_images = random_augmentation_batch(images, augmentations, rng_keys)
print(augmented_images.shape)

(10, 100)
(10, 100)


In [41]:
# Calculating per-sample gradients

from jax import grad, vmap, jit 
import matplotlib.pyplot as plt

x = jnp.linspace(0, 10*jnp.pi, num=1000)
e = 10*random.normal(random.PRNGKey(42), shape=x.shape)
y = 65 +1.8*x + 40 * jnp.cos(x) + e 

model_parameters = jnp.array([1.0, 1.0])
def predict(theta, x):
    w, b = theta 
    return w*x + b

def loss_fn(model_parameters, x, y):
    y_hat = predict(model_parameters, x)
    return (y_hat-y)**2

grads_fn = jit(vmap(grad(loss_fn), in_axes=(None, 0, 0)))
batch_x, batch_y = x[:32], y[:32]
grads_fn(model_parameters, batch_x, batch_y).shape

# print(x.shape, y.shape)
# plt.plot(x, y)

(32, 2)

- Vectorizing loops

In [42]:
import jax
import jax.numpy as jnp

# Define a function that performs a computation on a single element
def compute_element(x):
    return x ** 2

# Create an array of values
values = jnp.array([1, 2, 3, 4, 5])

# Vectorize the computation using jax.vmap()
vectorized_fn = jax.vmap(compute_element)

# Apply the vectorized function to the array of values
result = vectorized_fn(values)

print(result)


[ 1  4  9 16 25]
