-
Notifications
You must be signed in to change notification settings - Fork 80
/
InverseFunctions.jl
92 lines (68 loc) · 2.92 KB
/
InverseFunctions.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
module InverseFunctionsModule
using InverseFunctions: inverse as _inverse, NoInverse
#! format: off
using ..CoreModule:
square, cube, safe_pow, safe_log, safe_log2,
safe_log10, safe_log1p, safe_sqrt, safe_acosh, neg, greater, cond,
relu, logical_or, logical_and, gamma, erf, erfc, atanh_clip
#! format: on
"""
approx_inverse(f::Function)
Create a function that, for x ∈ [0, ϵ], for some ϵ > 0,
is the inverse of `f`. This means that, e.g., `abs` has
an `approx_inverse` of `abs`, since, for any `x>0`, `abs(abs(x)) = x`.
This is not to be treated as an exact mathematical inverse. It's purely
for utility in mutation operators that can use information about inverse
functions to improve the search space.
The default behavior for operators is to use InverseFunctions.jl.
"""
function approx_inverse(f::F) where {F<:Function}
i_f = _inverse(f)
if i_f isa NoInverse
_no_inverse(f)
end
return i_f
end
function _no_inverse(f)
return error("Inverse of $(f) not yet implemented. Please extend `$(approx_inverse)`.")
end
# Fix1 and Fix2 are treated separately
approx_inverse(f::Union{Base.Fix1,Base.Fix2}) = _no_inverse(f)
###########################################################################
## Unary operators ########################################################
###########################################################################
approx_inverse(::typeof(sin)) = asin
approx_inverse(::typeof(asin)) = sin
approx_inverse(::typeof(cos)) = acos
approx_inverse(::typeof(acos)) = cos
approx_inverse(::typeof(tan)) = atan
approx_inverse(::typeof(atan)) = tan
approx_inverse(::typeof(sinh)) = asinh
approx_inverse(::typeof(asinh)) = sinh
approx_inverse(::typeof(cosh)) = safe_acosh
approx_inverse(::typeof(safe_acosh)) = cosh
approx_inverse(::typeof(tanh)) = atanh_clip
approx_inverse(::typeof(atanh_clip)) = tanh
approx_inverse(::typeof(square)) = safe_sqrt
approx_inverse(::typeof(safe_sqrt)) = square
approx_inverse(::typeof(cube)) = cbrt
approx_inverse(::typeof(cbrt)) = cube
approx_inverse(::typeof(exp)) = safe_log
approx_inverse(::typeof(safe_log)) = exp
approx_inverse(::typeof(safe_log2)) = exp2
approx_inverse(::typeof(exp2)) = safe_log2
approx_inverse(::typeof(safe_log10)) = exp10
approx_inverse(::typeof(exp10)) = safe_log10
exp1m(x) = exp(x) - one(x)
approx_inverse(::typeof(safe_log1p)) = exp1m
approx_inverse(::typeof(exp1m)) = safe_log1p
approx_inverse(::typeof(neg)) = neg
approx_inverse(::typeof(relu)) = relu
approx_inverse(::typeof(abs)) = abs
###########################################################################
###########################################################################
## Binary operators #######################################################
###########################################################################
approx_inverse(f::Base.Fix1{typeof(+)}) = Base.Fix2(-, f.x)
###########################################################################
end