# Test code run times

## Replacing index assignment

In [2]:
import numpy as np
import jax.numpy as jnp

In [3]:
%%timeit
index_to_replace = 4
replacement_item = 14
items = np.arange(10)
items[index_to_replace] = replacement_item

345 ns ± 10.6 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


### My attempt

In [13]:
%%timeit
index_to_replace = 4
replacement_item = 14
items = np.arange(10)
items = list(items)
items.pop(index_to_replace)
items.insert(index_to_replace,replacement_item)
items = np.array(items)

1.29 μs ± 5.97 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [16]:
%%timeit
index_to_replace = 4
replacement_item = 14
items = np.arange(10)
new_items = []
for index,item in enumerate(items):
    if index == index_to_replace:
        new_items.append(replacement_item)
    else:
        new_items.append(item)
items = np.array(new_items)

1.74 μs ± 34.3 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


### `jax`

In [9]:
%%timeit
index_to_replace = 4
replacement_item = 14
items = jnp.arange(10)
items = items.at[index_to_replace].set(replacement_item)

224 μs ± 555 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


### Indexing with tuple vs converting to `raw_type`

If you need to convert a multidimensional index to `raw_type` before inserting, convert first and then make a tuple of `raw_type`s.

In [5]:
%%timeit
x = np.zeros((10,10,10))
x[(5,5,5)] = 10

376 ns ± 30.8 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [6]:
%%timeit
x = np.zeros((10,10,10))
x[np.array([5,5,5])] = 10

1.38 μs ± 6.64 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


## Replacing in-place operations

### Addition

In [17]:
%%timeit
a = 5
a += 10

16.9 ns ± 0.0211 ns per loop (mean ± std. dev. of 7 runs, 100,000,000 loops each)


In [18]:
%%timeit
a = 5
a = a + 10

16.3 ns ± 0.013 ns per loop (mean ± std. dev. of 7 runs, 100,000,000 loops each)


### Multiplication + index assignment

In [5]:
%%timeit
a = np.arange(10)
a[5] *= 10

365 ns ± 18.2 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [6]:
%%timeit
a = jnp.arange(10)
temp = a[5]*10
a = a.at[5].set(temp)

282 μs ± 25.1 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [7]:
%%timeit
a = jnp.arange(10)
a = a.at[5].multiply(10)

239 μs ± 1.14 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
