In [None]:
try:
    from openmdao.utils.notebook_utils import notebook_mode  # noqa: F401
except ImportError:
    !python -m pip install openmdao[notebooks]

# Computing Partial Derivatives of Components Using JAX

One of the barriers to using OpenMDAO is that to truly take advantage of the framework, the user needs to 
write code for the analytic partial derivatives of their `Components`. To avoid that, users can use 
the optional third-party [JAX](https://jax.readthedocs.io/en/latest/index.html) library, which can 
automatically differentiate native Python and NumPy functions.  To simplify jax usage within OpenMDAO, 
we've created two component classes, [JaxExplicitComponent](jax_explicitcomp_api.ipynb) and 
[JaxImplicitComponent](jax_implicitcomp_api.ipynb).  These components require only the definition of 
a `compute_primal` method that takes the component's inputs as arguments and returns the component's 
outputs.  For a `JaxExplicitComponent`, `compute_primal` replaces the `compute` method, and for a
`JaxImplicitComponent` it replaces the `apply_nonlinear` method.

This notebook will describe in more detail how to create and use a JaxExplicitComponent or 
JaxImplicitComponent and will give examples.

Before going further, it's a good idea to aquaint yourself with some of jax's 'sharp edges' 
[here](https://docs.jax.dev/en/latest/notebooks/Common_Gotchas_in_JAX.html). This will hopefully 
make the process of creating a `JaxExplicitComponent` or `JaxImplicitComponent` a less frustrating one.

The use of JAX is optional for OpenMDAO so if not already installed, the user needs to install it by 
issuing *one* of the following commands at your operating system command prompt:
```
pip install jax jaxlib
pip install openmdao[jax]
pip install openmdao[all]
```


In [None]:
!pip install jax

The JAX library includes a NumPy-like API, `jax.numpy`, which implements the NumPy API using the primitives in JAX. Almost anything that can be done with NumPy can be done with `jax.numpy`. JAX arrays are similar to NumPy arrays, but they are designed to work with accelerators such as GPUs and TPUs. 

To use `jax.numpy`, it needs to be imported, using the commonly used `jnp` abbreviation.

In [None]:
import jax

The default for JAX is to do single precision computations. For this example, we want to use double precision, so this line of code is needed.

In [None]:
jax.config.update("jax_enable_x64", True)

## Configuration Options

[JaxExplicitComponent](jax_explicitcomp_api.ipynb) and [JaxImplicitComponent](jax_implicitcomp_api.ipynb) 
both have the following options:

- **use_jit**:  
      If True, compute_primal and its corresponding derivatives method will be jit compiled.
      Defaults to True.

- **derivs_method**:  
      This defaults to 'jax' but can also be set to 'cs' or 'fd' in order to debug the
      component or perform performance comparisons.

- **default_to_dyn_shapes**:  
      If set to True, any variables in the component that don't have a specified
      shape will be shaped dynamically.  Unshaped inputs will be marked as `shape_by_conn` and unshaped
      outputs will be marked as `compute_shape`, where the `compute_shape` function uses `jax.eval_shape`
      internally to determine output shapes based on input shapes. Defaults to False.

- **default_shape**:  
      This defaults to (1,), meaning that unshaped variables will be allocated as size 1
      arrays.  If true scalars are desired instead, this can be set to ().  Note that if
      `default_to_dyn_shapes` is True, this option is ignored.


## Automatic Determination of Derivative Direction
The relative size of component inputs and outputs determines the direction to be used when solving
for the partial jacobian.  If the size of the inputs is greater than or equal to the size of the 
outputs, then the derivatives will be computed in forward mode.  Otherwise, reverse mode will be used.  
Note that this automatic determination of derivative direction only occurs if the `matrix_free`
attribute is False.
