-
-
Notifications
You must be signed in to change notification settings - Fork 72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Simple Parameter Estimation Tutorial #305
Merged
Merged
Changes from 2 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
d385bf4
Add tutorial text
TorkelE 0154d6b
add_plots
TorkelE 0c9535c
reference DiffEqFlux and mentione final parameter set.
TorkelE bfa65e6
Separates the last section with subheading
TorkelE 785e6f7
very minor change to second sentence
TorkelE File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# Parameter Estimation | ||
The parameters of a model, generated by Catalyst, can be fitted using various packages available in the Julia ecosystem. Refer [here](https://diffeq.sciml.ai/stable/analysis/parameter_estimation/) for a more detailed description. Below follows a quick tutorial of how a parameter set for a model can be fitted to data. | ||
|
||
First, we fetch the required packages. | ||
```julia | ||
using OrdinaryDiffEq | ||
using DiffEqFlux, Flux | ||
using Catalyst | ||
``` | ||
|
||
Next, we declare our model. For our example, we will use the Brusselator, a simple oscillator. | ||
```julia | ||
brusselator = @reaction_network begin | ||
A, ∅ → X | ||
1, 2X + Y → 3X | ||
B, X → Y | ||
1, X → ∅ | ||
end A B | ||
p_real = [1., 2.] | ||
``` | ||
|
||
We simulate our model, and from the simulation generate sampled data points (with added noise), to which we will attempt to fit a parameter et. | ||
```julia | ||
u0 = [1.0, 1.0] | ||
tspan = (0.0, 30.0) | ||
|
||
sample_times = range(tspan[1],stop=tspan[2],length=100) | ||
prob = ODEProblem(brusselator, u0, tspan, p_real) | ||
sol_real = solve(prob, Rosenbrock23(), tstops=sample_times) | ||
|
||
sample_vals = [sol_real.u[findfirst(sol_real.t .>= ts)][var] * (1+(0.1rand()-0.05)) for var in 1:2, ts in sample_times]; | ||
``` | ||
|
||
We can plot the real solution, as well as the noisy samples. | ||
```julia | ||
using Plots | ||
plot(sol_real,size=(1200,400),label="",framestyle=:box,lw=3,color=[:darkblue :darkred]) | ||
plot!(sample_times,sample_vals',seriestype=:scatter,color=[:blue :red],label="") | ||
``` | ||
![parameter_estimation_plot1](../assets/parameter_estimation_plot1.svg) | ||
|
||
Next, we create an optimisation function. For a given initial estimate of the parameter values, p, this function will fit parameter values to our data samples. However, it will only do so on the interval [0,tend]. | ||
```julia | ||
function optimise_p(p_init,tend) | ||
function loss(p) | ||
sol = solve(remake(prob,tspan=(0.,tend),p=p), Rosenbrock23(), tstops=sample_times) | ||
vals = hcat(map(ts -> sol.u[findfirst(sol.t .>= ts)], sample_times[1:findlast(sample_times .<= tend)])...) | ||
loss = sum(abs2, vals .- sample_vals[:,1:size(vals)[2]]) | ||
return loss, sol | ||
end | ||
return DiffEqFlux.sciml_train(loss,p_init,ADAM(0.1),maxiters = 100) | ||
end | ||
``` | ||
|
||
Next, we will fit a parameter set to the data on the interval [0,10]. | ||
```julia | ||
p_estimate = optimise_p([5.,5.],10.).minimizer | ||
``` | ||
|
||
We can compare this to the real solution, as well as the sample data | ||
```julia | ||
sol_estimate = solve(remake(prob,tspan=(0.,10.),p=p_estimate), Rosenbrock23()) | ||
plot(sol_real,size=(1200,400),color=[:blue :red],framestyle=:box,lw=3,label=["X real" "Y real"],linealpha=0.2) | ||
plot!(sample_times,sample_vals',seriestype=:scatter,color=[:blue :red],label=["Samples of X" "Samples of Y"],alpha=0.4) | ||
plot!(sol_estimate,color=[:darkblue :darkred], linestyle=:dash,lw=3,label=["X estimated" "Y estimated"],xlimit=tspan) | ||
``` | ||
![parameter_estimation_plot2](../assets/parameter_estimation_plot2.svg) | ||
|
||
Next, we use this parameter estimation as the input to the next iteration of our fitting process, this time on the interval [0,20]. | ||
```julia | ||
p_estimate = optimise_p(p_estimate,20.).minimizer | ||
|
||
sol_estimate = solve(remake(prob,tspan=(0.,20.),p=p_estimate), Rosenbrock23()) | ||
plot(sol_real,size=(1200,400),color=[:blue :red],framestyle=:box,lw=3,label=["X real" "Y real"],linealpha=0.2) | ||
plot!(sample_times,sample_vals',seriestype=:scatter,color=[:blue :red],label=["Samples of X" "Samples of Y"],alpha=0.4) | ||
plot!(sol_estimate,color=[:darkblue :darkred], linestyle=:dash,lw=3,label=["X estimated" "Y estimated"],xlimit=tspan) | ||
``` | ||
![parameter_estimation_plot3](../assets/parameter_estimation_plot3.svg) | ||
|
||
Finally, we use this estimate as the input to fit a parameter set on the full interval of sampled data. | ||
```julia | ||
p_estimate = optimise_p(p_estimate,30.).minimizer | ||
|
||
sol_estimate = solve(remake(prob,tspan=(0.,30.),p=p_estimate), Rosenbrock23()) | ||
plot(sol_real,size=(1200,400),color=[:blue :red],framestyle=:box,lw=3,label=["X real" "Y real"],linealpha=0.2) | ||
plot!(sample_times,sample_vals',seriestype=:scatter,color=[:blue :red],label=["Samples of X" "Samples of Y"],alpha=0.4) | ||
plot!(sol_estimate,color=[:darkblue :darkred], linestyle=:dash,lw=3,label=["X estimated" "Y estimated"],xlimit=tspan) | ||
``` | ||
![parameter_estimation_plot4](../assets/parameter_estimation_plot4.svg) | ||
|
||
The reason we chose to fit the model on a smaller interval to begin with, and then extend the interval, is to avoid getting stuck in a local minimum. Here specifically, we chose our initial interval to be smaller than a full cycle of the oscillation. If we had chosen to fit a parameter set on the full interval immediately we would have received an inferior solution. | ||
```julia | ||
p_estimate = optimise_p([5.,5.],30.).minimizer | ||
|
||
sol_estimate = solve(remake(prob,tspan=(0.,30.),p=p_estimate), Rosenbrock23()) | ||
plot(sol_real,size=(1200,400),color=[:blue :red],framestyle=:box,lw=3,label=["X real" "Y real"],linealpha=0.2) | ||
plot!(sample_times,sample_vals',seriestype=:scatter,color=[:blue :red],label=["Samples of X" "Samples of Y"],alpha=0.4) | ||
plot!(sol_estimate,color=[:darkblue :darkred], linestyle=:dash,lw=3,label=["X estimated" "Y estimated"],xlimit=tspan) | ||
``` | ||
![parameter_estimation_plot5](../assets/parameter_estimation_plot5.svg) | ||
|
||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mention this tutorial uses [DiffEqFlux] point to docs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Along those lines perhaps also say that the
sciml_train
is using theADAM
gradient descent method?