In [None]:
try:
    from openmdao.utils.notebook_utils import notebook_mode
except ImportError:
    !python -m pip install openmdao[notebooks]

(sec:openmdao_math)=
# Math Library (`openmdao.math`)

Certain functions are useful in a gradient-based optimization context, such as smooth activation functions or differentiable maximum/minimum functions.

Rather than provide a component that forces a user to structure their system in a certain way and add more components than necessary, the `openmdao.math` package is intended to provide a universal source for _composable_ functions that users can use within their own components.

These functions trade accuracy for differentiability.
Near regions where the nominal functions would have invalid derivatives, these functions are smooth but will not perfectly match their non-smooth counterparts.

These functions can also be constructed using the `jax` library to support automatic differentiation and just-in-time compilation, as explained below.


## Available Functions

```{eval-rst}
    .. autofunction:: openmdao.math.act_tanh
        :noindex:
```

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import openmdao.math as omm

fig, ax = plt.subplots(2, 2, figsize=(8, 8))
fig.suptitle('Impact of different parameters on act_tanh')
x = np.linspace(0, 1, 1000)

mup001 = omm.act_tanh(x, mu=0.001, z=0.5, a=0, b=1)
mup01 = omm.act_tanh(x, mu=0.01, z=0.5, a=0, b=1)
mup1 = omm.act_tanh(x, mu=0.1, z=0.5, a=0, b=1)

ax[0, 0].plot(x, mup001, label=r'$\mu$ = 0.001')
ax[0, 0].plot(x, mup01, label=r'$\mu$ = 0.01')
ax[0, 0].plot(x, mup1, label=r'$\mu$ = 0.1')
ax[0, 0].legend()
ax[0, 0].grid()

zp5 = omm.act_tanh(x, mu=0.01, z=0.5, a=0, b=1)
zp4 = omm.act_tanh(x, mu=0.01, z=0.4, a=0, b=1)
zp6 = omm.act_tanh(x, mu=0.01, z=0.6, a=0, b=1)

ax[0, 1].plot(x, zp4, label=r'$z$ = 0.4')
ax[0, 1].plot(x, zp5, label=r'$z$ = 0.5')
ax[0, 1].plot(x, zp6, label=r'$z$ = 0.6')
ax[0, 1].legend()
ax[0, 1].grid()

a0 = omm.act_tanh(x, mu=0.01, z=0.5, a=0, b=1)
ap2 = omm.act_tanh(x, mu=0.01, z=0.5, a=0.2, b=1)
ap4 = omm.act_tanh(x, mu=0.01, z=0.5, a=0.4, b=1)

ax[1, 0].plot(x, a0, label=r'$a$ = 0.0')
ax[1, 0].plot(x, ap2, label=r'$a$ = 0.2')
ax[1, 0].plot(x, ap4, label=r'$a$ = 0.4')
ax[1, 0].legend()
ax[1, 0].grid()

bp6 = omm.act_tanh(x, mu=0.01, z=0.5, a=0, b=.6)
bp8 = omm.act_tanh(x, mu=0.01, z=0.5, a=0, b=.8)
b1 = omm.act_tanh(x, mu=0.01, z=0.5, a=0, b=1)

ax[1, 1].plot(x, bp6, label=r'$b$ = 0.6')
ax[1, 1].plot(x, bp8, label=r'$b$ = 0.8')
ax[1, 1].plot(x, b1, label=r'$b$ = 1.0')
ax[1, 1].legend()
ax[1, 1].grid()

```{eval-rst}
    .. autofunction:: openmdao.math.smooth_abs
        :noindex:
```

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 4))
fig.suptitle('Impact of different parameters on smooth_abs')
x = np.linspace(-0.2, 0.2, 1000)

mup001 = omm.smooth_abs(x, mu=0.001)
mup01 = omm.smooth_abs(x, mu=0.01)
mup1 = omm.smooth_abs(x, mu=0.1)

