Skip to content

Commit

Permalink
add boundingbox and broadcast tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentcp committed Jun 24, 2019
1 parent 8d8892a commit 1f36feb
Show file tree
Hide file tree
Showing 10 changed files with 178 additions and 178 deletions.
9 changes: 6 additions & 3 deletions src/GridArrays.jl
Expand Up @@ -52,13 +52,16 @@ float_type(::Type{NTuple{N,T}}) where {N,T} = T
float_type(::Type{T}) where {T} = Float64


include("domains/extensions.jl")

include("grid.jl")
include("productgrid.jl")
include("intervalgrids.jl")
include("mappedgrid.jl")
include("scattered_grid.jl")


include("domains/boundingbox.jl")
include("domains/broadcast.jl")

include("randomgrid.jl")


Expand All @@ -74,7 +77,7 @@ include("test/test_grids.jl")


(d1::DomainSets.Interval, d2::DomainSets.Interval) =
11+abs(DomainSets.leftendpoint(d1)-DomainSets.leftendpoint(d2))+abs(DomainSets.rightendpoint(d1)-DomainSets.rightendpoint(d2))
11+abs(DomainSets.infimum(d1)-DomainSets.infimum(d2))+abs(DomainSets.supremum(d1)-DomainSets.supremum(d2))


end # module
74 changes: 74 additions & 0 deletions src/domains/boundingbox.jl
@@ -0,0 +1,74 @@

################
# Bounding boxes
#################

# A bounding box is an Interval or ProductDomain of intervals that encompasses the domain.

# If the boundingbox is not a product of intervals, something has gone wrong.

boundingbox(a::SVector{1}, b::SVector{1}) = a[1]..b[1]

boundingbox(a::Number, b::Number) = a..b

boundingbox(a, b) = ProductDomain(map((ai,bi)->ClosedInterval(ai,bi), a, b)...)

boundingbox(d::AbstractInterval) = d

boundingbox(::UnitHyperBall{N,T}) where {N,T} = boundingbox(-ones(SVector{N,T}), ones(SVector{N,T}))

boundingbox(d::ProductDomain) = cartesianproduct(map(boundingbox, elements(d))...)

boundingbox(d::DerivedDomain) = boundingbox(source(d))

boundingbox(d::DifferenceDomain) = boundingbox(d.d1)

function boundingbox(d::UnionDomain)
left = minimum(hcat(map(infimum,map(boundingbox,elements(d)))...);dims=2)
right = maximum(hcat(map(supremum,map(boundingbox,elements(d)))...);dims=2)
boundingbox(left,right)
end

function boundingbox(d::IntersectionDomain)
left = maximum(hcat(map(infimum,map(boundingbox,elements(d)))...);dims=2)
right = minimum(hcat(map(supremum,map(boundingbox,elements(d)))...);dims=2)
boundingbox(left,right)
end

DomainSets.superdomain(d::DomainSets.MappedDomain) = DomainSets.source(d)

# Now here is a problem: how do we compute a bounding box, without extra knowledge
# of the map? We can only do this for some maps.
boundingbox(d::DomainSets.MappedDomain) = mapped_boundingbox(boundingbox(source(d)), forward_map(d))

function mapped_boundingbox(box::Interval, fmap)
l,r = (infimum(box),supremum(box))
ml = fmap*l
mr = fmap*r
boundingbox(min(ml,mr), max(ml,mr))
end

# In general, we can at least map all the corners of the bounding box of the
# underlying domain, and compute a bounding box for those points. This will be
# correct for affine maps.
function mapped_boundingbox(box::ProductDomain, fmap)
crn = corners(infimum(box),supremum(box))
mapped_corners = [fmap*crn[:,i] for i in 1:size(crn,2)]
left = [minimum([mapped_corners[i][j] for i in 1:length(mapped_corners)]) for j in 1:size(crn,1)]
right = [maximum([mapped_corners[i][j] for i in 1:length(mapped_corners)]) for j in 1:size(crn,1)]
boundingbox(left, right)
end

# Auxiliary functions to rotate a bounding box when mapping it.
function corners(left::AbstractVector, right::AbstractVector)
@assert length(left)==length(right)
N=length(left)
corners = zeros(N,2^N)
# All possible permutations of the corners
for i=1:2^length(left)
for j=1:N
corners[j,i] = ((i>>(j-1))%2==0) ? left[j] : right[j]
end
end
corners
end
46 changes: 46 additions & 0 deletions src/domains/broadcast.jl
@@ -0,0 +1,46 @@
# A collection of extensions to the DomainSets package.

using DomainSets: inverse_map, forward_map

###########################
# Applying broadcast to in
###########################

# Intercept a broadcasted call to indomain. We assume that the user wants evaluation
# in a set of points (which we call a grid), rather than in a single point.
# TODO: the user may want to evaluate a single point in a sequence of domains...
broadcast(::typeof(in), grid, d::Domain) = indomain_broadcast(grid, d)

