Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle units consistently in handle_infinites #96

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/QuadGK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,14 @@ struct InplaceIntegrand{F,R,RI}
Ig::R
Ik::R
fx::R
Idiff::RI
Idiff::R

# final result array
I::RI
end

InplaceIntegrand(f!::F, I::RI, fx::R) where {F,RI,R} =
InplaceIntegrand{F,R,RI}(f!, similar(fx), similar(fx), similar(fx), similar(fx), fx, similar(I), I)
InplaceIntegrand{F,R,RI}(f!, similar(fx), similar(fx), similar(fx), similar(fx), fx, similar(fx), I)

include("gausskronrod.jl")
include("evalrule.jl")
Expand Down
65 changes: 42 additions & 23 deletions src/adapt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,61 +119,77 @@ realone(x::Number) = one(x) isa Real
# and pass transformed data to workfunc(f, s, tfunc)
function handle_infinities(workfunc, f, s)
s1, s2 = s[1], s[end]
u = float(real(oneunit(s1))) # the units of the segment
if realone(s1) && realone(s2) # check for infinite or semi-infinite intervals
inf1, inf2 = isinf(s1), isinf(s2)
if inf1 || inf2
if inf1 && inf2 # x = t/(1-t^2) coordinate transformation
return workfunc(t -> begin t2 = t*t; den = 1 / (1 - t2);
f(oneunit(s1)*t*den) * (1+t2)*den*den*oneunit(s1); end,
I, E = workfunc(t -> begin t2 = t*t; den = 1 / (1 - t2);
f(u*t*den) * (1+t2)*den*den; end,
map(x -> isinf(x) ? (signbit(x) ? -one(x) : one(x)) : 2x / (oneunit(x)+hypot(oneunit(x),2x)), s),
t -> oneunit(s1) * t / (1 - t^2))
t -> u * t / (1 - t^2))
return u * I, u * E
end
let (s0,si) = inf1 ? (s2,s1) : (s1,s2) # let is needed for JuliaLang/julia#15276
if si < zero(si) # x = s0 - t/(1-t)
return workfunc(t -> begin den = 1 / (1 - t);
f(s0 - oneunit(s1)*t*den) * den*den*oneunit(s1); end,
I, E = workfunc(t -> begin den = 1 / (1 - t);
f(s0 - u*t*den) * den*den; end,
reverse(map(x -> 1 / (1 + oneunit(x) / (s0 - x)), s)),
t -> s0 - oneunit(s1)*t/(1-t))
t -> s0 - u*t/(1-t))
return u * I, u * E
else # x = s0 + t/(1-t)
return workfunc(t -> begin den = 1 / (1 - t);
f(s0 + oneunit(s1)*t*den) * den*den*oneunit(s1); end,
I, E = workfunc(t -> begin den = 1 / (1 - t);
f(s0 + u*t*den) * den*den; end,
map(x -> 1 / (1 + oneunit(x) / (x - s0)), s),
t -> s0 + oneunit(s1)*t/(1-t))
t -> s0 + u*t/(1-t))
return u * I, u * E
end
end
end
end
return workfunc(f, s, identity)
I, E = workfunc(t -> f(t*u), map(x -> x/oneunit(x), s), identity)
return u * I, u * E
end

