Skip to content

The Partial Differential Equation (PDE) Constrained Optimization example #683

@YichengDWu

Description

@YichengDWu

I rewrote the example with MethodOfLines. I believe it should be a more desired approach.

using Plots
using DifferentialEquations, Optimization, OptimizationPolyalgorithms, Zygote, OptimizationOptimJL
using DomainSets, MethodOfLines, ModelingToolkit, SciMLSensitivity

# Problem setup parameters:
Lx = 10.0
x_grid  = 0.0:0.01:Lx
dx = 0.01
#Nx = size(x)

#u0 = exp.(-(x.-3.0).^2) # I.C

## Problem Parameters
p        = [1.0,1.0]    # True solution parameters
#xtrs     = [dx,Nx]      # Extra parameters
dt       = 0.40*dx^2    # CFL condition
t0, tMax = 0.0 ,1000*dt
tspan    = (t0,tMax)
t_grid = t0:dt:tMax
#t        = t0:dt:tMax;

##
@parameters t, x, a0, a1
@variables u(..)
Dt = Differential(t)
Dxx = Differential(x)^2

heateq = [Dt(u(t,x)) ~ 2.0 * a0 * u(t,x) + a1 * Dxx(u(t,x))]

bcs = [u(t,0) ~ 0.0, u(t,Lx) ~ 0.0, u(0,x) ~ exp(-(x-3.0)^2)]
domains = [t  Interval(t0,tMax), x  Interval(0.0,Lx)]

@named pdesys = PDESystem(heateq, bcs, domains,[t,x],[u(t,x)], [a0=>0.1, a1=>0.2])
discretization = MOLFiniteDifference([x => dx], t)
prob = discretize(pdesys, discretization)

# Testing Solver on linear PDE
sol = solve(prob,Tsit5(), p =p, dt = dt, saveat = t_grid);

plot(x_grid, [0; sol.u[1]; 0], lw=3, label="t0", size=(800,500))
plot!(x_grid, [0; sol.u[end]; 0],lw=3, ls=:dash, label="tMax")

ps  = [0.1, 0.2];   # Initial guess for model parameters
function predict(θ)
    Array(solve(prob,Tsit5(),p=θ,dt=dt,saveat=t_grid))
end

## Defining Loss function
function loss(θ)
    pred = predict(θ)
    l = predict(θ)  - sol
    return sum(abs2, l), pred # Mean squared error
end

l,pred   = loss(ps)
size(pred), size(sol), size(t) # Checking sizes

LOSS  = []                              # Loss accumulator
PRED  = []                              # prediction accumulator
PARS  = []                              # parameters accumulator

callback = function (θ,l,pred) #callback function to observe training
  display(l)
  append!(PRED, [pred])
  append!(LOSS, l)
  append!(PARS, [θ])
  false
end

callback(ps,loss(ps)...) # Testing callback function

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x,p)->loss(x), adtype)

optprob = Optimization.OptimizationProblem(optf, ps)
res = Optimization.solve(optprob, PolyOpt(), callback = callback)
@show res.u

# Let see prediction vs. Truth
plot(sol[:,end],  lw=3, label="Truth", size=(800,500))
plot!(PRED[end][:,end], lw=3, ls=:dash, label="Prediction")
 # returns [1.0000000000000162, 1.0000000000000044]

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions