Skip to content

Commit

Permalink
Merge b0d2583 into a9b00c8
Browse files Browse the repository at this point in the history
  • Loading branch information
fp4code committed Jul 10, 2019
2 parents a9b00c8 + b0d2583 commit c3598be
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 92 deletions.
34 changes: 30 additions & 4 deletions src/api.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@

const DEFINED_DIFFRULES = Dict{Tuple{Union{Expr,Symbol},Symbol,Int},Any}()
const DEFINED_COMPLEX_DIFFRULES = Dict{Tuple{Union{Expr,Symbol},Symbol,Int},Any}()

"""
@define_diffrule M.f(x) = :(df_dx(\$x))
@define_diffrule M.f(x, y) = :(df_dx(\$x, \$y)), :(df_dy(\$x, \$y))
@define_complex_diffrule M.f(x) = :(df_dx(\$x))
@define_complex_diffrule M.f(x, y) = :(df_dx(\$x, \$y)), :(df_dy(\$x, \$y))
Define a new differentiation rule for the function `M.f` and the given arguments, which should
Expand All @@ -16,14 +19,18 @@ interpolated wherever they are used on the RHS.
Note that differentiation rules are purely symbolic, so no type annotations should be used.
The complex version @define_complex_diffrule should be used if M.f is complex differentiable.
If not, @define_diffrule should be used instead.
Examples:
@define_diffrule Base.cos(x) = :(-sin(\$x))
@define_diffrule Base.:/(x, y) = :(inv(\$y)), :(-\$x / (\$y^2))
@define_diffrule Base.polygamma(m, x) = :NaN, :(polygamma(\$m + 1, \$x))
@define_complex_diffrule Base.cos(x) = :(-sin(\$x))
@define_complex_diffrule Base.:/(x, y) = :(inv(\$y)), :(-\$x / (\$y^2))
@define_diffrule Base.polygamma(m, x) = :NaN, :(polygamma(\$m + 1, \$x))
"""
macro define_diffrule(def)

function _getkeyrule(def)
@assert isa(def, Expr) && def.head == :(=) "Diff rule expression does not have a left and right side"
lhs = def.args[1]
rhs = def.args[2]
Expand All @@ -35,6 +42,20 @@ macro define_diffrule(def)
args = lhs.args[2:end]
rule = Expr(:->, Expr(:tuple, args...), rhs)
key = Expr(:tuple, Expr(:quote, M), Expr(:quote, f), length(args))
return key,rule
end

macro define_complex_diffrule(def)
key,rule = _getkeyrule(def)
return esc(quote
$DiffRules.DEFINED_DIFFRULES[$key] = $rule
$DiffRules.DEFINED_COMPLEX_DIFFRULES[$key] = $rule
$key
end)
end

