Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove duplicate code for collocation #180

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ version = "1.26.0"
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
Dierckx = "39dd38d3-220a-591b-8e3c-4c3a8c710a94"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
Vaibhavdixit02 marked this conversation as resolved.
Show resolved Hide resolved
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -21,6 +22,7 @@ SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Calculus = "0.5"
Dierckx = "0.4, 0.5"
DiffEqBase = "6"
DiffEqFlux = "1"
Distributions = "0.25"
ForwardDiff = "0.10"
LsqFit = "0.8, 0.9, 0.10, 0.11, 0.12"
Expand Down
2 changes: 1 addition & 1 deletion docs/src/methods/collocation_loss.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ but is much faster, and is a good method to try first to get in the general
"good parameter" region, to then finish using one of the other methods.

```julia
function two_stage_method(prob::DEProblem,tpoints,data;kernel= :Epanechnikov,
function two_stage_method(prob::DEProblem,tpoints,data;kernel= EpanechnikovKernel(),
loss_func = L2DistLoss,mpg_autodiff = false,
verbose = false,verbose_steps = 100)
```
5 changes: 2 additions & 3 deletions src/DiffEqParamEstim.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
module DiffEqParamEstim
using DiffEqBase, LsqFit, PenaltyFunctions,
RecursiveArrayTools, ForwardDiff, Calculus, Distributions,
LinearAlgebra, SciMLSensitivity, Dierckx,
SciMLBase
LinearAlgebra, SciMLSensitivity, Dierckx, DiffEqFlux, SciMLBase

import DiffEqFlux.CollocationKernel
import PreallocationTools
STANDARD_PROB_GENERATOR(prob, p) = remake(prob; u0 = eltype(p).(prob.u0), p = p)
function STANDARD_PROB_GENERATOR(prob::EnsembleProblem, p)
Expand All @@ -26,7 +26,6 @@ include("cost_functions.jl")
include("lm_fit.jl")
include("build_loss_objective.jl")
include("build_lsoptim_objective.jl")
include("kernels.jl")
include("two_stage_method.jl")
include("multiple_shooting_objective.jl")

Expand Down
86 changes: 0 additions & 86 deletions src/kernels.jl

This file was deleted.

66 changes: 20 additions & 46 deletions src/two_stage_method.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
export TwoStageCost, two_stage_method
export EpanechnikovKernel, UniformKernel, TriangularKernel, QuarticKernel
export TriweightKernel, TricubeKernel, GaussianKernel, CosineKernel
export LogisticKernel, SigmoidKernel, SilvermanKernel

struct TwoStageCost{F, F2, D} <: Function
cost_function::F
Expand All @@ -10,7 +13,6 @@ end
(f::TwoStageCost)(p) = f.cost_function(p)
(f::TwoStageCost)(p, g) = f.cost_function2(p, g)

decide_kernel(kernel::CollocationKernel) = kernel
function decide_kernel(kernel::Symbol)
if kernel == :Epanechnikov
return EpanechnikovKernel()
Expand All @@ -32,51 +34,19 @@ function decide_kernel(kernel::Symbol)
return LogisticKernel()
elseif kernel == :Sigmoid
return SigmoidKernel()
else
elseif kernel == :Silverman
return SilvermanKernel()
else
return error("Kernel name not recognized")
end
end

function construct_t1(t, tpoints)
Vaibhavdixit02 marked this conversation as resolved.
Show resolved Hide resolved
hcat(ones(eltype(tpoints), length(tpoints)), tpoints .- t)
end
function construct_t2(t, tpoints)
hcat(ones(eltype(tpoints), length(tpoints)), tpoints .- t, (tpoints .- t) .^ 2)
end
function construct_w(t, tpoints, h, kernel)
W = @. calckernel((kernel,), (tpoints - t) / h) / h
Diagonal(W)
end
function construct_estimated_solution_and_derivative!(data, kernel, tpoints)
_one = oneunit(first(data))
_zero = zero(first(data))
e1 = [_one; _zero]
e2 = [_zero; _one; _zero]
n = length(tpoints)
h = (n^(-1 / 5)) * (n^(-3 / 35)) * ((log(n))^(-1 / 16))

