In [None]:
try:
    from openmdao.utils.notebook_utils import notebook_mode
except ImportError:
    !python -m pip install openmdao[notebooks]

(sec:openmdao_math)=
# Math Library (`openmdao.math`)

Certain functions are useful in a gradient-based optimization context, such as smooth activation functions or differentiable maximum/minimum functions.

Rather than provide a component that forces a user to structure their system in a certain way and add more components than necessary, the `openmdao.math` package is intended to provide a universal source for _composable_ functions that users can use within their own components.

Functions in `openmdao.math` are built using the [jax](https://github.com/google/jax).
This allows users to develop components that use these functions, along with other code written with jax, and leverage capabilities of `jax` like automatic differentiation and just-in-time compilation.

Many of these functions are focused on providing differentiable forms of strictly non-differentiable functions, such as step responses, absolute value, and minimums or maximums.
Near regions where the nominal functions would have invalid derivatives, these functions are smooth but will not perfectly match their non-smooth counterparts.

## Available Functions

```{eval-rst}
    .. autofunction:: openmdao.math.act_tanh
        :noindex:
```

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import openmdao.math as omm

fig, ax = plt.subplots(2, 2, figsize=(8, 8))
fig.suptitle('Impact of different parameters on act_tanh')
x = np.linspace(0, 1, 1000)

mup001 = omm.act_tanh(x, mu=0.001, z=0.5, a=0, b=1)
mup01 = omm.act_tanh(x, mu=0.01, z=0.5, a=0, b=1)
mup1 = omm.act_tanh(x, mu=0.1, z=0.5, a=0, b=1)

ax[0, 0].plot(x, mup001, label=r'$\mu$ = 0.001')
ax[0, 0].plot(x, mup01, label=r'$\mu$ = 0.01')
ax[0, 0].plot(x, mup1, label=r'$\mu$ = 0.1')
ax[0, 0].legend()
ax[0, 0].grid()

zp5 = omm.act_tanh(x, mu=0.01, z=0.5, a=0, b=1)
zp4 = omm.act_tanh(x, mu=0.01, z=0.4, a=0, b=1)
zp6 = omm.act_tanh(x, mu=0.01, z=0.6, a=0, b=1)

ax[0, 1].plot(x, zp4, label=r'$z$ = 0.4')
ax[0, 1].plot(x, zp5, label=r'$z$ = 0.5')
ax[0, 1].plot(x, zp6, label=r'$z$ = 0.6')
ax[0, 1].legend()
ax[0, 1].grid()

a0 = omm.act_tanh(x, mu=0.01, z=0.5, a=0, b=1)
ap2 = omm.act_tanh(x, mu=0.01, z=0.5, a=0.2, b=1)
ap4 = omm.act_tanh(x, mu=0.01, z=0.5, a=0.4, b=1)

ax[1, 0].plot(x, a0, label=r'$a$ = 0.0')
ax[1, 0].plot(x, ap2, label=r'$a$ = 0.2')
ax[1, 0].plot(x, ap4, label=r'$a$ = 0.4')
ax[1, 0].legend()
ax[1, 0].grid()

bp6 = omm.act_tanh(x, mu=0.01, z=0.5, a=0, b=.6)
bp8 = omm.act_tanh(x, mu=0.01, z=0.5, a=0, b=.8)
b1 = omm.act_tanh(x, mu=0.01, z=0.5, a=0, b=1)

ax[1, 1].plot(x, bp6, label=r'$b$ = 0.6')
ax[1, 1].plot(x, bp8, label=r'$b$ = 0.8')
ax[1, 1].plot(x, b1, label=r'$b$ = 1.0')
ax[1, 1].legend()
ax[1, 1].grid()

```{eval-rst}
    .. autofunction:: openmdao.math.smooth_abs
        :noindex:
```

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 4))
fig.suptitle('Impact of different parameters on smooth_abs')
x = np.linspace(-0.2, 0.2, 1000)

mup001 = omm.smooth_abs(x, mu=0.001)
mup01 = omm.smooth_abs(x, mu=0.01)
mup1 = omm.smooth_abs(x, mu=0.1)

ax.plot(x, mup001, label=r'$\mu$ = 0.001')
ax.plot(x, mup01, label=r'$\mu$ = 0.01')
ax.plot(x, mup1, label=r'$\mu$ = 0.1')
ax.legend()
ax.grid()

