Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 146 additions & 43 deletions src/diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,32 @@ function recursive_hasoperator(op, O)
end
end

struct DerivativeNotDefinedError <: Exception
expr
i::Int
end

function Base.showerror(io::IO, err::DerivativeNotDefinedError)
op = operation(err.expr)
nargs = length(arguments(err.expr))
# `Markdown.parse` instead of `@md_str` to allow interpolating inside `literal` blocks
# and code fences
err_str = Markdown.parse("""
Derivative of `$(err.expr)` with respect to its $(err.i)-th argument is not defined.
Define a derivative by adding a method to `Symbolics.derivative`:

```julia
function Symbolics.derivative(::typeof($op), args::NTuple{$nargs, Any}, ::Val{$(err.i)})
# ...
end
```

Refer to the documentation for `Symbolics.derivative` and the
"[Adding Analytical Derivatives](@ref)" section of the docs for further information.
""")
show(io, MIME"text/plain"(), err_str)
end

"""
executediff(D, arg, simplify=false; occurrences=nothing)

Expand All @@ -166,8 +192,10 @@ passed differential and not any other Differentials it encounters.
- `occurrences=nothing`: Information about the occurrences of the independent
variable in the argument of the derivative. This is used internally for
optimization purposes.
- `throw_no_derivative=false`: Whether to throw if a function with unknown
derivative is encountered.
"""
function executediff(D, arg, simplify=false; occurrences=nothing)
function executediff(D, arg, simplify=false; occurrences=nothing, throw_no_derivative=false)
if occurrences == nothing
occurrences = occursin_info(D.x, arg)
end
Expand All @@ -183,15 +211,15 @@ function executediff(D, arg, simplify=false; occurrences=nothing)
return D(arg) # base case if any argument is directly equal to the i.v.
else
return sum(inner_args, init=0) do a
return executediff(Differential(a), arg) *
executediff(D, a)
return executediff(Differential(a), arg; throw_no_derivative) *
executediff(D, a; throw_no_derivative)
end
end
elseif op === ifelse
args = arguments(arg)
O = op(args[1],
executediff(D, args[2], simplify; occurrences=arguments(occurrences)[2]),
executediff(D, args[3], simplify; occurrences=arguments(occurrences)[3]))
executediff(D, args[2], simplify; occurrences=arguments(occurrences)[2], throw_no_derivative),
executediff(D, args[3], simplify; occurrences=arguments(occurrences)[3], throw_no_derivative))
return O
elseif isa(op, Differential)
# The recursive expand_derivatives was not able to remove
Expand All @@ -201,13 +229,13 @@ function executediff(D, arg, simplify=false; occurrences=nothing)
if isequal(op.x, D.x)
return D(arg)
else
inner = executediff(D, arguments(arg)[1], false)
inner = executediff(D, arguments(arg)[1], false; throw_no_derivative)
# if the inner expression is not expandable either, return
if iscall(inner) && operation(inner) isa Differential
return D(arg)
else
# otherwise give the nested Differential another try
return executediff(op, inner, simplify)
return executediff(op, inner, simplify; throw_no_derivative)
end
end
elseif isa(op, Integral)
Expand All @@ -226,7 +254,7 @@ function executediff(D, arg, simplify=false; occurrences=nothing)
t2 = D(b)
c += t1*t2
end
inner = executediff(D, arguments(arg)[1])
inner = executediff(D, arguments(arg)[1]; throw_no_derivative)
c += op(inner)
return value(c)
end
Expand All @@ -238,16 +266,26 @@ function executediff(D, arg, simplify=false; occurrences=nothing)
c = 0

for i in 1:l
t2 = executediff(D, inner_args[i],false; occurrences=arguments(occurrences)[i])
t2 = executediff(D, inner_args[i],false; occurrences=arguments(occurrences)[i], throw_no_derivative)

x = if _iszero(t2)
t2
elseif _isone(t2)
d = derivative_idx(arg, i)
d isa NoDeriv ? D(arg) : d
if d isa NoDeriv
throw_no_derivative && throw(DerivativeNotDefinedError(arg, i))
D(arg)
else
d
end
else
t1 = derivative_idx(arg, i)
t1 = t1 isa NoDeriv ? D(arg) : t1
t1 = if t1 isa NoDeriv
throw_no_derivative && throw(DerivativeNotDefinedError(arg, i))
D(arg)
else
t1
end
t1 * t2
end