Wd = similar(data, n, size(data, 1))
WT1 = similar(data, n, 2)
WT2 = similar(data, n, 3)
x = map(tpoints) do _t
T1 = construct_t1(_t, tpoints)
T2 = construct_t2(_t, tpoints)
W = construct_w(_t, tpoints, h, kernel)
mul!(Wd, W, data')
mul!(WT1, W, T1)
mul!(WT2, W, T2)
(e2' * ((T2' * WT2) \ T2')) * Wd, (e1' * ((T1' * WT1) \ T1')) * Wd
end
estimated_derivative = reduce(hcat, transpose.(first.(x)))
estimated_solution = reduce(hcat, transpose.(last.(x)))
estimated_derivative, estimated_solution
end
function construct_iip_cost_function(f, du, preview_est_sol, preview_est_deriv, tpoints)
function (p)
_du = PreallocationTools.get_tmp(du, p)
vecdu = vec(_du)
cost = zero(first(p))
for i in 1:length(preview_est_sol)
for i in eachindex(preview_est_sol)
est_sol = preview_est_sol[i]
f(_du, est_sol, p, tpoints[i])
vecdu .= vec(preview_est_deriv[i]) .- vec(_du)
Expand All @@ -89,7 +59,7 @@ end
function construct_oop_cost_function(f, du, preview_est_sol, preview_est_deriv, tpoints)
function (p)
cost = zero(first(p))
for i in 1:length(preview_est_sol)
for i in eachindex(preview_est_sol)
est_sol = preview_est_sol[i]
_du = f(est_sol, p, tpoints[i])
cost += sum(abs2, vec(preview_est_deriv[i]) .- vec(_du))
Expand All @@ -102,23 +72,27 @@ get_chunksize(cs) = cs
get_chunksize(cs::Type{Val{CS}}) where {CS} = CS

function two_stage_method(prob::DiffEqBase.DEProblem, tpoints, data;
kernel = EpanechnikovKernel(),
kernel::Union{CollocationKernel, Symbol} = EpanechnikovKernel(),
loss_func = L2Loss, mpg_autodiff = false,
verbose = false, verbose_steps = 100,
autodiff_chunk = length(prob.p))
f = prob.f
kernel_function = decide_kernel(kernel)
estimated_derivative, estimated_solution = construct_estimated_solution_and_derivative!(data,
kernel_function,
tpoints)
if kernel isa Symbol
@warn "Passing kernels as Symbols will be deprecated"
Copy link
Member

@Vaibhavdixit02 Vaibhavdixit02 Oct 15, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use @deprecate here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to do it like that at first, but I ran into issues because Julia cannot dispatch on a keyword argument, i.e. it is impossible to have two methods, one with kernel restricted to Symbol types and one with kernel restricted to CollocationKernel, and then @deprecate the Symbol method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you can @depreacate here, it needs this special handling. But I also don't think we should be merging deprecations to this repo: it's somewhat in freeze mode because its interface does not extend well to more cases. In that case, major dependency changes and deprecations for an old repo not getting any major benefits is not a great feel. I think this PR should just hold until there's a good reason found for it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah without kwarg handling it's not obvious, I think the way to do it with @deprecate would be
(though it's not relevant anymore since we won't go ahead with the PR, but it might be useful later)

my_func(args...; kwarg_to_deprecate::Union{T1,T2}, kwargs...) = my_func(args..., kwarg_to_deprecate; kwargs...)

my_func(args..., kwarg_to_deprecate::T1; kwargs....) = do something...

@deprecate my_func(args..., kwarg_to_deprecate::T2; kwargs...) = do something....

kernel = decide_kernel(kernel)
end

# Step - 1

estimated_derivative, estimated_solution = collocate_data(data, tpoints, kernel)

# Step - 2

du = PreallocationTools.dualcache(similar(prob.u0), autodiff_chunk)
preview_est_sol = [@view estimated_solution[:, i]
for i in 1:size(estimated_solution, 2)]
for i in axes(estimated_solution, 2)]
preview_est_deriv = [@view estimated_derivative[:, i]
for i in 1:size(estimated_solution, 2)]
for i in axes(estimated_solution, 2)]
f = prob.f
if DiffEqBase.isinplace(prob)
cost_function = construct_iip_cost_function(f, du, preview_est_sol,
preview_est_deriv, tpoints)
Expand Down
4 changes: 3 additions & 1 deletion test/out_of_place_odes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ prob_oop = ODEProblem{false}(ff, u0, tspan, ps)
data = Array(solve(prob, Tsit5(), saveat = t))
ptest = ones(rc)

obj_ts = two_stage_method(prob, t, data; kernel = :Sigmoid)
obj_ts = two_stage_method(prob, t, data; kernel = SigmoidKernel())
@test obj_ts(ptest) ≈ 418.3400017500223^2
obj_ts = two_stage_method(prob_oop, t, data; kernel = SigmoidKernel())
@test obj_ts(ptest) ≈ 418.3400017500223^2
obj_ts = two_stage_method(prob_oop, t, data; kernel = :Sigmoid)
@test obj_ts(ptest) ≈ 418.3400017500223^2