```{eval-rst}
    .. autofunction:: openmdao.math.smooth_max
        :noindex:
```

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 4))
fig.suptitle('Impact of different parameters on smooth_max of sin and cos')
x = np.linspace(0.5, 1, 1000)

sin = np.sin(x)
cos = np.cos(x)

mup001 = omm.smooth_max(sin, cos, mu=0.001)
mup01 = omm.smooth_max(sin, cos, mu=0.01)
mup1 = omm.smooth_max(sin, cos, mu=0.1)

ax.plot(x, sin, '--', label=r'$\sin{x}$')
ax.plot(x, cos, '--', label=r'$\cos{x}$')
ax.plot(x, mup01, label=r'$\mu$ = 0.01')
ax.plot(x, mup1, label=r'$\mu$ = 0.1')
ax.legend()
ax.grid()

```{eval-rst}
    .. autofunction:: openmdao.math.smooth_min
        :noindex:
```

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 4))
fig.suptitle('Impact of different parameters on smooth_min of sin and cos')
x = np.linspace(0.5, 1, 1000)

sin = np.sin(x)
cos = np.cos(x)

mup001 = omm.smooth_min(sin, cos, mu=0.001)
mup01 = omm.smooth_min(sin, cos, mu=0.01)
mup1 = omm.smooth_min(sin, cos, mu=0.1)

ax.plot(x, sin, '--', label=r'$\sin{x}$')
ax.plot(x, cos, '--', label=r'$\cos{x}$')
ax.plot(x, mup01, label=r'$\mu$ = 0.01')
ax.plot(x, mup1, label=r'$\mu$ = 0.1')
ax.legend(ncol=2)
ax.grid()

```{eval-rst}
    .. autofunction:: openmdao.math.ks_max
        :noindex:
```

In [None]:
from openmdao.math import ks_max

fig, ax = plt.subplots(1, 1, figsize=(4, 4))
fig.suptitle('Impact of different parameters on ks_max')
y = np.random.random(100)
x = np.linspace(0, 1, 100)

rho1 = ks_max(y, rho=10.)
rho10 = ks_max(y, rho=100.)
rho100 = ks_max(y, rho=1000.)

ax.plot(x, y, '.', label='y')
ax.plot(x, rho1 * np.ones_like(x), label='ks_max(y, rho=10)')
ax.plot(x, rho10 * np.ones_like(x), label='ks_max(y, rho=100)')
ax.legend(ncol=1)
ax.grid()

In [None]:
from openmdao.math import ks_min

fig, ax = plt.subplots(1, 1, figsize=(4, 4))
fig.suptitle('Impact of different parameters on ks_max')
y = np.random.random(100) + 5
x = np.linspace(0, 1, 100)

rho1 = ks_min(y, rho=10.)
rho10 = ks_min(y, rho=100.)
rho100 = ks_min(y, rho=1000.)

ax.plot(x, y, '.', label='y')
ax.plot(x, rho1 * np.ones_like(x), label='ks_min(y, rho=10)')
ax.plot(x, rho10 * np.ones_like(x), label='ks_min(y, rho=100)')
ax.legend(ncol=1)
ax.grid()

## Getting derivatives from jax-composed functions

If the user write a function that is composed entirely using jax-based functions (from `jax.numpy`, etc.), then `jax` will in most cases be able to provide derivatives of those functions automatically.

The library has several ways of doing this and the best approach will likely depend on the specific use-case hand.
Rather than provide a component to wrap a `jax` function and provide derivatives automatically, consider the following example as a template for how to utilize `jax` in combination with OpenMDAO components.