function handle_infinities(workfunc, f::InplaceIntegrand, s)
s1, s2 = s[1], s[end]
result = f.I
u = float(real(oneunit(s1))) # the units of the segment
if realone(s1) && realone(s2) # check for infinite or semi-infinite intervals
inf1, inf2 = isinf(s1), isinf(s2)
if inf1 || inf2
ftmp = f.fx # original integrand may have different units
if inf1 && inf2 # x = t/(1-t^2) coordinate transformation
return workfunc(InplaceIntegrand((v, t) -> begin t2 = t*t; den = 1 / (1 - t2);
f.f!(ftmp, oneunit(s1)*t*den); v .= ftmp .* ((1+t2)*den*den*oneunit(s1)); end, f.I, f.fx * oneunit(s1)),
I, E = workfunc(InplaceIntegrand((v, t) -> begin t2 = t*t; den = 1 / (1 - t2);
f.f!(v, u*t*den); v .*= ((1+t2)*den*den); end, f.fg, f.fk, f.Ig, f.Ik, f.fx, f.Idiff, similar(f.fx)),
map(x -> isinf(x) ? (signbit(x) ? -one(x) : one(x)) : 2x / (oneunit(x)+hypot(oneunit(x),2x)), s),
t -> oneunit(s1) * t / (1 - t^2))
t -> u * t / (1 - t^2))
result .= u .* I
return result, u * E
end
let (s0,si) = inf1 ? (s2,s1) : (s1,s2) # let is needed for JuliaLang/julia#15276
if si < zero(si) # x = s0 - t/(1-t)
return workfunc(InplaceIntegrand((v, t) -> begin den = 1 / (1 - t);
f.f!(ftmp, s0 - oneunit(s1)*t*den); v .= ftmp .* (den * den * oneunit(s1)); end, f.I, f.fx * oneunit(s1)),
I, E = workfunc(InplaceIntegrand((v, t) -> begin den = 1 / (1 - t);
f.f!(v, s0 - u*t*den); v .*= (den * den); end, f.fg, f.fk, f.Ig, f.Ik, f.fx, f.Idiff, similar(f.fx)),
reverse(map(x -> 1 / (1 + oneunit(x) / (s0 - x)), s)),
t -> s0 - oneunit(s1)*t/(1-t))
t -> s0 - u*t/(1-t))
result .= u .* I
return result, u * E
else # x = s0 + t/(1-t)
return workfunc(InplaceIntegrand((v, t) -> begin den = 1 / (1 - t);
f.f!(ftmp, s0 + oneunit(s1)*t*den); v .= ftmp .* (den * den * oneunit(s1)); end, f.I, f.fx * oneunit(s1)),
I, E = workfunc(InplaceIntegrand((v, t) -> begin den = 1 / (1 - t);
f.f!(v, s0 + u*t*den); v .*= (den * den); end, f.fg, f.fk, f.Ig, f.Ik, f.fx, f.Idiff, similar(f.fx)),
map(x -> 1 / (1 + oneunit(x) / (x - s0)), s),
t -> s0 + oneunit(s1)*t/(1-t))
t -> s0 + u*t/(1-t))
result .= u .* I
return result, u * E
end
end
end
end
return workfunc(f, s, identity)
I, E = workfunc(InplaceIntegrand((y,t) -> f.f!(y, t*u), f.fg, f.fk, f.Ig, f.Ik, f.fx, f.Idiff, similar(f.fx)),
map(x -> x/oneunit(x), s),
identity)
result .= u .* I
return result, u * E
end

function check_endpoint_roundoff(a, b, x; throw_error::Bool=false)
Expand Down Expand Up @@ -254,8 +270,9 @@ quadgk(f, segs...; kws...) =

function quadgk(f, segs::T...;
atol=nothing, rtol=nothing, maxevals=10^7, order=7, norm=norm, segbuf=nothing) where {T}
utol = isnothing(atol) ? atol : atol/float(real(oneunit(T))) # remove units of domain
handle_infinities(f, segs) do f, s, _
do_quadgk(f, s, order, atol, rtol, maxevals, norm, segbuf)
do_quadgk(f, s, order, utol, rtol, maxevals, norm, segbuf)
end
end