Expand Down Expand Up @@ -284,6 +322,10 @@ and other derivative rules to expand any derivatives it encounters.
- `simplify::Bool=false`: Whether to simplify the resulting expression using
[`SymbolicUtils.simplify`](@ref).

# Keyword Arguments
- `throw_no_derivative=false`: Whether to throw if a function with unknown
derivative is encountered.

# Examples
```jldoctest
julia> @variables x y z k;
Expand All @@ -298,29 +340,29 @@ julia> dfx = expand_derivatives(Dx(f))
(k*((2abs(x - y)) / y - 2z)*ifelse(signbit(x - y), -1, 1)) / y
```
"""
function expand_derivatives(O::Symbolic, simplify=false)
function expand_derivatives(O::Symbolic, simplify=false; throw_no_derivative=false)
if iscall(O) && isa(operation(O), Differential)
arg = only(arguments(O))
arg = expand_derivatives(arg, false)
return executediff(operation(O), arg, simplify)
arg = expand_derivatives(arg, false; throw_no_derivative)
return executediff(operation(O), arg, simplify; throw_no_derivative)
elseif iscall(O) && isa(operation(O), Integral)
return operation(O)(expand_derivatives(arguments(O)[1]))
return operation(O)(expand_derivatives(arguments(O)[1]; throw_no_derivative))
elseif !hasderiv(O)
return O
else
args = map(a->expand_derivatives(a, false), arguments(O))
args = map(a->expand_derivatives(a, false; throw_no_derivative), arguments(O))
O1 = operation(O)(args...)
return simplify ? SymbolicUtils.simplify(O1) : O1
end
end
function expand_derivatives(n::Num, simplify=false)
wrap(expand_derivatives(value(n), simplify))
function expand_derivatives(n::Num, simplify=false; kwargs...)
wrap(expand_derivatives(value(n), simplify; kwargs...))
end
function expand_derivatives(n::Complex{Num}, simplify=false)
wrap(ComplexTerm{Real}(expand_derivatives(real(n), simplify),
expand_derivatives(imag(n), simplify)))
function expand_derivatives(n::Complex{Num}, simplify=false; kwargs...)
wrap(ComplexTerm{Real}(expand_derivatives(real(n), simplify; kwargs...),
expand_derivatives(imag(n), simplify; kwargs...)))
end
expand_derivatives(x, simplify=false) = x
expand_derivatives(x, simplify=false; kwargs...) = x

_iszero(x) = false
_isone(x) = false
Expand Down Expand Up @@ -368,6 +410,16 @@ end
# Indicate that no derivative is defined.
struct NoDeriv
end

"""
Symbolics.derivative(::typeof(f), args::NTuple{N, Any}, ::Val{i})

Return the derivative of `f(args...)` with respect to `args[i]`. `N` should be the number
of arguments that `f` takes and `i` is the argument with respect to which the derivative
is taken. The result can be a numeric value (if the derivative is constant) or a symbolic
expression. This function is useful for defining derivatives of custom functions registered
via `@register_symbolic`, to be used when calling `expand_derivatives`.
"""
derivative(f, args, v) = NoDeriv()

# Pre-defined derivatives
Expand Down Expand Up @@ -455,12 +507,18 @@ $(SIGNATURES)

A helper function for computing the derivative of the expression `O` with respect to
`var`.

# Keyword Arguments

- `simplify=false`: The simplify argument of `expand_derivatives`.

All other keyword arguments are forwarded to `expand_derivatives`.
"""
function derivative(O, var; simplify=false)
function derivative(O, var; simplify=false, kwargs...)
if O isa AbstractArray
Num[Num(expand_derivatives(Differential(var)(value(o)), simplify)) for o in O]
Num[Num(expand_derivatives(Differential(var)(value(o)), simplify; kwargs...)) for o in O]
else
Num(expand_derivatives(Differential(var)(value(O)), simplify))
Num(expand_derivatives(Differential(var)(value(O)), simplify; kwargs...))
end
end