The following component wraps the `act_tanh` function in `openmdao.math`. In this case we've assumed values for `$\mu$`, `a`, and `b`, and will only be passing a vector input `x` as well as the activation value `z`.
The output `h` will have a value of approximately `1.0` in those indices of `x` where `x > z`. The value of `h` will be approximately `0.0` in those indices of `x` where `x < z`, and there will be some smooth transition between them. The smoothness of this transition is goverened by `$\mu$`, with smaller values yielding behavior closer to a true step (at the expence of more abrupt changes in the derivatives near `z`.

### compute_primal

In this particular instance, we declare a method of the component named `compute_primal`.
That function name is not special to OpenMDAO and the user could call this function whatever they choose so long as it doesn't interfere with some pre-existing component method name.
In addition to the `self` argument, `compute_primal` takes positional arguments to make it compatible with `jax`.
We also wrap the method with the `jax.jit` decorator (and use `static_argnums` to inform it that the first argument (`self`) is not relevant to `jax`.

### compute
Compute in this case is just a matter of passing the values of the inputs to `compute_primal` and populating the outputs with the results.

###  compute_partials

Computing the partial derivatives across the component is done by passing the inputs to a separate method. Since there are multiple ways of computing the partials with `jax`, this example has four different `_compute_partials_xxx` methods, though only one is used.

Again, these method names are not special and are only used in the context of this example.

### _compute_partials_jacfwd

This uses the `jax.jacfwd` method to compuite the partial derivatives of the calculation with a forward differentiation approach.
This approach should be one of the faster options when there are comparatively few inputs versus outputs.

Note that because we know that the sparsity structure of the jacobian for the output `h` wrt `x` will be diagonal, we extract the diagonal from the derivative that `jax` returns. We also declared this sparsity structure in the corresponding `declare_partials` call.
This pattern is common in vectorized functions but ultimately it's up to the user to know the sparsity structure when they implement the component.

### _compute_partials_jacrev

This is similar to the previous approach except `jax.jacrev` is used.
This function still returns dense jacobians and so the diagonal is extracted from the derivative of `h` with respect to `x`.

Reverse differentiation should be faster when the number of outputs of a function is significantly fewer than the number of inputs, such as in reduction operations.

### _compute_partials_jvp

Because we know the sparisty structure will be diagonal for the derivative of `h` with respect to `x`, we can use the jacobian vector product method provided by `jax` to extract the diagonal elements only, rather than computing the full matrix and extracting the diagonal.

Because there are two arguments we call `jax.jvp` two times, once with the tangents of `x` populated with ones and the tangents of `z` populated with zeros (to compute the partials with respect to `x`), and once for the other way around.

### Which approach to use?

In practice, it's going to be a matter of the user profiling their code to see which of these approaches is fastest.
For the example below, some testing indicated that `jvp` was _slightly_ faster.

An analytic implementation of the derivatives of `act_tanh` was _slightly_ faster than all of these, but only by a few percent at most.
The ability to provide accurate derivatives with little effort vs. hand differentiation at a nearly negligible loss in performance is an attractive proposition for the use of `jax` in OpenMDAO.


In [None]:
    from functools import partial
    import numpy as np
    import jax
    import jax.numpy as jnp
    import openmdao.api as om


    class ActTanhComp(om.ExplicitComponent):

        def initialize(self):
            self.options.declare('vec_size', types=int)

        def setup(self):
            N = self.options['vec_size']
            self.add_input('x', shape=(N,))
            self.add_input('z', shape=(1,))
            self.add_output('h', shape=(N,))

            ar = np.arange(N, dtype=int)

            self.declare_partials(of='h', wrt='x', rows=ar, cols=ar)
            self.declare_partials(of='h', wrt='z')

        @partial(jax.jit, static_argnums=(0,))
        def compute_primal(self, x, z):
            """
            This is where the jax implementation belongs.
            """
            return act_tanh(x, 0.01, z, 0.0, 1.0)

        @partial(jax.jit, static_argnums=(0,))
        def _compute_partials_jacfwd(self, x, z):
            deriv_func = jax.jacfwd(self.compute_primal, argnums=[0, 1])
            dx, dz = deriv_func(x, z)
            return jnp.diagonal(dx), dz

        @partial(jax.jit, static_argnums=(0,))
        def _compute_partials_jacrev(self, x, z):
            deriv_func = jax.jacrev(self.compute_primal, argnums=[0, 1])
            dx, dz = deriv_func(x, z)
            return jnp.diagonal(dx), dz

        @partial(jax.jit, static_argnums=(0,))
        def _compute_partials_jvp(self, x, z):
            dx = jax.jvp(self.compute_primal,
                         primals=(x, z),
                         tangents=(jnp.ones_like(x), jnp.zeros_like(z)))[1]

            dz = jax.jvp(self.compute_primal,
                         primals=(x, z),
                         tangents=(jnp.zeros_like(x), jnp.ones_like(z)))[1]

            return dx, dz

        def compute(self, inputs, outputs, discrete_inputs=None, discrete_outputs=None):
            outputs['h'] = self.compute_primal(*inputs.values())

        def compute_partials(self, inputs, partials, discrete_inputs=None):
            dx, dz = self._compute_partials_jvp(*inputs.values())

            partials['h', 'x'] = dx
            partials['h', 'z'] = dz

## Example Use-Case - Differentiable Counting

Suppose we have some array of data and we want a count of the number of elements in the array that are greater than or equal to some given value.

We can use the `act_tanh` function and set it to return 0 for values less than our threshold, and 1 for valeus greater than our threshold.

Summing the result of _this_ function will then give us the count.
It will be approximate in that there is some inaccuracy where the activation function is smoothed, but it will be differentiable.

In [None]:
import jax.numpy as jnp

import openmdao.api as om
from openmdao.math import act_tanh


class CountingComp(om.ExplicitComponent):
    
    def initialize(self):
        self.options.declare('vec_size', types=(int,))
        self.options.declare('threshold', types=(float,), default=0.0)
        self.options.declare('mu', types=(float,), default=0.01)
    
    def setup(self):
        n = self.options['vec_size']
        self.add_input('x', shape=(n,))
        self.add_output('count', shape=(1,))
        
        # The partials are a dense row in this case (1 row x N inputs)
        # There is no need to specify a sparsity pattern.
        self.declare_partials(of='count', wrt='x')

#     @partial(jax.jit, static_argnums=(0,))
    def _compute_partials_jacfwd(self, x):
        deriv_func = jax.jacfwd(self.compute_primal, argnums=[0])
        dx, = deriv_func(x)
        return dx
    
#     @partial(jax.jit, static_argnums=(0,))
    def _compute_partials_jacrev(self, x):
        deriv_func = jax.jacrev(self.compute_primal, argnums=[0])
        # Always returns a tuple
        dx, = deriv_func(x)
        return dx

#     @partial(jax.jit, static_argnums=(0,))
    def _compute_partials_jvp(self, x):
        # Note that JVP is a poor choice here, since the jacobian is a row vector!
        # We have to call it once with each individual element in x set to 1.0
        # while all the others are zero in order to get a correct result!
        
        # jvp always returns the primal and the jvp
        # This will give incorrect results! There is "cross-talk" amongs the different
        # indices in the tangents.
        _, dx = jax.jvp(self.compute_primal,
                        primals=(x,),
                        tangents=(jnp.ones_like(x),))
        return dx

    # TODO: how to properly use vjp?
    #     @partial(jax.jit, static_argnums=(0,))
    def _compute_partials_vjp(self, x):
        # Note that JVP is a poor choice here, since the jacobian is a row vector!
        # We have to call it once with each individual element in x set to 1.0
        # while all the others are zero in order to get a correct result!
        
        # jvp always returns the primal and the jvp
        # This will give incorrect results! There is "cross-talk" amongs the different
        # indices in the tangents.
        _, vjp_fun = jax.vjp(self.compute_primal, x)
        dx = vjp_fun(self.compute_primal(x))
        return dx
 
    # TODO: how to use jax.jit and get self.<attribute> into the function here? Wrapped function?
#     @partial(jax.jit, static_argnums=(0,))
    def compute_primal(self, x):
        mu = self.options['mu']
        z = self.options['threshold']
        return jnp.sum(act_tanh(x, mu, z, 0.0, 1.0))
    
    def compute(self, inputs, outputs):
        z = self.options['threshold']
        x = inputs['x']
        mu = self.options['mu']
        
        outputs['count'] = self.compute_primal(*inputs.values())
        
    def compute_partials(self, inputs, partials):
        dx = self._compute_partials_jacrev(*inputs.values())

        partials['count', 'x'] = dx


In [None]:
N = 10

p = om.Problem()
p.model.add_subsystem('counter',
                      CountingComp(vec_size=N,threshold=0.5, mu=0.01),
                      promotes_inputs=['x'], promotes_outputs=['count'])
p.setup(force_alloc_complex=True)
p.set_val('x', np.linspace(0, 1, N))
p.run_model()
p.model.list_inputs(print_arrays=True)
p.model.list_outputs(print_arrays=True)

with np.printoptions(linewidth=1024):
    p.check_partials(method='cs', compact_print=False);