Expand All @@ -270,7 +287,9 @@ starting with the given `size`. The buffer can then be reused across multiple
compatible calls to `quadgk(...)` to avoid repeated allocation.
"""
function alloc_segbuf(domain_type=Float64, range_type=Float64, error_type=Float64; size=1)
Vector{Segment{domain_type, range_type, error_type}}(undef, size)
x = float(real(one(domain_type))) # quadrature point type
err = oneunit(error_type)/oneunit(domain_type) # error has units of the integral
Vector{Segment{typeof(x), range_type, typeof(err)}}(undef, size)
end

"""
Expand Down
60 changes: 33 additions & 27 deletions src/batch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,19 @@ If `x` or `X` are not specified, `quadgk` internally creates a new `BatchIntegra
user-supplied `y` buffer and a freshly-allocated `x` buffer based on the domain types. So,
if you want to reuse the `x` buffer between calls, supply `{Y,X}` or pass `y,x` explicitly.
"""
struct BatchIntegrand{Y,X,Ty<:AbstractVector{Y},Tx<:AbstractVector{X},F}
struct BatchIntegrand{Y,X,Ty<:AbstractVector{Y},Tx<:AbstractVector{X},F,T}
# in-place function f!(y, x) that takes an array of x values and outputs an array of results in-place
f!::F
y::Ty
x::Tx
t::T
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not great that we pay the price of having an additional point array here …

max_batch::Int # maximum number of x to supply in parallel
end

function BatchIntegrand(f!, y::AbstractVector, x::AbstractVector=similar(y, Nothing); max_batch::Integer=typemax(Int))
max_batch > 0 || throw(ArgumentError("max_batch must be positive"))
return BatchIntegrand(f!, y, x, max_batch)
X = eltype(x)
return BatchIntegrand(f!, y, x, X <: Nothing ? nothing : similar(x, typeof(one(X))), max_batch)
end
BatchIntegrand{Y,X}(f!; kws...) where {Y,X} = BatchIntegrand(f!, Y[], X[]; kws...)
BatchIntegrand{Y}(f!; kws...) where {Y} = BatchIntegrand(f!, Y[]; kws...)
Expand Down Expand Up @@ -65,20 +67,20 @@ function evalrules(f::BatchIntegrand, s::NTuple{N}, x,w,gw, nrm) where {N}
l = length(x)
m = 2l-1 # evaluations per segment
n = (N-1)*m # total evaluations
resize!(f.x, n)
resize!(f.t, n)
resize!(f.y, n)
for i in 1:(N-1) # fill buffer with evaluation points
a = s[i]; b = s[i+1]
check_endpoint_roundoff(a, b, x, throw_error=true)
c = convert(eltype(x), 0.5) * (b-a)
o = (i-1)*m
f.x[l+o] = a + c
f.t[l+o] = a + c
for j in 1:l-1
f.x[j+o] = a + (1 + x[j]) * c
f.x[m+1-j+o] = a + (1 - x[j]) * c
f.t[j+o] = a + (1 + x[j]) * c
f.t[m+1-j+o] = a + (1 - x[j]) * c
end
end
f.f!(f.y, f.x) # evaluate integrand
f.f!(f.y, f.t) # evaluate integrand
return ntuple(Val(N-1)) do i
return batchevalrule(view(f.y, (1+(i-1)*m):(i*m)), s[i], s[i+1], x,w,gw, nrm)
end
Expand All @@ -105,7 +107,7 @@ function refine(f::BatchIntegrand, segs::Vector{T}, I, E, numevals, x,w,gw,n, at
len > nsegs && DataStructures.percolate_down!(segs, 1, y, Reverse, len-nsegs)
end

resize!(f.x, 2m*nsegs)
resize!(f.t, 2m*nsegs)
resize!(f.y, 2m*nsegs)
for i in 1:nsegs # fill buffer with evaluation points
s = segs[len-i+1]
Expand All @@ -114,15 +116,15 @@ function refine(f::BatchIntegrand, segs::Vector{T}, I, E, numevals, x,w,gw,n, at
check_endpoint_roundoff(a, b, x) && return segs
c = convert(eltype(x), 0.5) * (b-a)
o = (2i-j)*m
f.x[l+o] = a + c
f.t[l+o] = a + c
for k in 1:l-1
# early return if integrand evaluated at endpoints
f.x[k+o] = a + (1 + x[k]) * c
f.x[m+1-k+o] = a + (1 - x[k]) * c
f.t[k+o] = a + (1 + x[k]) * c
f.t[m+1-k+o] = a + (1 - x[k]) * c
end
end
end
f.f!(f.y, f.x)
f.f!(f.y, f.t)

resize!(segs, len+nsegs)
for i in 1:nsegs # evaluate segments and update estimates & heap
Expand All @@ -144,35 +146,39 @@ end

function handle_infinities(workfunc, f::BatchIntegrand, s)
s1, s2 = s[1], s[end]
u = float(real(oneunit(s1))) # the units of the segment
if realone(s1) && realone(s2) # check for infinite or semi-infinite intervals
inf1, inf2 = isinf(s1), isinf(s2)
if inf1 || inf2
xtmp = f.x # buffer to store evaluation points
ytmp = f.y # original integrand may have different units
xbuf = similar(xtmp, typeof(one(eltype(f.x))))
ybuf = similar(ytmp, typeof(oneunit(eltype(f.y))*oneunit(s1)))
if inf1 && inf2 # x = t/(1-t^2) coordinate transformation
return workfunc(BatchIntegrand((v, t) -> begin resize!(xtmp, length(t)); resize!(ytmp, length(v));
f.f!(ytmp, xtmp .= oneunit(s1) .* t ./ (1 .- t .* t)); v .= ytmp .* (1 .+ t .* t) .* oneunit(s1) ./ (1 .- t .* t) .^ 2; end, ybuf, xbuf, f.max_batch),
I, E = workfunc(BatchIntegrand((v, t) -> begin resize!(f.x, length(t));
f.f!(v, f.x .= u .* t ./ (1 .- t .* t)); v .*= (1 .+ t .* t) ./ (1 .- t .* t) .^ 2; end, f.y, f.x, f.t, f.max_batch),
map(x -> isinf(x) ? (signbit(x) ? -one(x) : one(x)) : 2x / (oneunit(x)+hypot(oneunit(x),2x)), s),
t -> oneunit(s1) * t / (1 - t^2))
t -> u * t / (1 - t^2))
return u * I, u * E
end
let (s0,si) = inf1 ? (s2,s1) : (s1,s2) # let is needed for JuliaLang/julia#15276
if si < zero(si) # x = s0 - t/(1-t)
return workfunc(BatchIntegrand((v, t) -> begin resize!(xtmp, length(t)); resize!(ytmp, length(v));
f.f!(ytmp, xtmp .= s0 .- oneunit(s1) .* t ./ (1 .- t)); v .= ytmp .* oneunit(s1) ./ (1 .- t) .^ 2; end, ybuf, xbuf, f.max_batch),
I, E = workfunc(BatchIntegrand((v, t) -> begin resize!(f.x, length(t));
f.f!(v, f.x .= s0 .- u .* t ./ (1 .- t)); v ./= (1 .- t) .^ 2; end, f.y, f.x, f.t, f.max_batch),
reverse(map(x -> 1 / (1 + oneunit(x) / (s0 - x)), s)),
t -> s0 - oneunit(s1)*t/(1-t))
t -> s0 - u*t/(1-t))
return u * I, u * E
else # x = s0 + t/(1-t)
return workfunc(BatchIntegrand((v, t) -> begin resize!(xtmp, length(t)); resize!(ytmp, length(v));
f.f!(ytmp, xtmp .= s0 .+ oneunit(s1) .* t ./ (1 .- t)); v .= ytmp .* oneunit(s1) ./ (1 .- t) .^ 2; end, ybuf, xbuf, f.max_batch),
I, E = workfunc(BatchIntegrand((v, t) -> begin resize!(f.x, length(t));
f.f!(v, f.x .= s0 .+ u .* t ./ (1 .- t)); v ./= (1 .- t) .^ 2; end, f.y, f.x, f.t, f.max_batch),
map(x -> 1 / (1 + oneunit(x) / (x - s0)), s),
t -> s0 + oneunit(s1)*t/(1-t))
t -> s0 + u*t/(1-t))
return u * I, u * E
end
end
end
end
return workfunc(f, s, identity)
I, E = workfunc(BatchIntegrand((y, t) -> begin resize!(f.x, length(t));
f.f!(y, f.x .= u .* t); end, f.y, f.x, f.t, f.max_batch),
map(x -> x/oneunit(x), s),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pays the cost of a division for every integrand evaluation?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, that is just for the segments.

identity)
return u * I, u * E
end

"""
Expand All @@ -194,6 +200,6 @@ simultaneously. In particular, there are two differences from `quadgk`
"""
function quadgk(f::BatchIntegrand{Y,Nothing}, segs::T...; kws...) where {Y,T}
FT = float(T) # the gk points are floating-point
g = BatchIntegrand(f.f!, f.y, similar(f.x, FT), f.max_batch)
g = BatchIntegrand(f.f!, f.y, similar(f.x, FT), similar(f.x, typeof(float(one(FT)))), f.max_batch)
return quadgk(g, segs...; kws...)
end
33 changes: 31 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,51 @@ module Test19626
end

# Following definitions needed for quadgk to work with MockQuantity
import Base: +, -, *, abs, isnan, isinf, isless, float
import Base: +, -, *, /, abs, isnan, isinf, isless, float, one, real, signbit
+(a::MockQuantity, b::MockQuantity) = MockQuantity(a.val+b.val)
-(a::MockQuantity, b::MockQuantity) = MockQuantity(a.val-b.val)
-(a::MockQuantity) = MockQuantity(-a.val)
*(a::MockQuantity, b::Number) = MockQuantity(a.val*b)
*(a::Number, b::MockQuantity) = MockQuantity(a*b.val)
/(a::MockQuantity, b::Number) = MockQuantity(a.val/b)
/(a::MockQuantity, b::MockQuantity) = a.val/b.val
abs(a::MockQuantity) = MockQuantity(abs(a.val))
float(a::MockQuantity) = a
isnan(a::MockQuantity) = isnan(a.val)
isinf(a::MockQuantity) = isinf(a.val)
isless(a::MockQuantity, b::MockQuantity) = isless(a.val, b.val)
one(::Type{MockQuantity}) = one(fieldtype(MockQuantity, :val))
one(a::MockQuantity) = one(a.val)
real(a::MockQuantity) = MockQuantity(real(a.val))
signbit(a::MockQuantity) = signbit(a.val)

# isapprox only needed for test purposes
Base.isapprox(a::MockQuantity, b::MockQuantity) = isapprox(a.val, b.val)

# Test physical quantity-valued functions
@test QuadGK.quadgk(x->MockQuantity(x), 0.0, 1.0, atol=MockQuantity(0.0))[1] ≈
MockQuantity(0.5)

# Test that quantities work with infinity transformations and segbufs
# for all quadgk interfaces
lims = [(-1.0, 1.0), (-Inf, 0.0), (0.0, Inf), (-Inf, Inf)]
ulims = map(x -> map(MockQuantity, x), lims)
for (lb, ub) in ulims
## function
f = x -> 1/(1+(x/MockQuantity(1.0))^2)
buf = QuadGK.alloc_segbuf(MockQuantity, Float64, MockQuantity)
@test QuadGK.quadgk(f, lb, ub, atol=MockQuantity(0.0))[1] ≈
QuadGK.quadgk(f, lb, ub, atol=MockQuantity(0.0), segbuf=buf)[1]
## inplace
fiip = (y, x) -> y[1] = 1/(1+(x/MockQuantity(1.0))^2)
ibuf = QuadGK.alloc_segbuf(MockQuantity, Array{Float64,1}, MockQuantity)
@test QuadGK.quadgk!(fiip, [MockQuantity(0.0)], lb, ub, atol=MockQuantity(0.0), norm=abs∘first)[1][] ≈
QuadGK.quadgk!(fiip, [MockQuantity(0.0)], lb, ub, atol=MockQuantity(0.0), norm=abs∘getindex, segbuf=ibuf)[1][]
## batch
fbatch = BatchIntegrand{Float64}((y, x) -> y .= 1 ./ (1 .+ (x ./ MockQuantity(1.0)) .^ 2))
@test QuadGK.quadgk(fbatch, lb, ub, atol=MockQuantity(0.0))[1] ≈
QuadGK.quadgk(fbatch, lb, ub, atol=MockQuantity(0.0), segbuf=buf)[1]
end
end

@testset "inference" begin
Expand Down Expand Up @@ -312,7 +341,7 @@ end
end

# test constructors
ref = BatchIntegrand(f!, Float64[], Nothing[], typemax(Int))
ref = BatchIntegrand(f!, Float64[], Nothing[], nothing, typemax(Int))
for b in (
BatchIntegrand(f!, Float64[]),
BatchIntegrand(f!, Float64[], Nothing[]),
Expand Down
Loading