Expand All @@ -469,44 +527,63 @@ $(SIGNATURES)

A helper function for computing the gradient of the expression `O` with respect to
an array of variable expressions.

# Keyword Arguments

- `simplify=false`: The simplify argument of `expand_derivatives`.

All other keyword arguments are forwarded to `expand_derivatives`.
"""
function gradient(O, vars::AbstractVector; simplify=false)
Num[Num(expand_derivatives(Differential(v)(value(O)),simplify)) for v in vars]
function gradient(O, vars::AbstractVector; simplify=false, kwargs...)
Num[Num(expand_derivatives(Differential(v)(value(O)),simplify; kwargs...)) for v in vars]
end

"""
$(SIGNATURES)

A helper function for computing the Jacobian of an array of expressions with respect to
an array of variable expressions.

# Keyword Arguments

- `simplify=false`: The simplify argument of `expand_derivatives`.
- `scalarize=true`: Whether to scalarize `ops` and `vars` before computing the jacobian.

All other keyword arguments are forwarded to `expand_derivatives`.
"""
function jacobian(ops::AbstractVector, vars::AbstractVector; simplify=false, scalarize=true)
function jacobian(ops::AbstractVector, vars::AbstractVector; simplify=false, scalarize=true, kwargs...)
if scalarize
ops = Symbolics.scalarize(ops)
vars = Symbolics.scalarize(vars)
end
Num[Num(expand_derivatives(Differential(value(v))(value(O)),simplify)) for O in ops, v in vars]
Num[Num(expand_derivatives(Differential(value(v))(value(O)),simplify; kwargs...)) for O in ops, v in vars]
end

function jacobian(ops, vars; simplify=false)
function jacobian(ops, vars; simplify=false, kwargs...)
ops = vec(scalarize(ops))
vars = vec(scalarize(vars)) # Suboptimal, but prevents wrong results on Arr for now. Arr resulting from a symbolic function will fail on this due to unknown size.
jacobian(ops, vars; simplify=simplify, scalarize=false)
jacobian(ops, vars; simplify=simplify, scalarize=false, kwargs...)
end

"""
$(SIGNATURES)

A helper function for computing the sparse Jacobian of an array of expressions with respect to
an array of variable expressions.

# Keyword Arguments

- `simplify=false`: The simplify argument of `expand_derivatives`.

All other keyword arguments are forwarded to `expand_derivatives`.
"""
function sparsejacobian(ops::AbstractVector, vars::AbstractVector; simplify::Bool=false)
function sparsejacobian(ops::AbstractVector, vars::AbstractVector; simplify::Bool=false, kwargs...)
ops = Symbolics.scalarize(ops)
vars = Symbolics.scalarize(vars)
sp = jacobian_sparsity(ops, vars)
I,J,_ = findnz(sp)

exprs = sparsejacobian_vals(ops, vars, I, J, simplify=simplify)
exprs = sparsejacobian_vals(ops, vars, I, J; simplify=simplify, kwargs...)

sparse(I, J, exprs, length(ops), length(vars))
end
Expand All @@ -516,15 +593,21 @@ $(SIGNATURES)

A helper function for computing the values of the sparse Jacobian of an array of expressions with respect to
an array of variable expressions given the sparsity structure.

# Keyword Arguments

- `simplify=false`: The simplify argument of `expand_derivatives`.

All other keyword arguments are forwarded to `expand_derivatives`.
"""
function sparsejacobian_vals(ops::AbstractVector, vars::AbstractVector, I::AbstractVector, J::AbstractVector; simplify::Bool=false)
function sparsejacobian_vals(ops::AbstractVector, vars::AbstractVector, I::AbstractVector, J::AbstractVector; simplify::Bool=false, kwargs...)
ops = Symbolics.scalarize(ops)
vars = Symbolics.scalarize(vars)

exprs = Num[]

for (i,j) in zip(I, J)
push!(exprs, Num(expand_derivatives(Differential(vars[j])(ops[i]), simplify)))
push!(exprs, Num(expand_derivatives(Differential(vars[j])(ops[i]), simplify; kwargs...)))
end
exprs
end
Expand Down Expand Up @@ -637,16 +720,22 @@ $(SIGNATURES)

