# Pure Python in JAX

This is a fairly advanced topic, and you should only read this if you have gone through the rest of the notebooks and are now somewhat comfortable with JAX.

### Prerequisites
Intermediate loops.

Sometimes, JAX is really not the right tool for the job. This is usually going to be the case when some very nasty loops are involved, or when you want to use variable size data structures inside a specific subroutine. 
It can also be the case that on GPU JAX does the job, but on CPU it is too slow. In this case, you can branch out to pure Python code, and then come back to JAX.

This notebook shows how to use pure Python code inside JAX and explains why you might want to do that, but also what the limitations are.

In the "loops" notebook, we encountered a question about implementing bubble sort in JAX:
```python
def bubble_sort(arr): 
    n = len(arr) 
    res = np.copy(arr)
    for i in range(n-1): 
        for j in range(0, n-i-1): 
            if res[j] > res[j+1]: 
                res[j], res[j+1] = res[j+1], res[j]
    return res   
```
for which a JAX implementation would be (don't look if you have not done the exercise yet!):

In [1]:
import jax
from jax import numpy as jnp
from jax.lax import dynamic_update_slice, while_loop, fori_loop
import numpy as np


@jax.jit
def bubble_sort(arr):
    n = arr.shape[0]

    def inner_cond(carry, i):
        j, _ = carry
        return j < n - i - 1

    def inner_body(carry):
        j, inner_res = carry

        # This is the equivalent of the line res[j], res[j+1] = res[j+1], res[j]
        update = jnp.array([inner_res[j + 1], inner_res[j]])
        inner_res = dynamic_update_slice(inner_res, update, (j,))
        return j + 1, inner_res

    def outer_loop(i, carry):
        _, carry = while_loop(lambda val: inner_cond(val, i), inner_body, (0, carry))
        return carry

    res = fori_loop(0, n - 1, outer_loop, arr)
    return res

In [2]:
bubble_sort(np.random.randn(5000))

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


Array([ 1.6327504 ,  0.83159214,  0.42975882, ..., -0.9669213 ,
        0.3114629 , -1.1753024 ], dtype=float32)

Even though this is a good exercise to understand how to use JAX, one has to admit this is a bit of an overkill. The JAX implementation is much more complex than the pure Python implementation but it is also hardly faster, which is not what we would expect from a library that is supposed to be fast.

In [42]:
def np_bubble_sort(arr):
    n = len(arr)
    res = np.copy(arr)
    for i in range(n - 1):
        for j in range(0, n - i - 1):
            if res[j] > res[j + 1]:
                res[j], res[j + 1] = res[j + 1], res[j]
    return res

In [43]:
arr_test = np.random.randn(1000)

In [45]:
%timeit bubble_sort(arr_test).block_until_ready()
%timeit np_bubble_sort(arr_test)

30.7 ms ± 83.9 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
91.3 ms ± 140 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


and the performance completely goes away when we compile the function using (for example) `numba`:

In [46]:
import numba as nb

nb_bubble_sort = nb.njit(np_bubble_sort, boundscheck=False)

%timeit nb_bubble_sort(arr_test)

453 μs ± 125 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


This lack of performance is due to the fact that JAX is not optimized for this kind of loops and is one of the reasons why you might want to use pure Python code inside JAX.
Thankfully, this is possible and quite easy to do using a `pure_callable` function. This function allows you to use Python code inside a JAX function, *provided that the Python code does not modify any global state*, otherwise the JAX code is not guaranteed to be correct.

In [49]:
@jax.jit
def wrapped_bubble_sort(arr):
    # This uses pure callback
    return jax.pure_callback(nb_bubble_sort, arr, arr)

In [50]:
wrapped_bubble_sort(arr_test)

jax.pure_callback failed
Traceback (most recent call last):
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/jax/_src/callback.py", line 94, in pure_callback_impl
    return tree_util.tree_map(np.asarray, callback(*args))
                                          ^^^^^^^^^^^^^^^
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/jax/_src/callback.py", line 71, in __call__
    return tree_util.tree_leaves(self.callback_func(*args, **kwargs))
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/dispatcher.py", line 442, in _compile_for_args
    raise e
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/dispatcher.py", line 375, in _compile_for_args
    return_val = self.compile(tuple(argtypes))
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^

XlaRuntimeError: INTERNAL: CpuCallback error: Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/ipykernel_launcher.py", line 18, in <module>
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/traitlets/config/application.py", line 1075, in launch_instance
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/ipykernel/kernelapp.py", line 739, in start
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/tornado/platform/asyncio.py", line 205, in start
  File "/home/adrien/miniconda3/lib/python3.12/asyncio/base_events.py", line 639, in run_forever
  File "/home/adrien/miniconda3/lib/python3.12/asyncio/base_events.py", line 1985, in _run_once
  File "/home/adrien/miniconda3/lib/python3.12/asyncio/events.py", line 88, in _run
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 534, in process_one
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 362, in execute_request
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/ipykernel/kernelbase.py", line 778, in execute_request
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 449, in do_execute
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3075, in run_cell
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3130, in _run_cell
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/IPython/core/async_helpers.py", line 128, in _pseudo_sync_runner
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
  File "/tmp/ipykernel_178186/3832731969.py", line 1, in <module>
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/jax/_src/pjit.py", line 338, in cache_miss
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/jax/_src/pjit.py", line 188, in _python_pjit_helper
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/jax/_src/core.py", line 2803, in bind
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/jax/_src/core.py", line 442, in bind_with_trace
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/jax/_src/core.py", line 955, in process_primitive
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/jax/_src/pjit.py", line 1738, in _pjit_call_impl
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/jax/_src/pjit.py", line 1714, in call_impl_cache_miss
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/jax/_src/pjit.py", line 1668, in _pjit_call_impl_python
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/jax/_src/profiler.py", line 333, in wrapper
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 1278, in __call__
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/jax/_src/interpreters/mlir.py", line 2768, in _wrapped_callback
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/jax/_src/callback.py", line 269, in _callback
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/jax/_src/callback.py", line 97, in pure_callback_impl
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/jax/_src/callback.py", line 71, in __call__
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/dispatcher.py", line 444, in _compile_for_args
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/dispatcher.py", line 906, in compile
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/dispatcher.py", line 80, in compile
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/dispatcher.py", line 95, in _compile_cached
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/dispatcher.py", line 107, in _compile_core
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/compiler.py", line 744, in compile_extra
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/compiler.py", line 438, in compile_extra
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/compiler.py", line 506, in _compile_bytecode
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/compiler.py", line 481, in _compile_core
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/compiler_machinery.py", line 364, in run
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/compiler_lock.py", line 35, in _acquire_compile_lock
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/compiler_machinery.py", line 311, in _runPass
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/compiler_machinery.py", line 273, in check
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/typed_passes.py", line 112, in run_pass
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/typed_passes.py", line 93, in type_inference_stage
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/typeinfer.py", line 1083, in propagate
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/typeinfer.py", line 182, in propagate
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/typeinfer.py", line 583, in __call__
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/typeinfer.py", line 607, in resolve
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/typeinfer.py", line 1577, in resolve_call
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/typing/context.py", line 197, in resolve_function_type
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/typing/context.py", line 248, in _resolve_user_function_type
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/types/functions.py", line 312, in get_call_type
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/typing/templates.py", line 350, in apply
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/typing/templates.py", line 613, in generic
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/typing/templates.py", line 712, in _get_impl
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/core/typing/templates.py", line 785, in _build_impl
  File "/home/adrien/PycharmProjects/.virtualenvs/jax-workshop/lib/python3.12/site-packages/numba/np/arrayobj.py", line 4924, in impl_numpy_copy
UnboundLocalError: cannot access local variable 'numpy_copy' where it is not associated with a value

This does not work: this is because arr is passed in as a JAX array, with which numba is not compatible., and we therefore need to convert it to a numpy array first.

In [54]:
@jax.jit
def wrapped_bubble_sort(arr):
    # This uses pure callback
    def callback(arr):
        np_arr = np.array(arr)
        return nb_bubble_sort(np_arr)

    return jax.pure_callback(callback, arr, arr)


%timeit wrapped_bubble_sort(arr_test).block_until_ready()

665 μs ± 1.66 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


and so we get the expected performance.
What's the point you may ask? The point is that you can now compose `numba` functions with JAX functions, and this can be very useful when you have a function that is not supported by JAX but that you want to use inside a JAX function. This includes seamlessly vectorizing functions.

In [56]:
jax.vmap(wrapped_bubble_sort)(jnp.array([arr_test, arr_test]))

Array([[-3.84594  , -3.0768633, -2.9661233, ...,  2.7572317,  3.0126548,
         3.125715 ],
       [-3.84594  , -3.0768633, -2.9661233, ...,  2.7572317,  3.0126548,
         3.125715 ]], dtype=float32)

### Questions:
#### Q1: 
As it is implemented, `pure_callback` uses a naive vectorization strategy. Look up its documentation and implement the two other strategies for `vmap_method`.

#### Q2:
Try calling the jacobian of `wrapped_bubble_sort` and see what happens. Compare with the jacobian of `bubble_sort`, and the jacobian of `jnp.sort`.
Implement a numba wrapped version of bubble sorting, but with custom gradients.  

# Branching between CPU and GPU
Sometimes, you may have an algorithm that is very fast on GPU but slow on CPU or vice versa, in which case, you may want to branch the behaviour depending on the device.
For instance, `searchsorted` function looks up `vals` in `arr` and returns the indices at which `vals` should be inserted to maintain order. 
It has several implementations, which are available as a `method` argument.

In [60]:
arr = np.sort(np.random.randn(100))
vals = np.random.randn(50)

%timeit jnp.searchsorted(arr, vals, method='scan')  # standard loop over arr
%timeit jnp.searchsorted(arr, vals, method='scan_unrolled')  # parallel loop over arr
%timeit jnp.searchsorted(arr, vals, method='sort')  # join arr and vals and find the argsort that would sort the result
%timeit jnp.searchsorted(arr, vals, method='compare_all')  # does what it says on the tin

11.3 μs ± 375 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
10.1 μs ± 378 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
30.1 μs ± 1.7 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
26.7 μs ± 1.13 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


Typically you would like to use `scan` on CPU and `sort` on GPU, but JAX does not do it for you automatically.
Thankfully, you can do this via `jax.lax.platform_dependent`:

In [64]:
def searchsorted(arr, vals):
    return jax.lax.platform_dependent(
        arr,
        vals,
        default=lambda x, y: jnp.searchsorted(
            x, y, method="scan_unrolled"
        ),  # fallback implementation
        cpu=lambda x, y: jnp.searchsorted(x, y, method="scan"),  # CPU implementation
        cuda=lambda x, y: jnp.searchsorted(x, y, method="sort"),
    )  # GPU implementation

### Questions:
#### Q1:
When the `value` array is sorted too, the `searchsorted` function can be implemented in O(n) time as
```python
def searchsorted(arr, vals):
    j = 0
    m = arr.shape[0]
    n = vals.shape[0]
    idx = np.empty(n, dtype=np.int_)
    for i in range(n):
        while (vals[i] > arr[np.minimum(j, m-1)]) and (j < m):
            j += 1    
        idx[i] = j
    return idx
```

Implement this function using numba, pure JAX, and compare the performance with the JAX implementation.
Then branch between this and the JAX implementation using sorting. Compare the  performance on CPU and GPU depending on the size of the arrays.