ax.plot(x, mup001, label=r'$\mu$ = 0.001')
ax.plot(x, mup01, label=r'$\mu$ = 0.01')
ax.plot(x, mup1, label=r'$\mu$ = 0.1')
ax.legend()
ax.grid()

```{eval-rst}
    .. autofunction:: openmdao.math.smooth_max
        :noindex:
```

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 4))
fig.suptitle('Impact of different parameters on smooth_max of sin and cos')
x = np.linspace(0.5, 1, 1000)

sin = np.sin(x)
cos = np.cos(x)

mup001 = omm.smooth_max(sin, cos, mu=0.001)
mup01 = omm.smooth_max(sin, cos, mu=0.01)
mup1 = omm.smooth_max(sin, cos, mu=0.1)

ax.plot(x, sin, '--', label=r'$\sin{x}$')
ax.plot(x, cos, '--', label=r'$\cos{x}$')
ax.plot(x, mup01, label=r'$\mu$ = 0.01')
ax.plot(x, mup1, label=r'$\mu$ = 0.1')
ax.legend()
ax.grid()

```{eval-rst}
    .. autofunction:: openmdao.math.smooth_min
        :noindex:
```

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 4))
fig.suptitle('Impact of different parameters on smooth_min of sin and cos')
x = np.linspace(0.5, 1, 1000)

sin = np.sin(x)
cos = np.cos(x)

mup001 = omm.smooth_min(sin, cos, mu=0.001)
mup01 = omm.smooth_min(sin, cos, mu=0.01)
mup1 = omm.smooth_min(sin, cos, mu=0.1)

ax.plot(x, sin, '--', label=r'$\sin{x}$')
ax.plot(x, cos, '--', label=r'$\cos{x}$')
ax.plot(x, mup01, label=r'$\mu$ = 0.01')
ax.plot(x, mup1, label=r'$\mu$ = 0.1')
ax.legend(ncol=2)
ax.grid()

## Derivative functions

The functions provided in `openmdao.math` have corresponing analytic derivative functions.

These functions are vectorized in such a way that their jacobian matrix is typically diagonal, where the diagonal elements of the jacobian are nonzero and all off-diagonal elements are zero.
Instead of returning the entire dense jacobian, these derivative functions generally return the diagonal elements only.
There are some exceptions to this, such as `norm`, since it can return a non-diagonal jacobian when the `axis` argument is specified.

Derivatives of function outputs with respect to their inputs are returned by functions named
`d_{func_name}`. In addition to the arguments of the function, functions which accept more than one argument
provide boolean arguments which determine whether the derivative with respect to each argument is returned.

For instance, `d_smooth_abs(x, mu=1.0E-5, dmu=False)` returns the jacobian matrix for the smooth absolute value function with respect `x` and `None`, since the user disabled the calculation of the derivatives with respect to `mu`.  This can be useful when some arguments to the function are constants in the particular use-case.

```{eval-rst}
    .. autofunction:: openmdao.math.d_act_tanh
        :noindex:
```

```{eval-rst}
    .. autofunction:: openmdao.math.d_smooth_abs
        :noindex:
```

```{eval-rst}
    .. autofunction:: openmdao.math.d_smooth_max
        :noindex:
```

```{eval-rst}
    .. autofunction:: openmdao.math.d_smooth_min
        :noindex:
```

In [None]:
import numpy as np
import openmdao.math as omm

x = np.linspace(-2, 2, 5)
print(f'x = {x}\n')

d_sabs_dx, d_sabs_dmu = omm.d_smooth_abs(x, mu=1.0E-5, dmu=False)

print(f'd_smoothabs_dx = {d_sabs_dx}\n')
print(f'd_smoothabs_dmu = {d_sabs_dmu}\n')

## Example Use-Case - Differentiable Counting

Suppose we have some array of data and we want a count of the number of elements in the array that are greater than or equal to some given value.

We can use the `act_tanh` function and set it to return 0 for values less than our threshold, and 1 for valeus greater than our threshold.

