In [2]:
# Dont run this unless u want to disable GPU

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

In [3]:
import jax
import jax.numpy as jnp



2022-09-13 20:56:34.792548: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:265] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


DeviceArray([11., 20., 29.], dtype=float32)

In [None]:
x = jnp.arange(5)
w = jnp.array([2., 3., 4.])

In [4]:
def convolve(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  return jnp.array(output)

In [5]:
convolve(x, w)

DeviceArray([11., 20., 29.], dtype=float32)

In [6]:
# Imagine if we want to apply this to a batch

xs = jnp.stack([x, x])
ws = jnp.stack([w, w])


In [7]:
# Looping over it on python

def manually_batched_convolve(xs, ws):
  output = []
  for i in range(xs.shape[0]):
    output.append(convolve(xs[i], ws[i]))
  return jnp.stack(output)

In [8]:
manually_batched_convolve(xs, ws)

DeviceArray([[11., 20., 29.],
             [11., 20., 29.]], dtype=float32)

This is really inefficient doing this way. This is where vectorization comes

In [9]:
auto_batch_convolve = jax.vmap(convolve)

In [10]:
auto_batch_convolve(xs, ws)

DeviceArray([[11., 20., 29.],
             [11., 20., 29.]], dtype=float32)

In [12]:
# To change the dimensions

auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1)

xst = jnp.transpose(xs)
wst = jnp.transpose(ws)

auto_batch_convolve_v2(xst, wst)


DeviceArray([[11., 11.],
             [20., 20.],
             [29., 29.]], dtype=float32)

In [13]:
# We can even jit this :)

auto_batch_convolve_v2 = jax.jit(auto_batch_convolve_v2)


In [14]:
auto_batch_convolve_v2(xst, wst)


DeviceArray([[11., 11.],
             [20., 20.],
             [29., 29.]], dtype=float32)