macro define_diffrule(def)
key,rule = _getkeyrule(def)
return esc(quote
$DiffRules.DEFINED_DIFFRULES[$key] = $rule
$key
Expand All @@ -43,6 +64,7 @@ end

"""
diffrule(M::Union{Expr,Symbol}, f::Symbol, args...)
complex_diffrule(M::Union{Expr,Symbol}, f::Symbol, args...)
Return the derivative expression for `M.f` at the given argument(s), with the argument(s)
interpolated into the returned expression.
Expand All @@ -65,9 +87,11 @@ Examples:
(:(c * (x + 2) ^ (c - 1)), :((x + 2) ^ c * log(x + 2)))
"""
diffrule(M::Union{Expr,Symbol}, f::Symbol, args...) = DEFINED_DIFFRULES[M,f,length(args)](args...)
complex_diffrule(M::Union{Expr,Symbol}, f::Symbol, args...) = DEFINED_COMPLEX_DIFFRULES[M,f,length(args)](args...)

"""
hasdiffrule(M::Union{Expr,Symbol}, f::Symbol, arity::Int)
hascomplex_diffrule(M::Union{Expr,Symbol}, f::Symbol, arity::Int)
Return `true` if a differentiation rule is defined for `M.f` and `arity`, or return `false`
otherwise.
Expand All @@ -92,6 +116,7 @@ Examples:
false
"""
hasdiffrule(M::Union{Expr,Symbol}, f::Symbol, arity::Int) = haskey(DEFINED_DIFFRULES, (M, f, arity))
hascomplex_diffrule(M::Union{Expr,Symbol}, f::Symbol, arity::Int) = haskey(DEFINED_COMPLEX_DIFFRULES, (M, f, arity))

"""
diffrules()
Expand All @@ -109,6 +134,7 @@ Examples:
"""
diffrules() = keys(DEFINED_DIFFRULES)
complex_diffrules() = keys(DEFINED_COMPLEX_DIFFRULES)

# For v0.6 and v0.7 compatibility, need to support having the diff rule function enter as a
# `Expr(:quote...)` and a `QuoteNode`. When v0.6 support is dropped, the function will
Expand Down
173 changes: 90 additions & 83 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,72 +5,79 @@
# unary #
#-------#

@define_diffrule Base.:+(x) = :( 1 )
@define_diffrule Base.:-(x) = :( -1 )
@define_diffrule Base.sqrt(x) = :( inv(2 * sqrt($x)) )
@define_diffrule Base.cbrt(x) = :( inv(3 * cbrt($x)^2) )
@define_diffrule Base.abs2(x) = :( $x + $x )
@define_diffrule Base.inv(x) = :( -abs2(inv($x)) )
@define_diffrule Base.log(x) = :( inv($x) )
@define_diffrule Base.log10(x) = :( inv($x) / log(10) )
@define_diffrule Base.log2(x) = :( inv($x) / log(2) )
@define_diffrule Base.log1p(x) = :( inv($x + 1) )
@define_diffrule Base.exp(x) = :( exp($x) )
@define_diffrule Base.exp2(x) = :( exp2($x) * log(2) )
@define_diffrule Base.exp10(x) = :( exp10($x) * log(10) )
@define_diffrule Base.expm1(x) = :( exp($x) )
@define_diffrule Base.sin(x) = :( cos($x) )
@define_diffrule Base.cos(x) = :( -sin($x) )
@define_diffrule Base.tan(x) = :( 1 + tan($x)^2 )
@define_diffrule Base.sec(x) = :( sec($x) * tan($x) )
@define_diffrule Base.csc(x) = :( -csc($x) * cot($x) )
@define_diffrule Base.cot(x) = :( -(1 + cot($x)^2) )
@define_diffrule Base.sind(x) = :( (π / 180) * cosd($x) )
@define_diffrule Base.cosd(x) = :( -/ 180) * sind($x) )
@define_diffrule Base.tand(x) = :( (π / 180) * (1 + tand($x)^2) )
@define_diffrule Base.secd(x) = :( (π / 180) * secd($x) * tand($x) )
@define_diffrule Base.cscd(x) = :( -/ 180) * cscd($x) * cotd($x) )
@define_diffrule Base.cotd(x) = :( -/ 180) * (1 + cotd($x)^2) )
@define_diffrule Base.sinpi(x) = :( π * cospi($x) )
@define_diffrule Base.cospi(x) = :( -π * sinpi($x) )
@define_diffrule Base.asin(x) = :( inv(sqrt(1 - $x^2)) )
@define_diffrule Base.acos(x) = :( -inv(sqrt(1 - $x^2)) )
@define_diffrule Base.atan(x) = :( inv(1 + $x^2) )
@define_diffrule Base.asec(x) = :( inv(abs($x) * sqrt($x^2 - 1)) )
@define_diffrule Base.acsc(x) = :( -inv(abs($x) * sqrt($x^2 - 1)) )
@define_diffrule Base.acot(x) = :( -inv(1 + $x^2) )
@define_diffrule Base.asind(x) = :( 180 / π / sqrt(1 - $x^2) )
@define_diffrule Base.acosd(x) = :( -180 / π / sqrt(1 - $x^2) )
@define_diffrule Base.atand(x) = :( 180 / π / (1 + $x^2) )
@define_diffrule Base.asecd(x) = :( 180 / π / abs($x) / sqrt($x^2 - 1) )
@define_diffrule Base.acscd(x) = :( -180 / π / abs($x) / sqrt($x^2 - 1) )
@define_diffrule Base.acotd(x) = :( -180 / π / (1 + $x^2) )
@define_diffrule Base.sinh(x) = :( cosh($x) )
@define_diffrule Base.cosh(x) = :( sinh($x) )
@define_diffrule Base.tanh(x) = :( 1 - tanh($x)^2 )
@define_diffrule Base.sech(x) = :( -tanh($x) * sech($x) )
@define_diffrule Base.csch(x) = :( -coth($x) * csch($x) )
@define_diffrule Base.coth(x) = :( -(csch($x)^2) )
@define_diffrule Base.asinh(x) = :( inv(sqrt($x^2 + 1)) )
@define_diffrule Base.acosh(x) = :( inv(sqrt($x^2 - 1)) )
@define_diffrule Base.atanh(x) = :( inv(1 - $x^2) )
@define_diffrule Base.asech(x) = :( -inv($x * sqrt(1 - $x^2)) )
@define_diffrule Base.acsch(x) = :( -inv(abs($x) * sqrt(1 + $x^2)) )
@define_diffrule Base.acoth(x) = :( inv(1 - $x^2) )
@define_diffrule Base.deg2rad(x) = :( π / 180 )
@define_diffrule Base.rad2deg(x) = :( 180 / π )
@define_complex_diffrule Base.:+(x) = :( 1 )
@define_complex_diffrule Base.:-(x) = :( -1 )
@define_complex_diffrule Base.sqrt(x) = :( inv(2 * sqrt($x)) )
@define_diffrule Base.cbrt(x) = :( inv(3 * cbrt($x)^2) )
@define_diffrule Base.abs2(x) = :( $x + $x )
@define_complex_diffrule Base.inv(x) = :( -(inv($x^2)) )
@define_diffrule Base.inv(x) = :( -abs2(inv($x)) )
@define_complex_diffrule Base.log(x) = :( inv($x) )
@define_complex_diffrule Base.log10(x) = :( inv($x) / log(10) )
@define_complex_diffrule Base.log2(x) = :( inv($x) / log(2) )
@define_complex_diffrule Base.log1p(x) = :( inv($x + 1) )
@define_complex_diffrule Base.exp(x) = :( exp($x) )
@define_complex_diffrule Base.exp2(x) = :( exp2($x) * log(2) )
@define_complex_diffrule Base.exp10(x) = :( exp10($x) * log(10) )
@define_complex_diffrule Base.expm1(x) = :( exp($x) )
@define_complex_diffrule Base.sin(x) = :( cos($x) )
@define_complex_diffrule Base.cos(x) = :( -sin($x) )
@define_complex_diffrule Base.tan(x) = :( 1 + tan($x)^2 )
@define_complex_diffrule Base.sec(x) = :( sec($x) * tan($x) )
@define_complex_diffrule Base.csc(x) = :( -csc($x) * cot($x) )
@define_complex_diffrule Base.cot(x) = :( -(1 + cot($x)^2) )
@define_diffrule Base.sind(x) = :( (π / 180) * cosd($x) )
@define_diffrule Base.cosd(x) = :( -/ 180) * sind($x) )
@define_diffrule Base.tand(x) = :( (π / 180) * (1 + tand($x)^2) )
@define_diffrule Base.secd(x) = :( (π / 180) * secd($x) * tand($x) )
@define_diffrule Base.cscd(x) = :( -/ 180) * cscd($x) * cotd($x) )
@define_diffrule Base.cotd(x) = :( -/ 180) * (1 + cotd($x)^2) )
@define_complex_diffrule Base.sinpi(x) = :( π * cospi($x) )
@define_complex_diffrule Base.cospi(x) = :( -π * sinpi($x) )
@define_complex_diffrule Base.asin(x) = :( inv(sqrt(1 - $x^2)) )
@define_complex_diffrule Base.acos(x) = :( -inv(sqrt(1 - $x^2)) )
@define_complex_diffrule Base.atan(x) = :( inv(1 + $x^2) )
@define_complex_diffrule Base.asec(x) = :( inv($x^2 * sqrt(1 - inv($x^2))) )
@define_diffrule Base.asec(x) = :( inv(abs($x) * sqrt($x^2 - 1)) )
@define_complex_diffrule Base.acsc(x) = :( -inv($x^2*sqrt(1 - inv($x^2))) )
@define_diffrule Base.acsc(x) = :( -inv(abs($x) * sqrt($x^2 - 1)) )
@define_diffrule Base.acot(x) = :( -inv(1 + $x^2) )
@define_diffrule Base.asind(x) = :( 180 / π / sqrt(1 - $x^2) )
@define_diffrule Base.acosd(x) = :( -180 / π / sqrt(1 - $x^2) )
@define_diffrule Base.atand(x) = :( 180 / π / (1 + $x^2) )
@define_diffrule Base.asecd(x) = :( 180 / π / $x^2 / sqrt(1 - inv($x^2)))
@define_diffrule Base.asecd(x) = :( 180 / π / abs($x) / sqrt($x^2 - 1) )
@define_diffrule Base.acscd(x) = :( -180 / π / $x^2 / sqrt(1 - inv($x^2)))
@define_diffrule Base.acscd(x) = :( -180 / π / abs($x) / sqrt($x^2 - 1) )
@define_diffrule Base.acotd(x) = :( -180 / π / (1 + $x^2) )
@define_complex_diffrule Base.sinh(x) = :( cosh($x) )
@define_complex_diffrule Base.cosh(x) = :( sinh($x) )
@define_complex_diffrule Base.tanh(x) = :( 1 - tanh($x)^2 )
@define_complex_diffrule Base.sech(x) = :( -tanh($x) * sech($x) )
@define_complex_diffrule Base.csch(x) = :( -coth($x) * csch($x) )
@define_complex_diffrule Base.coth(x) = :( -(csch($x)^2) )
@define_complex_diffrule Base.asinh(x) = :( inv(sqrt($x^2 + 1)) )
@define_complex_diffrule Base.acosh(x) = :( inv(sqrt($x - 1)*sqrt($x+1)) )
@define_diffrule Base.acosh(x) = :( inv(sqrt($x^2 - 1)) )
@define_complex_diffrule Base.atanh(x) = :( inv(1 - $x^2) )
@define_complex_diffrule Base.asech(x) = :( -inv($x * sqrt(1 - $x^2)) )
@define_complex_diffrule Base.acsch(x) = :( -inv(sqrt($x^4 + $x^2)) )
@define_diffrule Base.acsch(x) = :( -inv(abs($x) * sqrt(1 + $x^2)) )
@define_complex_diffrule Base.acoth(x) = :( inv(1 - $x^2) )
@define_diffrule Base.deg2rad(x) = :( π / 180 )
@define_diffrule Base.rad2deg(x) = :( 180 / π )
if VERSION < v"0.7-"
@define_diffrule Base.gamma(x) = :( digamma($x) * gamma($x) )
@define_diffrule Base.lgamma(x) = :( digamma($x) )
@define_diffrule Base.Math.JuliaLibm.log1p(x) = :( inv($x + 1) )
@define_complex_diffrule Base.gamma(x) = :( digamma($x) * gamma($x) )
@define_complex_diffrule Base.lgamma(x) = :( digamma($x) )
@define_diffrule Base.Math.JuliaLibm.log1p(x) = :( inv($x + 1) )
else
@define_diffrule SpecialFunctions.gamma(x) =
@define_complex_diffrule SpecialFunctions.gamma(x) =
:( SpecialFunctions.digamma($x) * SpecialFunctions.gamma($x) )
@define_diffrule SpecialFunctions.lgamma(x) =
@define_complex_diffrule SpecialFunctions.lgamma(x) =
:( SpecialFunctions.digamma($x) )
end
@define_diffrule Base.transpose(x) = :( 1 )
@define_diffrule Base.abs(x) = :( DiffRules._abs_deriv($x) )
@define_complex_diffrule Base.transpose(x) = :( 1 )
@define_diffrule Base.abs(x) = :( DiffRules._abs_deriv($x) )

