Skip to content

Commit

Permalink
add unitless buffer to BatchIntegrand
Browse files Browse the repository at this point in the history
  • Loading branch information
lxvm committed Nov 20, 2023
1 parent eda5c03 commit abcc700
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 19 deletions.
37 changes: 19 additions & 18 deletions src/batch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,19 @@ If `x` or `X` are not specified, `quadgk` internally creates a new `BatchIntegra
user-supplied `y` buffer and a freshly-allocated `x` buffer based on the domain types. So,
if you want to reuse the `x` buffer between calls, supply `{Y,X}` or pass `y,x` explicitly.
"""
struct BatchIntegrand{Y,X,Ty<:AbstractVector{Y},Tx<:AbstractVector{X},F}
struct BatchIntegrand{Y,X,Ty<:AbstractVector{Y},Tx<:AbstractVector{X},F,T}
# in-place function f!(y, x) that takes an array of x values and outputs an array of results in-place
f!::F
y::Ty
x::Tx
t::T
max_batch::Int # maximum number of x to supply in parallel
end

function BatchIntegrand(f!, y::AbstractVector, x::AbstractVector=similar(y, Nothing); max_batch::Integer=typemax(Int))
max_batch > 0 || throw(ArgumentError("max_batch must be positive"))
return BatchIntegrand(f!, y, x, max_batch)
X = eltype(x)
return BatchIntegrand(f!, y, x, X <: Nothing ? nothing : similar(x, typeof(one(X))), max_batch)
end
BatchIntegrand{Y,X}(f!; kws...) where {Y,X} = BatchIntegrand(f!, Y[], X[]; kws...)
BatchIntegrand{Y}(f!; kws...) where {Y} = BatchIntegrand(f!, Y[]; kws...)
Expand Down Expand Up @@ -65,20 +67,20 @@ function evalrules(f::BatchIntegrand, s::NTuple{N}, x,w,gw, nrm) where {N}
l = length(x)
m = 2l-1 # evaluations per segment
n = (N-1)*m # total evaluations
resize!(f.x, n)
resize!(f.t, n)
resize!(f.y, n)
for i in 1:(N-1) # fill buffer with evaluation points
a = s[i]; b = s[i+1]
check_endpoint_roundoff(a, b, x, throw_error=true)
c = convert(eltype(x), 0.5) * (b-a)
o = (i-1)*m
f.x[l+o] = a + c
f.t[l+o] = a + c
for j in 1:l-1
f.x[j+o] = a + (1 + x[j]) * c
f.x[m+1-j+o] = a + (1 - x[j]) * c
f.t[j+o] = a + (1 + x[j]) * c
f.t[m+1-j+o] = a + (1 - x[j]) * c
end
end
f.f!(f.y, f.x) # evaluate integrand
f.f!(f.y, f.t) # evaluate integrand
return ntuple(Val(N-1)) do i
return batchevalrule(view(f.y, (1+(i-1)*m):(i*m)), s[i], s[i+1], x,w,gw, nrm)
end
Expand All @@ -105,7 +107,7 @@ function refine(f::BatchIntegrand, segs::Vector{T}, I, E, numevals, x,w,gw,n, at
len > nsegs && DataStructures.percolate_down!(segs, 1, y, Reverse, len-nsegs)
end

resize!(f.x, 2m*nsegs)
resize!(f.t, 2m*nsegs)
resize!(f.y, 2m*nsegs)
for i in 1:nsegs # fill buffer with evaluation points
s = segs[len-i+1]
Expand All @@ -114,15 +116,15 @@ function refine(f::BatchIntegrand, segs::Vector{T}, I, E, numevals, x,w,gw,n, at
check_endpoint_roundoff(a, b, x) && return segs
c = convert(eltype(x), 0.5) * (b-a)
o = (2i-j)*m
f.x[l+o] = a + c
f.t[l+o] = a + c
for k in 1:l-1
# early return if integrand evaluated at endpoints
f.x[k+o] = a + (1 + x[k]) * c
f.x[m+1-k+o] = a + (1 - x[k]) * c
f.t[k+o] = a + (1 + x[k]) * c
f.t[m+1-k+o] = a + (1 - x[k]) * c
end
end
end
f.f!(f.y, f.x)
f.f!(f.y, f.t)

resize!(segs, len+nsegs)
for i in 1:nsegs # evaluate segments and update estimates & heap
Expand All @@ -145,27 +147,26 @@ end
function handle_infinities(workfunc, f::BatchIntegrand, s)
s1, s2 = s[1], s[end]
u = float(real(oneunit(s1))) # the units of the segment
tbuf = similar(f.x, typeof(s1/oneunit(s1)))
if realone(s1) && realone(s2) # check for infinite or semi-infinite intervals
inf1, inf2 = isinf(s1), isinf(s2)
if inf1 || inf2
if inf1 && inf2 # x = t/(1-t^2) coordinate transformation
I, E = workfunc(BatchIntegrand((v, t) -> begin resize!(f.x, length(t));
f.f!(v, f.x .= u .* t ./ (1 .- t .* t)); v .*= (1 .+ t .* t) ./ (1 .- t .* t) .^ 2; end, f.y, tbuf, f.max_batch),
f.f!(v, f.x .= u .* t ./ (1 .- t .* t)); v .*= (1 .+ t .* t) ./ (1 .- t .* t) .^ 2; end, f.y, f.x, f.t, f.max_batch),
map(x -> isinf(x) ? (signbit(x) ? -one(x) : one(x)) : 2x / (oneunit(x)+hypot(oneunit(x),2x)), s),
t -> u * t / (1 - t^2))
return u * I, u * E
end
let (s0,si) = inf1 ? (s2,s1) : (s1,s2) # let is needed for JuliaLang/julia#15276
if si < zero(si) # x = s0 - t/(1-t)
I, E = workfunc(BatchIntegrand((v, t) -> begin resize!(f.x, length(t));
f.f!(v, f.x .= s0 .- u .* t ./ (1 .- t)); v ./= (1 .- t) .^ 2; end, f.y, tbuf, f.max_batch),
f.f!(v, f.x .= s0 .- u .* t ./ (1 .- t)); v ./= (1 .- t) .^ 2; end, f.y, f.x, f.t, f.max_batch),
reverse(map(x -> 1 / (1 + oneunit(x) / (s0 - x)), s)),
t -> s0 - u*t/(1-t))
return u * I, u * E
else # x = s0 + t/(1-t)
I, E = workfunc(BatchIntegrand((v, t) -> begin resize!(f.x, length(t));
f.f!(v, f.x .= s0 .+ u .* t ./ (1 .- t)); v ./= (1 .- t) .^ 2; end, f.y, tbuf, f.max_batch),
f.f!(v, f.x .= s0 .+ u .* t ./ (1 .- t)); v ./= (1 .- t) .^ 2; end, f.y, f.x, f.t, f.max_batch),
map(x -> 1 / (1 + oneunit(x) / (x - s0)), s),
t -> s0 + u*t/(1-t))
return u * I, u * E
Expand All @@ -174,7 +175,7 @@ function handle_infinities(workfunc, f::BatchIntegrand, s)
end
end
I, E = workfunc(BatchIntegrand((y, t) -> begin resize!(f.x, length(t));
f.f!(y, f.x .= u .* t); end, f.y, tbuf, f.max_batch),
f.f!(y, f.x .= u .* t); end, f.y, f.x, f.t, f.max_batch),
map(x -> x/oneunit(x), s),
identity)
return u * I, u * E
Expand All @@ -199,6 +200,6 @@ simultaneously. In particular, there are two differences from `quadgk`
"""
function quadgk(f::BatchIntegrand{Y,Nothing}, segs::T...; kws...) where {Y,T}
FT = float(T) # the gk points are floating-point
g = BatchIntegrand(f.f!, f.y, similar(f.x, FT), f.max_batch)
g = BatchIntegrand(f.f!, f.y, similar(f.x, FT), similar(f.x, typeof(float(one(FT)))), f.max_batch)
return quadgk(g, segs...; kws...)
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ end
end

# test constructors
ref = BatchIntegrand(f!, Float64[], Nothing[], typemax(Int))
ref = BatchIntegrand(f!, Float64[], Nothing[], nothing, typemax(Int))
for b in (
BatchIntegrand(f!, Float64[]),
BatchIntegrand(f!, Float64[], Nothing[]),
Expand Down

0 comments on commit abcc700

Please sign in to comment.