# # Default methods for evaluation on a grid: the default is to call eval on the domain with
# # points as arguments. Domains that have faster grid evaluation routines may define their own version.
indomain_broadcast(grid, d::Domain) = indomain_broadcast!(BitArray(undef, size(grid)), grid, d)
# TODO: use BitArray here

function indomain_broadcast!(result, grid::AbstractGrid, domain::Domain)
for (i,x) in enumerate(grid)
result[i] = DomainSets.in(x, domain)
end
result
end

function indomain_broadcast(grid::AbstractGrid, d::UnionDomain)
z = indomain_broadcast(grid, element(d,1))
for i in 2:numelements(d)
z = z .| indomain_broadcast(grid, element(d,i))
end
z
end

function indomain_broadcast(grid::AbstractGrid, d::IntersectionDomain)
z = indomain_broadcast(grid, element(d,1))
for i in 2:numelements(d)
z = z .& indomain_broadcast(grid, element(d,i))
end
z
end

function indomain_broadcast(grid::AbstractGrid, d::DifferenceDomain)
z1 = indomain_broadcast(grid, d.d1)
z2 = indomain_broadcast(grid, d.d2)
z1 .& (.~z2)
end
157 changes: 0 additions & 157 deletions src/domains/extensions.jl

This file was deleted.

9 changes: 0 additions & 9 deletions src/recipes.jl
Expand Up @@ -33,12 +33,3 @@ end
legend --> false
collect(grid), zeros(size(grid))
end

# # Plot a matrix of values on a 2D equispaced grid
# @recipe function f(grid::AbstractGrid2d, vals)
# seriestype --> :surface
# size --> (500,400)
# xrange = linspace(leftendpoint(grid)[1],rightendpoint(grid)[1],size(grid,1))
# yrange = linspace(leftendpoint(grid)[2],rightendpoint(grid)[2],size(grid,2))
# xrange, yrange, vals'
# end
4 changes: 2 additions & 2 deletions src/subgrid/AbstractSubGrids.jl
Expand Up @@ -23,8 +23,8 @@ include("boundary.jl")
subgrid(grid::AbstractGrid, domain::Domain) = MaskedGrid(grid, domain)

function subgrid(grid::AbstractEquispacedGrid, domain::AbstractInterval)
a = leftendpoint(domain)
b = rightendpoint(domain)
a = infimum(domain)
b = supremum(domain)
h = step(grid)
idx_a = convert(Int, ceil( (a-grid[1])/step(grid))+1 )
idx_b = convert(Int, floor( (b-grid[1])/step(grid))+1 )
Expand Down
4 changes: 2 additions & 2 deletions src/test/test_grids.jl
Expand Up @@ -2,7 +2,7 @@
function test_interval_grid(grid::AbstractGrid, show_timings=false)
test_generic_grid(grid, show_timings=show_timings)
T = eltype(grid)
g1 = rescale(rescale(grid, -T(10), T(3)), leftendpoint(support(grid)), rightendpoint(support(grid)))
g1 = rescale(rescale(grid, -T(10), T(3)), infimum(support(grid)), supremum(support(grid)))
@test support(g1) support(grid)
g2 = resize(grid, length(grid)<<1)
@test length(g2) == length(grid)<<1
Expand All @@ -11,7 +11,7 @@ function test_interval_grid(grid::AbstractGrid, show_timings=false)
g3 = resize(g1, length(grid)<<1)
@test length(g3) == length(grid)<<1

g4 = rescale(rescale(g2, -T(10), T(3)), leftendpoint(support(g2)), rightendpoint(support(g2)))
g4 = rescale(rescale(g2, -T(10), T(3)), infimum(support(g2)), supremum(support(g2)))
@test support(g4) support(g2)

if hasextension(grid)
Expand Down
11 changes: 6 additions & 5 deletions test/runtests.jl
Expand Up @@ -124,14 +124,14 @@ function test_grids(T)
# Does mapped_grid simplify?
mg2 = mapped_grid(PeriodicEquispacedGrid(30, T(0), T(1)), m)
@test typeof(mg2) <: PeriodicEquispacedGrid
@test leftendpoint(support(mg2)) T(2)
@test rightendpoint(support(mg2)) T(3)
@test infimum(support(mg2)) T(2)
@test supremum(support(mg2)) T(3)

# Apply a second map and check whether everything simplified
m2 = interval_map(T(2), T(3), T(4), T(5))
mg3 = mapped_grid(mg1, m2)
@test leftendpoint(support(mg3)) T(4)
@test rightendpoint(support(mg3)) T(5)
@test infimum(support(mg3)) T(4)
@test supremum(support(mg3)) T(5)
@test typeof(supergrid(mg3)) <: PeriodicEquispacedGrid

# Scattered grid
Expand Down Expand Up @@ -318,7 +318,8 @@ end

include("test_modcartesianindices.jl")


include("test_boundingbox.jl")
include("test_broadcast.jl")
test_subgrids()
test_randomgrids()

Expand Down

0 comments on commit 1f36feb

Please sign in to comment.