Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ArrayDifferentialOperators for Vector calculus #942

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions src/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ end
# turn `f(x...)` into `term(f, x...)`
#
function call2term(expr, arrs=[])
(expr isa QuoteNode) && return expr
!(expr isa Expr) && return :($unwrap($expr))
if expr.head == :call
if expr.args[1] == :(:)
Expand Down
106 changes: 99 additions & 7 deletions src/diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,17 @@ struct Differential <: Operator
x
Differential(x) = new(value(x))
end
(D::Differential)(x) = Term{symtype(x)}(D, [x])
(D::Differential)(x::Num) = Num(D(value(x)))
(D::Differential)(x::Complex{Num}) = wrap(ComplexTerm{Real}(D(unwrap(real(x))), D(unwrap(imag(x)))))
(D::Operator)(x) = Term{symtype(x)}(D, [x])
(D::Operator)(x::Num) = Num(D(value(x)))
(D::Operator)(x::Complex{Num}) = wrap(ComplexTerm{Real}(D(unwrap(real(x))), D(unwrap(imag(x)))))
SymbolicUtils.promote_symtype(::Differential, x) = x

is_derivative(x) = istree(x) ? operation(x) isa Differential : false

Base.:*(D1, D2::Differential) = D1 ∘ D2
Base.:*(D1::Differential, D2) = D1 ∘ D2
Base.:*(D1::Differential, D2::Differential) = D1 ∘ D2
Base.:^(D::Differential, n::Integer) = _repeat_apply(D, n)
Base.:*(D1, D2::Operator) = D1 ∘ D2
Base.:*(D1::Operator, D2) = D1 ∘ D2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering why we have these methods @ChrisRackauckas ? this does not make sense in general, only maybe for 2 operators

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's a case where they don't make sense?

Base.:*(D1::Operator, D2::Operator) = D1 ∘ D2
Base.:^(D::Operator, n::Integer) = _repeat_apply(D, n)

Base.show(io::IO, D::Differential) = print(io, "Differential(", D.x, ")")

Expand Down Expand Up @@ -785,3 +785,95 @@ end
function SymbolicUtils.substitute(op::Differential, dict; kwargs...)
@set! op.x = substitute(op.x, dict; kwargs...)
end


#######################################################################################################################
# Vector Calculus
#######################################################################################################################

struct ArrayDifferentialOperator <: Operator
"""The variables to differentiate with respect to."""
vars
"""The differentials, can be other functions if composite"""
differentials
"""name"""
name
ArrayDifferentialOperator(vars, differentials, name) = new(vars, differentials, name)
end
Nabla(vars) = ArrayDifferentialOperator(value.(vars), map(Differential, scalarize(value.(vars))), "∇")
const Grad = Nabla
Div(vars) = (x) -> Nabla(vars) ⋅ x
Curl(vars) = (x) -> Nabla(vars) × x
Laplacian(vars) = Nabla(vars) ⋅ Nabla(vars)

#? How to get transpose and Jac working?

function (D::ArrayDifferentialOperator)(x::SymVec)
@assert length(D.vars) == length(x) "Vector must be same length as vars in Operator $(D.name)."
@arrayop (i,) (D.differentials)[i](x[i]) term=D(x)
end
(D::ArrayDifferentialOperator)(x::Arr) = Arr(D(value(x)))

function (D1::ArrayDifferentialOperator)(D2::ArrayDifferentialOperator)
@assert all(x -> any(isequal.((x,), D2.vars)), D1.vars)

ArrayDifferentialOperator(D1.vars, scalarize(D1.differentials .∘ D2.differentials), "("*D1.name*"∘"*D2.name*")")
end

function LinearAlgebra.dot(D::ArrayDifferentialOperator, x::SymVec)
@assert length(D.vars) == length(x) "Vector must be same length as vars in Operator $(D.name)."
@show D(x), scalarize(D(x))
sum(scalarize(D(x)))
end
LinearAlgebra.dot(D::ArrayDifferentialOperator, x::Arr) = Num(D ⋅ value(x))

function LinearAlgebra.dot(x::SymVec, D::ArrayDifferentialOperator)
@assert length(D.vars) == length(x) "Vector must be same length as vars in Operator $(D.name)."
(y) -> sum(@arrayop (i,) x[i]*D.differentials[i](y) term = (x⋅D)(y))
end
LinearAlgebra.dot(x::Arr, D::ArrayDifferentialOperator) = value(x) ⋅ D

function LinearAlgebra.dot(D1::ArrayDifferentialOperator, D2::ArrayDifferentialOperator)
@assert all(scalarize(isequal.(D1.vars, D2.vars))) "Operators have different variables and cannot be composed."
lap = x -> sum((D1.differentials[i] ∘ D2.differentials[i])(x) for i in 1:length(D1.vars))
(x) -> @arrayop (i,) lap(x[i]) term=(D1⋅D2)(x) reduce=+
end

function crosscompose(a, b)
v1 = x -> (a[2] ∘ b[3])(x) - (a[3] ∘ b[2])(x)
v2 = x -> (a[3] ∘ b[1])(x) - (a[1] ∘ b[3])(x)
v3 = x -> (a[1] ∘ b[2])(x) - (a[2] ∘ b[1])(x)
return [v1, v2, v3]
end

function crosscall(a, b)
v1 = a[2](b[3]) - a[3](b[2])
v2 = a[3](b[1]) - a[1](b[3])
v3 = a[1](b[2]) - a[2](b[1])
return [v1, v2, v3]
end
function LinearAlgebra.cross(D::ArrayDifferentialOperator, x::SymVec)
@assert length(D.vars) == length(x) == 3 "Cross product is only defined in 3 dimensions."
curl = crosscall(D.differentials, x)
@arrayop (i,) curl[i] term=D×x
end
LinearAlgebra.cross(D::ArrayDifferentialOperator, x::Arr) = Arr(D × value(x))

function LinearAlgebra.cross(D1::ArrayDifferentialOperator, D2::ArrayDifferentialOperator)
@assert length(D1.vars) == length(D2.vars) == 3 "Cross product is only defined in 3 dimensions."
@assert all(scalarize(isequal.(D1.vars, D2.vars))) "Operators have different variables and cannot be composed."

ArrayDifferentialOperator(D1.vars, crosscompose(D1.differentials, D2.differentials), "("*D1.name*"×"*D2.name*")")
end

SymbolicUtils.promote_symtype(::ArrayDifferentialOperator, x) = x

Base.show(io::IO, D::ArrayDifferentialOperator) = print(io, D.name)
Base.nameof(D::ArrayDifferentialOperator) = Symbol(D.name)

function Base.:(==)(D1::ArrayDifferentialOperator, D2::ArrayDifferentialOperator)
@variables x[1:length(D1.vars)]
all(scalarize(isequal.(D1.vars, D2.vars))) && all(scalarize(isequal.(D1(x), D2(x))))
end

# TODO: Add simplification rules for dot and cross products to remove 0 terms and simplify