# We provide this hook for special number types like `Interval`
# that need their own special definition of `abs`.
Expand All @@ -79,12 +86,12 @@ _abs_deriv(x) = signbit(x) ? -one(x) : one(x)
# binary #
#--------#

@define_diffrule Base.:+(x, y) = :( 1 ), :( 1 )
@define_diffrule Base.:-(x, y) = :( 1 ), :( -1 )
@define_diffrule Base.:*(x, y) = :( $y ), :( $x )
@define_diffrule Base.:/(x, y) = :( inv($y) ), :( -($x / $y / $y) )
@define_diffrule Base.:\(x, y) = :( -($y / $x / $x) ), :( inv($x) )
@define_diffrule Base.:^(x, y) = :( $y * ($x^($y - 1)) ), :( ($x^$y) * log($x) )
@define_complex_diffrule Base.:+(x, y) = :( 1 ), :( 1 )
@define_complex_diffrule Base.:-(x, y) = :( 1 ), :( -1 )
@define_complex_diffrule Base.:*(x, y) = :( $y ), :( $x )
@define_complex_diffrule Base.:/(x, y) = :( inv($y) ), :( -($x / $y / $y) )
@define_complex_diffrule Base.:\(x, y) = :( -($y / $x / $x) ), :( inv($x) )
@define_diffrule Base.:^(x, y) = :( $y * ($x^($y - 1)) ), :( ($x^$y) * log($x) )

