Skip to content

Commit

Permalink
Explicit Base and AT overloads
Browse files Browse the repository at this point in the history
  • Loading branch information
LuEdRaMo committed Apr 8, 2024
1 parent 889f409 commit b192757
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 15 deletions.
27 changes: 13 additions & 14 deletions src/adsbinarynode.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# This file is part of the TaylorIntegration.jl package; MIT licensed

using StaticArrays: SVector
import AbstractTrees: HasNodeType, NodeType, ParentLinks, StoredParents, TreeIterator,
children, nodetype, nodevalue, parent, printnode, PreOrderDFS
import Base: IteratorEltype, HasEltype, in, eltype, split
using AbstractTrees: StoredParents, HasNodeType, TreeIterator, PreOrderDFS
import AbstractTrees

if !isdefined(Base, :isnothing) # Julia 1.0 support
using AbstractTrees: isnothing
Expand Down Expand Up @@ -63,7 +62,7 @@ infima(s::ADSDomain{N, T}) where {N, T <: Real} = s.lo
# Return the upper bounds in s
suprema(s::ADSDomain{N, T}) where {N, T <: Real} = s.hi

Check warning on line 63 in src/adsbinarynode.jl

View check run for this annotation

Codecov / codecov/patch

src/adsbinarynode.jl#L63

Added line #L63 was not covered by tests
# Split s in half along direction i
function split(s::ADSDomain{N, T}, i::Int) where {N, T <: Real}
function Base.split(s::ADSDomain{N, T}, i::Int) where {N, T <: Real}
@assert 1 <= i <= N
mid = (s.lo[i] + s.hi[i])/2
a = ADSDomain{N, T}(
Expand All @@ -77,7 +76,7 @@ function split(s::ADSDomain{N, T}, i::Int) where {N, T <: Real}
return a, b
end

function in(x::AbstractVector{T}, s::ADSDomain{N, T}) where {N, T <: Real}
function Base.in(x::AbstractVector{T}, s::ADSDomain{N, T}) where {N, T <: Real}
@assert length(x) == N "x must be of length $N"
mask = SVector{N, Bool}(s.lo[i] <= x[i] <= s.hi[i] for i in 1:N)
return all(mask)
Expand Down Expand Up @@ -170,7 +169,7 @@ function rightchild!(parent::ADSBinaryNode{N, M, T}, node::ADSBinaryNode{N, M, T
end

# AbstractTrees interface
function children(node::ADSBinaryNode{N, M, T}) where {N, M, T <: Real}
function AbstractTrees.children(node::ADSBinaryNode{N, M, T}) where {N, M, T <: Real}
if isnothing(node.left) && isnothing(node.right)
()
elseif isnothing(node.left) && !isnothing(node.right)
Expand All @@ -182,20 +181,20 @@ function children(node::ADSBinaryNode{N, M, T}) where {N, M, T <: Real}
end
end

printnode(io::IO, n::ADSBinaryNode{N, M, T}) where {N, M, T <: Real} = print(io, "t: ", n.t)
AbstractTrees.printnode(io::IO, n::ADSBinaryNode{N, M, T}) where {N, M, T <: Real} = print(io, "t: ", n.t)

Check warning on line 184 in src/adsbinarynode.jl

View check run for this annotation

Codecov / codecov/patch

src/adsbinarynode.jl#L184

Added line #L184 was not covered by tests

nodevalue(n::ADSBinaryNode{N, M, T}) where {N, M, T <: Real} = (n.s, n.t, n.x, n.p)
AbstractTrees.nodevalue(n::ADSBinaryNode{N, M, T}) where {N, M, T <: Real} = (n.s, n.t, n.x, n.p)

Check warning on line 186 in src/adsbinarynode.jl

View check run for this annotation

Codecov / codecov/patch

src/adsbinarynode.jl#L186

Added line #L186 was not covered by tests

ParentLinks(::Type{<:ADSBinaryNode{N, M, T}}) where {N, M, T <: Real} = StoredParents()
AbstractTrees.ParentLinks(::Type{<:ADSBinaryNode{N, M, T}}) where {N, M, T <: Real} = StoredParents()

parent(n::ADSBinaryNode{N, M, T}) where {N, M, T <: Real} = n.parent
AbstractTrees.parent(n::ADSBinaryNode{N, M, T}) where {N, M, T <: Real} = n.parent

NodeType(::Type{<:ADSBinaryNode{N, M, T}}) where {N, M, T <: Real} = HasNodeType()
nodetype(::Type{<:ADSBinaryNode{N, M, T}}) where {N, M, T <: Real} = ADSBinaryNode{N, M, T}
AbstractTrees.NodeType(::Type{<:ADSBinaryNode{N, M, T}}) where {N, M, T <: Real} = HasNodeType()
AbstractTrees.nodetype(::Type{<:ADSBinaryNode{N, M, T}}) where {N, M, T <: Real} = ADSBinaryNode{N, M, T}

Check warning on line 193 in src/adsbinarynode.jl

View check run for this annotation

Codecov / codecov/patch

src/adsbinarynode.jl#L193

Added line #L193 was not covered by tests

# For TreeIterator
IteratorEltype(::Type{<:TreeIterator{ADSBinaryNode{N, M, T}}}) where {N, M, T <: Real} = HasEltype()
eltype(::Type{<:TreeIterator{ADSBinaryNode{N, M, T}}}) where {N, M, T <: Real} = ADSBinaryNode{N, M, T}
Base.IteratorEltype(::Type{<:TreeIterator{ADSBinaryNode{N, M, T}}}) where {N, M, T <: Real} = HasEltype()
Base.eltype(::Type{<:TreeIterator{ADSBinaryNode{N, M, T}}}) where {N, M, T <: Real} = ADSBinaryNode{N, M, T}

Check warning on line 197 in src/adsbinarynode.jl

View check run for this annotation

Codecov / codecov/patch

src/adsbinarynode.jl#L196-L197

Added lines #L196 - L197 were not covered by tests

"""
countnodes(n::ADSBinaryNode{N, M, T}, k::Int) where {N, M, T <: Real}
Expand Down
2 changes: 1 addition & 1 deletion src/adstaylorinteg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ end

# Split node's domain in half
# See section 3 of https://doi.org/10.1007/s10569-015-9618-3
function split(node::ADSBinaryNode{N, M, T}, x::SVector{M, TaylorN{T}},
function Base.split(node::ADSBinaryNode{N, M, T}, x::SVector{M, TaylorN{T}},
p::SVector{M, Taylor1{TaylorN{T}}}, dt::T) where {N, M, T <: Real}
# Split direction
j = splitdirection(x)
Expand Down

0 comments on commit b192757

Please sign in to comment.