Summing the result of _this_ function will then give us the count.
It will be approximate in that there is some inaccuracy where the activation function is smoothed, but it will be differentiable.

In [None]:
import openmdao.api as om
from openmdao.math import act_tanh, d_act_tanh


class CountingComp(om.ExplicitComponent):
    
    def initialize(self):
        self.options.declare('vec_size', types=(int,))
        self.options.declare('threshold', types=(float,), default=0.0)
        self.options.declare('mu', types=(float,), default=0.01)
    
    def setup(self):
        n = self.options['vec_size']
        self.add_input('x', shape=(n,))
        self.add_output('count', shape=(1,))
        
        self.declare_partials(of='count', wrt='x')
    
    def compute(self, inputs, outputs):
        z = self.options['threshold']
        x = inputs['x']
        mu = self.options['mu']
        
        outputs['count'] = np.sum(act_tanh(x, mu=mu, z=z, a=0, b=1))
        
    def compute_partials(self, inputs, partials):
        z = self.options['threshold']
        x = inputs['x']
        mu = self.options['mu']
        
        dact_dx, _, _, _, _ = d_act_tanh(x, mu=mu, z=z, a=0, b=1,
                                         dmu=False, dz=False, da=False, db=False)
        # The derivative of the sum function is just a row of ones, so we can just do this.
        partials['count', 'x'] = dact_dx


In [None]:
N = 10

p = om.Problem()
p.model.add_subsystem('counter',
                      CountingComp(vec_size=N,threshold=0.5, mu=0.01),
                      promotes_inputs=['x'], promotes_outputs=['count'])
p.setup(force_alloc_complex=True)
p.set_val('x', np.random.random(N))
p.run_model()
p.check_partials(method='cs', compact_print=True);

# `jax` Math Library (openmdao.math.jax)

The `jax` Python package provides functionality that duplicates the math capability of in `openmdao.math` but built upon the `jax` python packaage.
This allows these functions to support automatic differentiation and just-in-time (jit) compilation.

## Getting derivatives from `jax`

Now instead of using  the `d_smooth_abs` function, we have the option of using the automatic differentiation tools in `jax` to do the differentiation for us.  Useful methods from `jax` here are `grad`, which only supports scalar functions, `jacfwd` and `jacrev`, which provide a full, dense jacobian, and the `jvp` and `vjp` methods which can provide matrix-free derivatives that only include the diagonal elements we care about.

In [None]:
import jax
import jax.numpy as jnp

import openmdao.math.jax as omj


x = np.linspace(-1, 1, 100)

act_x = omm.act_tanh(x, z=0, mu=1.0E-1, a=5, b=10)

fig, ax = plt.subplots(2, 1, figsize=(6, 10))

ax[0].set_title('openmdao.math.jax.act_tanh(x, z=0, mu=0.1, a=5, b=10)')
ax[0].plot(x, act_x)
ax[0].grid()

def d_act_tanh(x):
    return jax.jacfwd(omj.act_tanh)(jax.numpy.ravel(x), z=0, mu=1.0E-1, a=5, b=10)

ax[1].set_title('Jacobian matrix of openmdao.math.jax.act_tanh(x, z=0, mu=0.1, a=5, b=10)')
ax[1].imshow(d_act_tanh(x), aspect='auto')

plt.show()


In [None]:
We can change this to retrieve only the diagonal if we wish, since all off-diagonal elements are zero.

In [None]:
mu = 0.01
z = 0.5

def d_act_tanh_jacfwd(x, z):
    d_dx, d_dz = jax.jacfwd(omj.act_tanh, argnums=[0, 2])(jax.numpy.ravel(x), mu, z, a=5, b=10)
    return jax.numpy.diagonal(d_dx), d_dz

def d_act_tanh_jacrev(x, z):
    d_dx, d_dz = jax.jacrev(omj.act_tanh, argnums=[0, 2])(jax.numpy.ravel(x), mu, z, a=5, b=10)
    return jax.numpy.diagonal(d_dx), d_dz