A helper function for computing the Hessian of the expression `O` with respect to
an array of variable expressions.

# Keyword Arguments

- `simplify=false`: The simplify argument of `expand_derivatives`.

All other keyword arguments are forwarded to `expand_derivatives`.
"""
function hessian(O, vars::AbstractVector; simplify=false)
function hessian(O, vars::AbstractVector; simplify=false, kwargs...)
vars = map(value, vars)
first_derivs = map(value, vec(jacobian([values(O)], vars, simplify=simplify)))
first_derivs = map(value, vec(jacobian([values(O)], vars; simplify=simplify, kwargs...)))
n = length(vars)
H = Array{Num, 2}(undef,(n, n))
fill!(H, 0)
for i=1:n
for j=1:i
H[j, i] = H[i, j] = expand_derivatives(Differential(vars[i])(first_derivs[j]))
H[j, i] = H[i, j] = expand_derivatives(Differential(vars[i])(first_derivs[j]), simplify; kwargs...)
end
end
H
Expand Down Expand Up @@ -766,14 +855,22 @@ $(SIGNATURES)

A helper function for computing the sparse Hessian of an expression with respect to
an array of variable expressions.

# Keyword Arguments

- `simplify=false`: The simplify argument of `expand_derivatives`.
- `full=false`: Whether to construct the full hessian by also including entries in
the upper-triangular half of the matrix.

All other keyword arguments are forwarded to `expand_derivatives`.
"""
function sparsehessian(op, vars::AbstractVector; simplify::Bool=false, full::Bool=true)
function sparsehessian(op, vars::AbstractVector; simplify::Bool=false, full::Bool=true, kwargs...)
op = value(op)
vars = map(value, vars)
S = hessian_sparsity(op, vars, full=full)
I, J, _ = findnz(S)

exprs = sparsehessian_vals(op, vars, I, J, simplify=simplify)
exprs = sparsehessian_vals(op, vars, I, J; simplify=simplify, kwargs...)

H = sparse(I, J, exprs, length(vars), length(vars))

Expand All @@ -790,8 +887,14 @@ $(SIGNATURES)

A helper function for computing the values of the sparse Hessian of an expression with respect to
an array of variable expressions given the sparsity structure.

# Keyword Arguments

- `simplify=false`: The simplify argument of `expand_derivatives`.

All other keyword arguments are forwarded to `expand_derivatives`.
"""
function sparsehessian_vals(op, vars::AbstractVector, I::AbstractVector, J::AbstractVector; simplify::Bool=false)
function sparsehessian_vals(op, vars::AbstractVector, I::AbstractVector, J::AbstractVector; simplify::Bool=false, kwargs...)
vars = Symbolics.scalarize(vars)

exprs = Array{Num}(undef, length(I))
Expand All @@ -802,9 +905,9 @@ function sparsehessian_vals(op, vars::AbstractVector, I::AbstractVector, J::Abst
for (k, (i, j)) in enumerate(zip(I, J))
j > i && continue
if j != prev_j
d = expand_derivatives(Differential(vars[j])(op), false)
d = expand_derivatives(Differential(vars[j])(op), false; kwargs...)
end
expr = expand_derivatives(Differential(vars[i])(d), simplify)
expr = expand_derivatives(Differential(vars[i])(d), simplify; kwargs...)
exprs[k] = expr
prev_j = j
end
Expand Down
7 changes: 7 additions & 0 deletions src/extra_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,10 @@ end
@register_symbolic ⊆(x, y)

LinearAlgebra.norm(x::Num, p::Real) = abs(x)

derivative(::typeof(<), ::NTuple{2, Any}, ::Val{i}) where {i} = 0
derivative(::typeof(<=), ::NTuple{2, Any}, ::Val{i}) where {i} = 0
derivative(::typeof(>), ::NTuple{2, Any}, ::Val{i}) where {i} = 0
derivative(::typeof(>=), ::NTuple{2, Any}, ::Val{i}) where {i} = 0
derivative(::typeof(==), ::NTuple{2, Any}, ::Val{i}) where {i} = 0
derivative(::typeof(!=), ::NTuple{2, Any}, ::Val{i}) where {i} = 0
Loading
Loading