Skip to content

Commit

Permalink
Merge pull request #781 from AstitvaAggarwal/Bpinn_pde
Browse files Browse the repository at this point in the history
BPINN solver Docs(Manual and tutorial)
  • Loading branch information
xtalax committed Jan 17, 2024
2 parents 0687aaf + ee5c1df commit ba5dbf9
Show file tree
Hide file tree
Showing 13 changed files with 272 additions and 52 deletions.
5 changes: 5 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
[deps]
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Integrals = "de52edbc-65ea-441a-8357-d3a637375a31"
IntegralsCubature = "c31f79ba-6e32-46d4-a52f-182a8ac42a54"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
NeuralPDE = "315f7962-48a3-4962-8226-d0f33b1235f0"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
Expand All @@ -20,6 +23,7 @@ Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[compat]
AdvancedHMC = "0.5"
DiffEqBase = "6.106"
Documenter = "1"
DomainSets = "0.6"
Expand All @@ -28,6 +32,7 @@ Integrals = "3.3"
IntegralsCubature = "=0.2.2"
Lux = "0.4, 0.5"
ModelingToolkit = "8.33"
MonteCarloMeasurements = "1"
NeuralPDE = "5.3"
Optimization = "3.9"
OptimizationOptimJL = "0.1"
Expand Down
2 changes: 2 additions & 0 deletions docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pages = ["index.md",
#"examples/nnrode_example.md", # currently incorrect
],
"PDE PINN Tutorials" => Any["Introduction to NeuralPDE for PDEs" => "tutorials/pdesystem.md",
"Bayesian PINNs for PDEs" => "tutorials/low_level_2.md",
"Using GPUs" => "tutorials/gpu.md",
"Defining Systems of PDEs" => "tutorials/systems.md",
"Imposing Constraints" => "tutorials/constraints.md",
Expand All @@ -21,6 +22,7 @@ pages = ["index.md",
"examples/nonlinear_hyperbolic.md"],
"Manual" => Any["manual/ode.md",
"manual/pinns.md",
"manual/bpinns.md",
"manual/training_strategies.md",
"manual/adaptive_losses.md",
"manual/logging.md",
Expand Down
22 changes: 22 additions & 0 deletions docs/src/manual/bpinns.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# `BayesianPINN` Discretizer for PDESystems

Using the Bayesian PINN solvers, we can solve general nonlinear PDEs, ODEs and also simultaneously perform parameter estimation on them.

Note: The BPINN PDE solver also works for ODEs defined using ModelingToolkit, [ModelingToolkit.jl PDESystem documentation](https://docs.sciml.ai/ModelingToolkit/stable/systems/PDESystem/). Despite this, the ODE specific BPINN solver `BNNODE` [refer](https://docs.sciml.ai/NeuralPDE/dev/manual/ode/#NeuralPDE.BNNODE) exists and uses `NeuralPDE.ahmc_bayesian_pinn_ode` at a lower level.

# `BayesianPINN` Discretizer for PDESystems and lower level Bayesian PINN Solver calls for PDEs and ODEs.

```@docs
NeuralPDE.BayesianPINN
NeuralPDE.ahmc_bayesian_pinn_ode
NeuralPDE.ahmc_bayesian_pinn_pde
```

## `symbolic_discretize` for `BayesianPINN` and lower level interface.

```@docs
SciMLBase.symbolic_discretize(::PDESystem, ::NeuralPDE.AbstractPINN)
NeuralPDE.BPINNstats
NeuralPDE.BPINNsolution
```

4 changes: 2 additions & 2 deletions docs/src/manual/pinns.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ NeuralPDE.Phi
SciMLBase.discretize(::PDESystem, ::NeuralPDE.PhysicsInformedNN)
```

## `symbolic_discretize` and the lower-level interface
## `symbolic_discretize` for `PhysicsInformedNN` and the lower-level interface

```@docs
SciMLBase.symbolic_discretize(::PDESystem, ::NeuralPDE.PhysicsInformedNN)
SciMLBase.symbolic_discretize(::PDESystem, ::NeuralPDE.AbstractPINN)
NeuralPDE.PINNRepresentation
NeuralPDE.PINNLossFunctions
```
2 changes: 1 addition & 1 deletion docs/src/tutorials/low_level.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Investigating `symbolic_discretize` with the 1-D Burgers' Equation
# Investigating `symbolic_discretize` with the `PhysicsInformedNN` Discretizer for the 1-D Burgers' Equation

Let's consider the Burgers' equation:

Expand Down
143 changes: 143 additions & 0 deletions docs/src/tutorials/low_level_2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Using `ahmc_bayesian_pinn_pde` with the `BayesianPINN` Discretizer for the Kuramoto–Sivashinsky equation

Consider the Kuramoto–Sivashinsky equation:

```math
∂_t u(x, t) + u(x, t) ∂_x u(x, t) + \alpha ∂^2_x u(x, t) + \beta ∂^3_x u(x, t) + \gamma ∂^4_x u(x, t) = 0 \, ,
```

where $\alpha = \gamma = 1$ and $\beta = 4$. The exact solution is:

```math
u_e(x, t) = 11 + 15 \tanh \theta - 15 \tanh^2 \theta - 15 \tanh^3 \theta \, ,
```

where $\theta = t - x/2$ and with initial and boundary conditions:

```math
\begin{align*}
u( x, 0) &= u_e( x, 0) \, ,\\
u( 10, t) &= u_e( 10, t) \, ,\\
u(-10, t) &= u_e(-10, t) \, ,\\
∂_x u( 10, t) &= ∂_x u_e( 10, t) \, ,\\
∂_x u(-10, t) &= ∂_x u_e(-10, t) \, .
\end{align*}
```

With Bayesian Physics-Informed Neural Networks, here is an example of using `BayesianPINN` discretization with `ahmc_bayesian_pinn_pde` :

```@example low_level_2
using NeuralPDE, Flux, Lux, ModelingToolkit, LinearAlgebra, AdvancedHMC
import ModelingToolkit: Interval, infimum, supremum, Distributions
using Plots, MonteCarloMeasurements
@parameters x, t, α
@variables u(..)
Dt = Differential(t)
Dx = Differential(x)
Dx2 = Differential(x)^2
Dx3 = Differential(x)^3
Dx4 = Differential(x)^4
# α = 1
β = 4
γ = 1
eq = Dt(u(x, t)) + u(x, t) * Dx(u(x, t)) + α * Dx2(u(x, t)) + β * Dx3(u(x, t)) + γ * Dx4(u(x, t)) ~ 0
u_analytic(x, t; z = -x / 2 + t) = 11 + 15 * tanh(z) - 15 * tanh(z)^2 - 15 * tanh(z)^3
du(x, t; z = -x / 2 + t) = 15 / 2 * (tanh(z) + 1) * (3 * tanh(z) - 1) * sech(z)^2
bcs = [u(x, 0) ~ u_analytic(x, 0),
u(-10, t) ~ u_analytic(-10, t),
u(10, t) ~ u_analytic(10, t),
Dx(u(-10, t)) ~ du(-10, t),
Dx(u(10, t)) ~ du(10, t)]
# Space and time domains
domains = [x ∈ Interval(-10.0, 10.0),
t ∈ Interval(0.0, 1.0)]
# Discretization
dx = 0.4;
dt = 0.2;
# Function to compute analytical solution at a specific point (x, t)
function u_analytic_point(x, t)
z = -x / 2 + t
return 11 + 15 * tanh(z) - 15 * tanh(z)^2 - 15 * tanh(z)^3
end
# Function to generate the dataset matrix
function generate_dataset_matrix(domains, dx, dt)
x_values = -10:dx:10
t_values = 0.0:dt:1.0
dataset = []
for t in t_values
for x in x_values
u_value = u_analytic_point(x, t)
push!(dataset, [u_value, x, t])
end
end
return vcat([data' for data in dataset]...)
end
datasetpde = [generate_dataset_matrix(domains, dx, dt)]
# noise to dataset
noisydataset = deepcopy(datasetpde)
noisydataset[1][:, 1] = noisydataset[1][:, 1] .+ randn(size(noisydataset[1][:, 1])) .* 5 / 100 .*
noisydataset[1][:, 1]
```

Plotting dataset, added noise is set at 5%.
```@example low_level_2
plot(datasetpde[1][:, 2], datasetpde[1][:, 1], title="Dataset from Analytical Solution")
plot!(noisydataset[1][:, 2], noisydataset[1][:, 1])
```

```@example low_level_2
# Neural network
chain = Lux.Chain(Lux.Dense(2, 8, Lux.tanh),
Lux.Dense(8, 8, Lux.tanh),
Lux.Dense(8, 1))
discretization = NeuralPDE.BayesianPINN([chain],
GridTraining([dx, dt]), param_estim = true, dataset = [noisydataset, nothing])
@named pde_system = PDESystem(eq,
bcs,
domains,
[x, t],
[u(x, t)],
[α],
defaults = Dict([α => 0.5]))
sol1 = ahmc_bayesian_pinn_pde(pde_system,
discretization;
draw_samples = 100, Kernel = AdvancedHMC.NUTS(0.8),
bcstd = [0.2, 0.2, 0.2, 0.2, 0.2],
phystd = [1.0], l2std = [0.05], param = [Distributions.LogNormal(0.5, 2)],
priorsNNw = (0.0, 10.0),
saveats = [1 / 100.0, 1 / 100.0], progress = true)
```

And some analysis:

```@example low_level_2
phi = discretization.phi[1]
xs, ts = [infimum(d.domain):dx:supremum(d.domain) for (d, dx) in zip(domains, [dx / 10, dt])]
u_predict = [[first(pmean(phi([x, t], sol1.estimated_nn_params[1]))) for x in xs]
for t in ts]
u_real = [[u_analytic(x, t) for x in xs] for t in ts]
diff_u = [[abs(u_analytic(x, t) - first(pmean(phi([x, t], sol1.estimated_nn_params[1]))))
for x in xs]
for t in ts]
p1 = plot(xs, u_predict, title = "predict")
p2 = plot(xs, u_real, title = "analytic")
p3 = plot(xs, diff_u, title = "error")
plot(p1, p2, p3)
```
2 changes: 1 addition & 1 deletion src/BPINN_ode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ of the physics-informed neural network which is used as a solver for a standard
* `Kernel`: Choice of MCMC Sampling Algorithm. Defaults to `AdvancedHMC.HMC`
## Keyword Arguments
(refer ahmc_bayesian_pinn_ode() keyword arguments.)
(refer `NeuralPDE.ahmc_bayesian_pinn_ode` keyword arguments.)
## Example
Expand Down
2 changes: 1 addition & 1 deletion src/NeuralPDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,6 @@ export NNODE, TerminalPDEProblem, NNPDEHan, NNPDENS, NNRODE,
AbstractAdaptiveLoss, NonAdaptiveLoss, GradientScaleAdaptiveLoss,
MiniMaxAdaptiveLoss, LogOptions,
ahmc_bayesian_pinn_ode, BNNODE, ahmc_bayesian_pinn_pde, vector_to_parameters,
BPINNsolution
BPINNsolution, BayesianPINN

end # module
48 changes: 47 additions & 1 deletion src/PDE_BPINN.jl
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,52 @@ function inference(samples, pinnrep, saveats, numensemble, ℓπ)
end
end

# priors: pdf for W,b + pdf for ODE params
"""
```julia
ahmc_bayesian_pinn_pde(pde_system, discretization;
draw_samples = 1000,
bcstd = [0.01], l2std = [0.05],
phystd = [0.05], priorsNNw = (0.0, 2.0),
param = [], nchains = 1, Kernel = HMC(0.1, 30),
Adaptorkwargs = (Adaptor = StanHMCAdaptor,
Metric = DiagEuclideanMetric, targetacceptancerate = 0.8),
Integratorkwargs = (Integrator = Leapfrog,), saveats = [1 / 10.0],
numensemble = floor(Int, draw_samples / 3), progress = false, verbose = false)
```
## NOTES
* Dataset is required for accurate Parameter estimation + solving equations.
* Returned solution is a BPINNsolution consisting of Ensemble solution, estimated PDE and NN parameters
for chosen `saveats` grid spacing and last n = `numensemble` samples in Chain. the complete set of samples
in the MCMC chain is returned as `fullsolution`, refer `BPINNsolution` for more details.
## Positional Arguments
* `pde_system`: ModelingToolkit defined PDE equation or system of equations.
* `discretization`: BayesianPINN discretization for the given pde_system, Neural Network and training strategy.
## Keyword Arguments
* `draw_samples`: number of samples to be drawn in the MCMC algorithms (warmup samples are ~2/3 of draw samples)
* `bcstd`: Vector of standard deviations of BPINN prediction against Initial/Boundary Condition equations.
* `l2std`: Vector of standard deviations of BPINN prediction against L2 losses/Dataset for each dependant variable of interest.
* `phystd`: Vector of standard deviations of BPINN prediction against Chosen Underlying PDE equations.
* `priorsNNw`: Tuple of (mean, std) for BPINN Network parameters. Weights and Biases of BPINN are Normal Distributions by default.
* `param`: Vector of chosen PDE's parameter's Distributions in case of Inverse problems.
* `nchains`: number of chains you want to sample
# AdvancedHMC.jl is still developing convenience structs so might need changes on new releases.
* `Kernel`: Choice of MCMC Sampling Algorithm object HMC/NUTS/HMCDA (AdvancedHMC.jl implemenations ).
* `Adaptorkwargs`: `Adaptor`, `Metric`, `targetacceptancerate`. Refer: https://turinglang.org/AdvancedHMC.jl/stable/
Note: Target percentage(in decimal) of iterations in which the proposals are accepted (0.8 by default)
* `Integratorkwargs`: `Integrator`, `jitter_rate`, `tempering_rate`. Refer: https://turinglang.org/AdvancedHMC.jl/stable/
* `saveats`: Grid spacing for each independant variable for evaluation of ensemble solution, estimated parameters.
* `numensemble`: Number of last samples to take for creation of ensemble solution, estimated parameters.
* `progress`: controls whether to show the progress meter or not.
* `verbose`: controls the verbosity. (Sample call args in AHMC)
"""

"""
priors: pdf for W,b + pdf for PDE params
"""
function ahmc_bayesian_pinn_pde(pde_system, discretization;
draw_samples = 1000,
bcstd = [0.01], l2std = [0.05],
Expand Down Expand Up @@ -369,6 +414,7 @@ function ahmc_bayesian_pinn_pde(pde_system, discretization;
#ode parameter estimation
nparameters = length(initial_θ)
ninv = length(param)
# add init_params for NN params
priors = [
MvNormal(priorsNNw[1] * ones(nparameters),
LinearAlgebra.Diagonal(abs2.(priorsNNw[2] .* ones(nparameters)))),
Expand Down
22 changes: 8 additions & 14 deletions src/advancedHMC_MCMC.jl
Original file line number Diff line number Diff line change
Expand Up @@ -436,40 +436,34 @@ Incase you are only solving the Equations for solution, do not provide dataset
## Keyword Arguments
* `strategy`: The training strategy used to choose the points for the evaluations. By default GridTraining is used with given physdt discretization.
* `dataset`: Vector containing Vectors of corresponding u,t values
* `init_params`: intial parameter values for BPINN (ideally for multiple chains different initializations preferred)
* `nchains`: number of chains you want to sample (random initialisation of params by default)
* `nchains`: number of chains you want to sample
* `draw_samples`: number of samples to be drawn in the MCMC algorithms (warmup samples are ~2/3 of draw samples)
* `l2std`: standard deviation of BPINN predicition against L2 losses/Dataset
* `phystd`: standard deviation of BPINN predicition against Chosen Underlying ODE System
* `priorsNNw`: Vector of [mean, std] for BPINN parameter. Weights and Biases of BPINN are Normal Distributions by default
* `l2std`: standard deviation of BPINN prediction against L2 losses/Dataset
* `phystd`: standard deviation of BPINN prediction against Chosen Underlying ODE System
* `priorsNNw`: Tuple of (mean, std) for BPINN Network parameters. Weights and Biases of BPINN are Normal Distributions by default.
* `param`: Vector of chosen ODE parameters Distributions in case of Inverse problems.
* `autodiff`: Boolean Value for choice of Derivative Backend(default is numerical)
* `physdt`: Timestep for approximating ODE in it's Time domain. (1/20.0 by default)
# AdvancedHMC.jl is still developing convenience structs so might need changes on new releases.
* `Kernel`: Choice of MCMC Sampling Algorithm (AdvancedHMC.jl implemenations HMC/NUTS/HMCDA)
* `Integratorkwargs`: A NamedTuple containing the chosen integrator and its keyword Arguments, as follows :
* `Integrator`: https://turinglang.org/AdvancedHMC.jl/stable/
* `jitter_rate`: https://turinglang.org/AdvancedHMC.jl/stable/
* `tempering_rate`: https://turinglang.org/AdvancedHMC.jl/stable/
* `Adaptorkwargs`: A NamedTuple containing the chosen Adaptor, it's Metric and targetacceptancerate, as follows :
* `Adaptor`: https://turinglang.org/AdvancedHMC.jl/stable/
* `Metric`: https://turinglang.org/AdvancedHMC.jl/stable/
* `targetacceptancerate`: Target percentage(in decimal) of iterations in which the proposals were accepted(0.8 by default)
* `Integratorkwargs`: `Integrator`, `jitter_rate`, `tempering_rate`. Refer: https://turinglang.org/AdvancedHMC.jl/stable/
* `Adaptorkwargs`: `Adaptor`, `Metric`, `targetacceptancerate`. Refer: https://turinglang.org/AdvancedHMC.jl/stable/
Note: Target percentage(in decimal) of iterations in which the proposals are accepted (0.8 by default)
* `MCMCargs`: A NamedTuple containing all the chosen MCMC kernel's(HMC/NUTS/HMCDA) Arguments, as follows :
* `n_leapfrog`: number of leapfrog steps for HMC
* `δ`: target acceptance probability for NUTS and HMCDA
* `λ`: target trajectory length for HMCDA
* `max_depth`: Maximum doubling tree depth (NUTS)
* `Δ_max`: Maximum divergence during doubling tree (NUTS)
Refer: https://turinglang.org/AdvancedHMC.jl/stable/
* `progress`: controls whether to show the progress meter or not.
* `verbose`: controls the verbosity. (Sample call args in AHMC)
"""

"""
dataset would be (x̂,t)
priors: pdf for W,b + pdf for ODE params
"""
function ahmc_bayesian_pinn_ode(prob::DiffEqBase.ODEProblem, chain;
Expand Down
7 changes: 4 additions & 3 deletions src/discretize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -389,14 +389,15 @@ end

"""
```julia
prob = symbolic_discretize(pde_system::PDESystem, discretization::PhysicsInformedNN)
prob = symbolic_discretize(pde_system::PDESystem, discretization::AbstractPINN)
```
`symbolic_discretize` is the lower level interface to `discretize` for inspecting internals.
It transforms a symbolic description of a ModelingToolkit-defined `PDESystem` into a
`PINNRepresentation` which holds the pieces required to build an `OptimizationProblem`
for [Optimization.jl](https://docs.sciml.ai/Optimization/stable) whose solution is the solution
to the PDE.
for [Optimization.jl](https://docs.sciml.ai/Optimization/stable) or a Likelihood Function
used for HMC based Posterior Sampling Algorithms [AdvancedHMC.jl](https://turinglang.org/AdvancedHMC.jl/stable/)
which is later optimized upon to give Solution or the Solution Distribution of the PDE.
For more information, see `discretize` and `PINNRepresentation`.
"""
Expand Down
Loading

0 comments on commit ba5dbf9

Please sign in to comment.