Skip to content

Commit

Permalink
refactor: throw error when autodiff is true for all except Quadrature…
Browse files Browse the repository at this point in the history
…Training
  • Loading branch information
sathvikbhagavan committed Jan 22, 2024
1 parent 607d4f1 commit 8a0dccc
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/ode_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,8 @@ end

function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, p, batch)
ts = tspan[1]:(strategy.dx):tspan[2]

# sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken
autodiff && throw(ArgumentError("autodiff not supported for GridTraining."))
function loss(θ, _)
if batch
sum(abs2, inner_loss(phi, f, autodiff, ts, θ, p))
Expand All @@ -287,6 +287,7 @@ end
function generate_loss(strategy::StochasticTraining, phi, f, autodiff::Bool, tspan, p,
batch)
# sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken
autodiff && throw(ArgumentError("autodiff not supported for StochasticTraining."))
function loss(θ, _)
ts = adapt(parameterless_type(θ),
[(tspan[2] - tspan[1]) * rand() + tspan[1] for i in 1:(strategy.points)])
Expand All @@ -302,6 +303,7 @@ end

function generate_loss(strategy::WeightedIntervalTraining, phi, f, autodiff::Bool, tspan, p,
batch)
autodiff && throw(ArgumentError("autodiff not supported for WeightedIntervalTraining."))
minT = tspan[1]
maxT = tspan[2]

Expand Down

0 comments on commit 8a0dccc

Please sign in to comment.