if VERSION < v"0.7-"
@define_diffrule Base.atan2(x, y) = :( $y / ($x^2 + $y^2) ), :( -$x / ($x^2 + $y^2) )
Expand All @@ -105,38 +112,38 @@ end
# unary #
#-------#

@define_diffrule SpecialFunctions.erf(x) = :( (2 / sqrt(π)) * exp(-$x * $x) )
@define_complex_diffrule SpecialFunctions.erf(x) = :( (2 / sqrt(π)) * exp(-$x * $x) )
@define_diffrule SpecialFunctions.erfinv(x) =
:( (sqrt(π) / 2) * exp(SpecialFunctions.erfinv($x)^2) )
@define_diffrule SpecialFunctions.erfc(x) = :( -(2 / sqrt(π)) * exp(-$x * $x) )
@define_diffrule SpecialFunctions.erfcinv(x) =
@define_complex_diffrule SpecialFunctions.erfc(x) = :( -(2 / sqrt(π)) * exp(-$x * $x) )
@define_diffrule SpecialFunctions.erfcinv(x) =
:( -(sqrt(π) / 2) * exp(SpecialFunctions.erfcinv($x)^2) )
@define_diffrule SpecialFunctions.erfi(x) = :( (2 / sqrt(π)) * exp($x * $x) )
@define_diffrule SpecialFunctions.erfcx(x) =
@define_complex_diffrule SpecialFunctions.erfi(x) = :( (2 / sqrt(π)) * exp($x * $x) )
@define_complex_diffrule SpecialFunctions.erfcx(x) =
:( (2 * $x * SpecialFunctions.erfcx($x)) - (2 / sqrt(π)) )
@define_diffrule SpecialFunctions.dawson(x) =
@define_complex_diffrule SpecialFunctions.dawson(x) =
:( 1 - (2 * $x * SpecialFunctions.dawson($x)) )
@define_diffrule SpecialFunctions.digamma(x) =
@define_complex_diffrule SpecialFunctions.digamma(x) =
:( SpecialFunctions.trigamma($x) )
@define_diffrule SpecialFunctions.invdigamma(x) =
:( inv(SpecialFunctions.trigamma(SpecialFunctions.invdigamma($x))) )
@define_diffrule SpecialFunctions.trigamma(x) =
@define_complex_diffrule SpecialFunctions.trigamma(x) =
:( SpecialFunctions.polygamma(2, $x) )
@define_diffrule SpecialFunctions.airyai(x) =
@define_complex_diffrule SpecialFunctions.airyai(x) =
:( SpecialFunctions.airyaiprime($x) )
@define_diffrule SpecialFunctions.airyaiprime(x) =
@define_complex_diffrule SpecialFunctions.airyaiprime(x) =
:( $x * SpecialFunctions.airyai($x) )
@define_diffrule SpecialFunctions.airybi(x) =
@define_complex_diffrule SpecialFunctions.airybi(x) =
:( SpecialFunctions.airybiprime($x) )
@define_diffrule SpecialFunctions.airybiprime(x) =
@define_complex_diffrule SpecialFunctions.airybiprime(x) =
:( $x * SpecialFunctions.airybi($x) )
@define_diffrule SpecialFunctions.besselj0(x) =
@define_complex_diffrule SpecialFunctions.besselj0(x) =
:( -SpecialFunctions.besselj1($x) )
@define_diffrule SpecialFunctions.besselj1(x) =
@define_complex_diffrule SpecialFunctions.besselj1(x) =
:( (SpecialFunctions.besselj0($x) - SpecialFunctions.besselj(2, $x)) / 2 )
@define_diffrule SpecialFunctions.bessely0(x) =
@define_complex_diffrule SpecialFunctions.bessely0(x) =
:( -SpecialFunctions.bessely1($x) )
@define_diffrule SpecialFunctions.bessely1(x) =
@define_complex_diffrule SpecialFunctions.bessely1(x) =
:( (SpecialFunctions.bessely0($x) - SpecialFunctions.bessely(2, $x)) / 2 )

# TODO:
Expand Down

0 comments on commit c3598be

Please sign in to comment.