Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 32 additions & 16 deletions ext/IntegralsCubatureExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,25 @@ function Integrals.__solvebp_call(prob::IntegralProblem,
end
if mid isa Number
if alg isa CubatureJLh
val, err = Cubature.hquadrature_v(_f, lb, ub;
val,
err = Cubature.hquadrature_v(_f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
else
val, err = Cubature.pquadrature_v(_f, lb, ub;
val,
err = Cubature.pquadrature_v(_f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
end
else
if alg isa CubatureJLh
val, err = Cubature.hcubature_v(_f, lb, ub;
val,
err = Cubature.hcubature_v(_f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
else
val, err = Cubature.pcubature_v(_f, lb, ub;
val,
err = Cubature.pcubature_v(_f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
end
Expand All @@ -73,21 +77,25 @@ function Integrals.__solvebp_call(prob::IntegralProblem,
end
if mid isa Number
if alg isa CubatureJLh
val_, err = Cubature.hquadrature_v(fdim, _f, lb, ub;
val_,
err = Cubature.hquadrature_v(fdim, _f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters, error_norm = alg.error_norm)
else
val_, err = Cubature.pquadrature_v(fdim, _f, lb, ub;
val_,
err = Cubature.pquadrature_v(fdim, _f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters, error_norm = alg.error_norm)
end
else
if alg isa CubatureJLh
val_, err = Cubature.hcubature_v(fdim, _f, lb, ub;
val_,
err = Cubature.hcubature_v(fdim, _f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters, error_norm = alg.error_norm)
else
val_, err = Cubature.pcubature_v(fdim, _f, lb, ub;
val_,
err = Cubature.pcubature_v(fdim, _f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters, error_norm = alg.error_norm)
end
Expand All @@ -102,21 +110,25 @@ function Integrals.__solvebp_call(prob::IntegralProblem,
_f = u -> f(u, p)
if lb isa Number
if alg isa CubatureJLh
val, err = Cubature.hquadrature(_f, lb, ub;
val,
err = Cubature.hquadrature(_f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
else
val, err = Cubature.pquadrature(_f, lb, ub;
val,
err = Cubature.pquadrature(_f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
end
else
if alg isa CubatureJLh
val, err = Cubature.hcubature(_f, lb, ub;
val,
err = Cubature.hcubature(_f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
else
val, err = Cubature.pcubature(_f, lb, ub;
val,
err = Cubature.pcubature(_f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters)
end
Expand All @@ -133,21 +145,25 @@ function Integrals.__solvebp_call(prob::IntegralProblem,
end
if mid isa Number
if alg isa CubatureJLh
val_, err = Cubature.hquadrature(fdim, _f, lb, ub;
val_,
err = Cubature.hquadrature(fdim, _f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters, error_norm = alg.error_norm)
else
val_, err = Cubature.pquadrature(fdim, _f, lb, ub;
val_,
err = Cubature.pquadrature(fdim, _f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters, error_norm = alg.error_norm)
end
else
if alg isa CubatureJLh
val_, err = Cubature.hcubature(fdim, _f, lb, ub;
val_,
err = Cubature.hcubature(fdim, _f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters, error_norm = alg.error_norm)
else
val_, err = Cubature.pcubature(fdim, _f, lb, ub;
val_,
err = Cubature.pcubature(fdim, _f, lb, ub;
reltol = reltol, abstol = abstol,
maxevals = maxiters, error_norm = alg.error_norm)
end
Expand Down
3 changes: 2 additions & 1 deletion ext/IntegralsMCIntegrationExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ function Integrals.__solvebp_call(prob::IntegralProblem, alg::VEGASMC, sensealg,
res = integrate(_f; var, dof, inplace = isinplace(prob), type = eltype(prototype),
solver = :vegasmc, niter = maxiters, verbose = -2, print = -2, alg.kws...)
# the package itself is not type-stable
out::typeof(prototype), err::typeof(prototype), chi2 = if prototype isa Number
out::typeof(prototype), err::typeof(prototype),
chi2 = if prototype isa Number
map(only, (res.mean, res.stdev, res.chi2))
else
map(a -> reshape(a, size(prototype)), (res.mean, res.stdev, res.chi2))
Expand Down
29 changes: 19 additions & 10 deletions src/Integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ abstract type QuadSensitivityAlg end
Wrapper for custom vector-Jacobian product functions in automatic differentiation.

# Fields
- `vjp::V`: The vector-Jacobian product function

- `vjp::V`: The vector-Jacobian product function
"""
struct ReCallVJP{V}
vjp::V
Expand All @@ -53,7 +54,8 @@ struct ZygoteVJP <: IntegralVJP end
Uses ReverseDiff.jl for vector-Jacobian products in automatic differentiation of integrals.

# Fields
- `compile::Bool`: Whether to compile the tape for better performance

- `compile::Bool`: Whether to compile the tape for better performance
"""
struct ReverseDiffVJP <: IntegralVJP
compile::Bool
Expand Down Expand Up @@ -81,7 +83,8 @@ const KWARGERROR_MESSAGE = """
Exception thrown when unrecognized keyword arguments are passed to `solve`.

# Fields
- `kwargs`: The keyword arguments that were passed

- `kwargs`: The keyword arguments that were passed
"""
struct CommonKwargError <: Exception
kwargs::Any
Expand Down Expand Up @@ -185,7 +188,8 @@ function __solvebp_call(cache::IntegralCache, alg::QuadGKJL, sensealg, domain, p
return
end
end
val, err = quadgk(_f, lb, ub, segbuf = cache.cacheval, maxevals = maxiters,
val,
err = quadgk(_f, lb, ub, segbuf = cache.cacheval, maxevals = maxiters,
rtol = reltol, atol = abstol, order = alg.order, norm = alg.norm)
else
prototype = f(typeof(mid)[], p)
Expand All @@ -198,19 +202,22 @@ function __solvebp_call(cache::IntegralCache, alg::QuadGKJL, sensealg, domain, p
return
end
end
val, err = quadgk(_f, lb, ub, segbuf = cache.cacheval, maxevals = maxiters,
val,
err = quadgk(_f, lb, ub, segbuf = cache.cacheval, maxevals = maxiters,
rtol = reltol, atol = abstol, order = alg.order, norm = alg.norm)
end
else
if isinplace(f)
result = f.integrand_prototype * mid # result may have different units than prototype
_f = (y, u) -> f(y, u, p)
val, err = quadgk!(_f, result, lb, ub, segbuf = cache.cacheval,
val,
err = quadgk!(_f, result, lb, ub, segbuf = cache.cacheval,
maxevals = maxiters,
rtol = reltol, atol = abstol, order = alg.order, norm = alg.norm)
else
_f = u -> f(u, p)
val, err = quadgk(_f, lb, ub, segbuf = cache.cacheval, maxevals = maxiters,
val,
err = quadgk(_f, lb, ub, segbuf = cache.cacheval, maxevals = maxiters,
rtol = reltol, atol = abstol, order = alg.order, norm = alg.norm)
end
end
Expand Down Expand Up @@ -244,7 +251,8 @@ function __solvebp_call(cache::IntegralCache, alg::HCubatureJL, sensealg, domain
end
end

val, err = if lb isa Number
val,
err = if lb isa Number
hquadrature(_f, lb, ub;
rtol = reltol, atol = abstol, buffer = cache.cacheval,
maxevals = maxiters, norm = alg.norm, initdiv = alg.initdiv)
Expand Down Expand Up @@ -307,8 +315,9 @@ function __solvebp_call(prob::IntegralProblem, alg::VEGAS, sensealg, domain, p;
out = vegas(_f, lb, ub, rtol = reltol, atol = abstol,
maxiter = maxiters, nbins = alg.nbins, debug = alg.debug,
ncalls = ncalls, batch = prob.f isa BatchIntegralFunction)
val, err, chi = out isa Tuple ? out :
(out.integral_estimate, out.standard_deviation, out.chi_squared_average)
val, err,
chi = out isa Tuple ? out :
(out.integral_estimate, out.standard_deviation, out.chi_squared_average)
SciMLBase.build_solution(prob, alg, val, err, chi = chi, retcode = ReturnCode.Success)
end

Expand Down
8 changes: 5 additions & 3 deletions src/sampled.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ abstract type AbstractWeights end
Abstract type for quadrature weights with uniform (equally-spaced) sampling points.

Implementations must have:
- field `n` for the number of points
- field `h` for the step size between points

- field `n` for the number of points
- field `h` for the step size between points
"""
abstract type UniformWeights <: AbstractWeights end
@inline Base.iterate(w::UniformWeights) = (0 == w.n) ? nothing : (w[1], 1)
Expand All @@ -27,7 +28,8 @@ Base.size(w::UniformWeights) = (length(w),)
Abstract type for quadrature weights with non-uniform (arbitrarily-spaced) sampling points.

Implementations must have:
- field `x` containing the sampling points

- field `x` containing the sampling points
"""
abstract type NonuniformWeights <: AbstractWeights end
@inline Base.iterate(w::NonuniformWeights) = (0 == length(w.x)) ? nothing :
Expand Down
8 changes: 5 additions & 3 deletions src/simpsons.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
Quadrature weights for Simpson's composite 1/3-3/8 rule with uniformly spaced points.

# Fields
- `n::Int`: Number of points
- `h::T`: Step size between points

- `n::Int`: Number of points
- `h::T`: Step size between points
"""
struct SimpsonUniformWeights{T} <: UniformWeights
n::Int
Expand All @@ -29,7 +30,8 @@ end
Quadrature weights for Simpson's composite 1/3 rule with non-uniformly spaced points.

# Fields
- `x::X`: Array of sampling points

- `x::X`: Array of sampling points
"""
struct SimpsonNonuniformWeights{X <: AbstractArray} <: NonuniformWeights
x::X
Expand Down
8 changes: 5 additions & 3 deletions src/trapezoidal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
Quadrature weights for the trapezoidal rule with uniformly spaced points.

# Fields
- `n::Int`: Number of points
- `h::T`: Step size between points

- `n::Int`: Number of points
- `h::T`: Step size between points
"""
struct TrapezoidalUniformWeights{T} <: UniformWeights
n::Int
Expand All @@ -22,7 +23,8 @@ end
Quadrature weights for the trapezoidal rule with non-uniformly spaced points.

# Fields
- `x::X`: Array of sampling points

- `x::X`: Array of sampling points
"""
struct TrapezoidalNonuniformWeights{X <: AbstractArray} <: NonuniformWeights
x::X
Expand Down
Loading
Loading