From f7f618bd4c82e715a1f4ee8801e563ad0055434f Mon Sep 17 00:00:00 2001 From: marcobonici Date: Thu, 16 Oct 2025 23:41:07 -0400 Subject: [PATCH 1/4] Added custom rrules for IntegralProblem --- ext/IntegralsZygoteExt.jl | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/ext/IntegralsZygoteExt.jl b/ext/IntegralsZygoteExt.jl index a6b270e..739d2fa 100644 --- a/ext/IntegralsZygoteExt.jl +++ b/ext/IntegralsZygoteExt.jl @@ -12,6 +12,28 @@ ChainRulesCore.@non_differentiable Integrals.substitute_f(args...) # use ∂f/ ChainRulesCore.@non_differentiable Integrals.substitute_v(args...) # TODO for ∂f/∂u ChainRulesCore.@non_differentiable Integrals.substitute_bv(args...) # TODO for ∂f/∂u +# Add custom rrule for IntegralProblem to avoid segfault +function ChainRulesCore.rrule(::Type{<:IntegralProblem}, f, domain, p; kwargs...) + prob = IntegralProblem(f, domain, p; kwargs...) + function IntegralProblem_pullback(Δ) + # For the constructor, we only need to propagate gradients for the parameters + # The function f and domain are treated as non-differentiable structural components + return NoTangent(), NoTangent(), NoTangent(), Δ.p + end + return prob, IntegralProblem_pullback +end + +# Handle both the inner constructor call patterns that might occur +function ChainRulesCore.rrule(::Type{IntegralProblem{iip}}, f, domain, p; kwargs...) where {iip} + prob = IntegralProblem{iip}(f, domain, p; kwargs...) + function IntegralProblem_iip_pullback(Δ) + # Extract the parameter gradient from the tangent + dp = hasproperty(Δ, :p) ? Δ.p : NoTangent() + return NoTangent(), NoTangent(), NoTangent(), dp + end + return prob, IntegralProblem_iip_pullback +end + # TODO move this adjoint to SciMLBase function ChainRulesCore.rrule( ::typeof(SciMLBase.build_solution), prob::IntegralProblem, alg, u, resid; kwargs...) From 2a312bddef90ed00597849adb4208d5970b05ed2 Mon Sep 17 00:00:00 2001 From: marcobonici Date: Thu, 16 Oct 2025 23:47:45 -0400 Subject: [PATCH 2/4] Added rrules for the integralproblem domain --- ext/IntegralsZygoteExt.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ext/IntegralsZygoteExt.jl b/ext/IntegralsZygoteExt.jl index 739d2fa..736c135 100644 --- a/ext/IntegralsZygoteExt.jl +++ b/ext/IntegralsZygoteExt.jl @@ -16,9 +16,9 @@ ChainRulesCore.@non_differentiable Integrals.substitute_bv(args...) # TODO for function ChainRulesCore.rrule(::Type{<:IntegralProblem}, f, domain, p; kwargs...) prob = IntegralProblem(f, domain, p; kwargs...) function IntegralProblem_pullback(Δ) - # For the constructor, we only need to propagate gradients for the parameters - # The function f and domain are treated as non-differentiable structural components - return NoTangent(), NoTangent(), NoTangent(), Δ.p + ddomain = hasproperty(Δ, :domain) ? Δ.domain : NoTangent() + dp = hasproperty(Δ, :p) ? Δ.p : NoTangent() + return NoTangent(), NoTangent(), ddomain, dp end return prob, IntegralProblem_pullback end @@ -27,9 +27,9 @@ end function ChainRulesCore.rrule(::Type{IntegralProblem{iip}}, f, domain, p; kwargs...) where {iip} prob = IntegralProblem{iip}(f, domain, p; kwargs...) function IntegralProblem_iip_pullback(Δ) - # Extract the parameter gradient from the tangent + ddomain = hasproperty(Δ, :domain) ? Δ.domain : NoTangent() dp = hasproperty(Δ, :p) ? Δ.p : NoTangent() - return NoTangent(), NoTangent(), NoTangent(), dp + return NoTangent(), NoTangent(), ddomain, dp end return prob, IntegralProblem_iip_pullback end From eb9bf21f63cb11fe937d51dff6a6016916346943 Mon Sep 17 00:00:00 2001 From: marcobonici Date: Thu, 16 Oct 2025 23:48:48 -0400 Subject: [PATCH 3/4] Removed comments --- ext/IntegralsZygoteExt.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/ext/IntegralsZygoteExt.jl b/ext/IntegralsZygoteExt.jl index 736c135..cbed683 100644 --- a/ext/IntegralsZygoteExt.jl +++ b/ext/IntegralsZygoteExt.jl @@ -12,7 +12,6 @@ ChainRulesCore.@non_differentiable Integrals.substitute_f(args...) # use ∂f/ ChainRulesCore.@non_differentiable Integrals.substitute_v(args...) # TODO for ∂f/∂u ChainRulesCore.@non_differentiable Integrals.substitute_bv(args...) # TODO for ∂f/∂u -# Add custom rrule for IntegralProblem to avoid segfault function ChainRulesCore.rrule(::Type{<:IntegralProblem}, f, domain, p; kwargs...) prob = IntegralProblem(f, domain, p; kwargs...) function IntegralProblem_pullback(Δ) @@ -23,7 +22,6 @@ function ChainRulesCore.rrule(::Type{<:IntegralProblem}, f, domain, p; kwargs... return prob, IntegralProblem_pullback end -# Handle both the inner constructor call patterns that might occur function ChainRulesCore.rrule(::Type{IntegralProblem{iip}}, f, domain, p; kwargs...) where {iip} prob = IntegralProblem{iip}(f, domain, p; kwargs...) function IntegralProblem_iip_pullback(Δ) From 49f215b8e533cec526b6984b36fb2e8fab2342dc Mon Sep 17 00:00:00 2001 From: marcobonici Date: Fri, 17 Oct 2025 00:05:03 -0400 Subject: [PATCH 4/4] Added formatting --- ext/IntegralsZygoteExt.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ext/IntegralsZygoteExt.jl b/ext/IntegralsZygoteExt.jl index cbed683..2e77adf 100644 --- a/ext/IntegralsZygoteExt.jl +++ b/ext/IntegralsZygoteExt.jl @@ -22,7 +22,8 @@ function ChainRulesCore.rrule(::Type{<:IntegralProblem}, f, domain, p; kwargs... return prob, IntegralProblem_pullback end -function ChainRulesCore.rrule(::Type{IntegralProblem{iip}}, f, domain, p; kwargs...) where {iip} +function ChainRulesCore.rrule( + ::Type{IntegralProblem{iip}}, f, domain, p; kwargs...) where {iip} prob = IntegralProblem{iip}(f, domain, p; kwargs...) function IntegralProblem_iip_pullback(Δ) ddomain = hasproperty(Δ, :domain) ? Δ.domain : NoTangent()