-
Notifications
You must be signed in to change notification settings - Fork 240
/
ks.py
78 lines (66 loc) · 2.08 KB
/
ks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
"""
jax implementations of the Kreisselmeier-Steinhauser for the min and max values in an array.
"""
try:
import jax
from jax import jit
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
except (ImportError, ModuleNotFoundError):
jax = None
from openmdao.utils.jax_utils import jit_stub as jit
CITATIONS = """
@conference {Martins:2005:SOU,
title = {On Structural Optimization Using Constraint Aggregation},
booktitle = {Proceedings of the 6th World Congress on Structural and Multidisciplinary
Optimization},
year = {2005},
month = {May},
address = {Rio de Janeiro, Brazil},
author = {Joaquim R. R. A. Martins and Nicholas M. K. Poon}
}
"""
@jit
def ks_max(x, rho=100.0):
"""
Compute a differentiable maximum value in an array.
Given some array of values `x`, compute a differentiable, _conservative_ maximum using the
Kreisselmeier-Steinhauser function.
Parameters
----------
x : ndarray
Array of values.
rho : float
Aggregation Factor. Larger values of rho more closely match the true maximum value.
Returns
-------
float
A conservative approximation to the minimum value in x.
"""
x_max = jnp.max(x)
x_diff = x - x_max
exponents = jnp.exp(rho * x_diff)
summation = jnp.sum(exponents)
return x_max + 1.0 / rho * jnp.log(summation)
@jit
def ks_min(x, rho=100.0):
"""
Compute a differentiable minimum value in an array.
Given some array of values `x`, compute a differentiable,
_conservative_ minimum using the Kreisselmeier-Steinhauser function.
Parameters
----------
x : ndarray
Array of values.
rho : float
Aggregation Factor. Larger values of rho more closely match the true minimum value.
Returns
-------
float
A conservative approximation to the minimum value in x.
"""
x_min = jnp.min(x)
x_diff = x_min - x
exponents = jnp.exp(rho * x_diff)
summation = jnp.sum(exponents)
return x_min - 1.0 / rho * jnp.log(summation)