# Brief benchmarking

In this notebook, we compare [JAX](jax.readthedocs.io), [MXNet](mxnet.apache.org), [PyTorch](http://pytorch.org/), [Numpy](https://numpy.org/) and [Numba](https://numba.pydata.org) with respect to runtimes in three simple scenarios. **All computations are carried out on multiple CPUs (except for numpy.)**.




## TL,DR

+ MXNet performs **A LOT** faster than all other libraries.


In [1]:
import jax
import jax.numpy as jnp
import mxnet as mx
from mxnet import nd
import numpy as np
import torch
import numba

In [2]:
size = 2000 #  square matrix multiplication

In [3]:

X_jax = jax.random.normal(key=jax.random.PRNGKey(0),shape=(size, size), dtype=jnp.float32)

X_mx=nd.random.normal(shape=(size, size)) # default float32

X_torch = torch.randn((size, size), dtype=torch.float32)

X_np = np.random.normal(size=(size, size)).astype(np.float32)



# 1. Matrix Multiplication on CPU

In [4]:
# JAX
%timeit jnp.dot(X_jax, X_jax)

126 ms ± 14 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [5]:
# JAX with block_intil_read()
%timeit jnp.dot(X_jax, X_jax).block_until_ready()

146 ms ± 15.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
# JAX with jit
def dot_prod(x):
    return jnp.dot(x, x)
dot_prod_jit = jax.jit(dot_prod)
%timeit dot_prod_jit(X_jax)#.block_until_ready()

129 ms ± 10.5 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [7]:
# MXNET
%timeit nd.dot(X_mx, X_mx)

67 µs ± 7.72 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [8]:
# Pytorch
%timeit torch.mm(X_torch,X_torch)

344 ms ± 122 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [9]:
# Numpy
%timeit np.dot(X_np, X_np) # on a single cpu.

241 ms ± 40.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [10]:
# Numba
@numba.jit
def go_fast(a):
    return np.dot(a, a)

_=go_fast(X_np) 
%timeit go_fast(X_np) 

297 ms ± 66.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# 2. Matrix Addition on CPU

In [11]:
# JAX
%timeit jnp.sum(X_jax, 1)

275 µs ± 51.2 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [12]:
# MXNET
%timeit nd.sum(X_mx, 1)

69.1 µs ± 18.8 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [13]:
# Pytorch
%timeit torch.add(X_torch,1)

3.4 ms ± 446 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [14]:
# Numpy
%timeit np.add(X_np, 1)

2.82 ms ± 335 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


# 3. $e^X$

In [15]:
# JAX
%timeit jnp.exp(X_jax)

6.39 ms ± 531 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [16]:
# MXNET
%timeit nd.exp(X_mx)

25.1 µs ± 4.65 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [17]:
# Pytorch
%timeit torch.exp(X_torch)

3.29 ms ± 712 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [18]:
# Numpy
%timeit np.exp(X_np)

5.99 ms ± 938 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


# 4. Sum over  $e^X$ 

In [19]:
# JAX
%timeit jnp.exp(X_jax).sum(axis=1)

7.58 ms ± 840 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [20]:
# MXNET
%timeit nd.exp(X_mx).sum(axis=1)

66.8 µs ± 13.5 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [21]:
# Pytorch
%timeit torch.exp(X_torch).sum(axis=1)

3.81 ms ± 294 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [22]:
# Numpy
%timeit np.exp(X_np).sum(axis=1)

8.97 ms ± 1.25 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