## Performance of `jax` vs `numpy`

Typically, automatically differentiated jax code is up to an order of magnitude slower in execution than providing analytic derivatives.

We can improve the performance by using the just-in-time compiling capability of `jax`.

In [None]:
@jax.jit
def d_act_tanh_jacfwd_jit(x, z):
    d_dx, d_dz = jax.jacfwd(omj.act_tanh, argnums=[0, 2])(jax.numpy.ravel(x), mu, z, a=5, b=10)
    return jax.numpy.diagonal(d_dx), d_dz

@jax.jit
def d_act_tanh_jacrev_jit(x, z):
    d_dx, d_dz = jax.jacrev(omj.act_tanh, argnums=[0, 2])(jax.numpy.ravel(x), mu, z, a=5, b=10)
    return jax.numpy.diagonal(d_dx), d_dz

In [None]:
import timeit
loop = 1000

timing = {}
timing['analytic'] = timeit.timeit('omm.d_act_tanh(x, mu, z, a=5, b=10, dx=True, dz=True)', globals=globals(), number=loop)
timing['jacfwd'] = timeit.timeit('d_act_tanh_jacfwd(x, z)', globals=globals(), number=loop)
timing['jacrev'] = timeit.timeit('d_act_tanh_jacrev(x, z)', globals=globals(), number=loop)
timing['jacfwd_jit'] = timeit.timeit('d_act_tanh_jacfwd_jit(x, z)', globals=globals(), number=loop)
timing['jacrev_jit'] = timeit.timeit('d_act_tanh_jacrev_jit(x, z)', globals=globals(), number=loop)

fig, ax = plt.subplots(1, 1)
ax.set_yscale('log')

fig.suptitle('Performance of analytic derivatives of act_tanh\nvs jax with and without jit')
ax.bar(timing.keys(), np.asarray(list(timing.values())) / loop, width = 0.4)
ax.set_ylabel('Average time (s)')
ax.grid()

Now the jax form of the functions are considerably faster than the numpy analytic derivative implementation.

# Using the jax form of functions in components

Building components with `jax` based functions is more easily accomplished by defining and jit-compiling those functions outside of the components.



In [None]:
import openmdao.api as om
import jax
import jax.numpy as jnp


@jax.jit
def _f_count(x, mu, z, a, b):
    return np.sum(omj.act_tanh(x, mu=mu, z=z, a=0, b=1))


@jax.jit
def _d_f_count(x, mu, z, a, b):
    return jax.jacfwd(_f_count, argnums=[0])(jax.numpy.ravel(x), mu, z, a, b)


class CountingJaxComp(om.ExplicitComponent):
    
    def initialize(self):
        self.options.declare('vec_size', types=(int,))
        self.options.declare('threshold', types=(float,), default=0.0)
        self.options.declare('mu', types=(float,), default=0.01)
    
    def setup(self):
        n = self.options['vec_size']
        self.add_input('x', shape=(n,))
        self.add_output('count', shape=(1,))
        
        self.declare_partials(of='count', wrt='x')
    
    def compute(self, inputs, outputs):
        z = self.options['threshold']
        x = inputs['x']
        mu = self.options['mu']
        
        outputs['count'] = _f_count(x, mu, z, 0, 1)
        
    def compute_partials(self, inputs, partials):
        z = self.options['threshold']
        x = inputs['x']
        mu = self.options['mu']
        
        partials['count', 'x'] = _d_f_count(x, mu, z, 0, 1)


In [None]:
N = 10_000

p = om.Problem()
p.model.add_subsystem('counter',
                      CountingJaxComp(vec_size=N,threshold=0.5, mu=0.01),
                      promotes_inputs=['x'], promotes_outputs=['count'])
p.setup(force_alloc_complex=True)
p.set_val('x', np.random.random(N))
p.run_model()
p.check_partials(method='cs', compact_print=True);