In [None]:
import Pkg
Pkg.activate(".")
Pkg.status()

In [None]:
import ForwardDiff as FD
import HeatEquation as HE
import ImplicitAD as IAD
import LinearAlgebra as LA
import ProgressMeter as PM
import ReverseDiff as RD

In [None]:
import Optim

In [None]:
using Plots
using Plots.PlotMeasures

In [None]:
function HE.convert_kappa(::Type{T}, kappa::Real) where T<:Real
    return kappa
end

function HE.build_2d_heat_csc(::Type{T}, kappa, dt, Nx, Ny) where T<:Real
    return HE.build_2d_heat_csc(Float64, kappa, dt, Nx, Ny)
end

function HE.gridpoint(dim::I, k::I, N::I, ::Type{T}) where {I<:Integer, T<:RD.TrackedReal}
    return HE.gridpoint(dim, k, N, RD.valtype(T))
end

function HE.gridpoint(dim::I, k::I, N::I, ::Type{T}) where {I<:Integer, T<:FD.Dual}
    return HE.gridpoint(dim, k, N, FD.valtype(T))
end

function HE.my_linear_solve!(
    u_sol::AbstractVector,
    A::Any,
    A_fact,
    b::AbstractVector,
)
    u_sol .= IAD.implicit_linear(A, b; Af=A_fact)
    return u_sol
end

In [None]:
function my_source(tk::T, xi::T, yj::T, a::S, b::S, r::S, h::S) where {T,S}
    # return T((xi - a)^2 + (yj - b)^2 <= r^2)
    # val = T((xi - a)^2 + (yj - b)^2 <= r^2)
    dr = sqrt((xi - a)^2 + (yj - b)^2)
    # h = 5.0
    val = max(0.0, h*(1.0 - dr / r))
    # val = exp(0.5*(-(xi - a)^2 - (yj - b)^2) / r^2)
    return val
end

In [None]:
function eval_space_l2(params::Dict, k::Integer)

    prob = params[:prob]

    kappa = prob.kappa
    f = prob.f

    dt = prob.dt
    Nx = prob.Nx
    Ny = prob.Ny

    A = prob.A
    A_fact = prob.A_fact
    uk = prob.uk
    rhs = prob.rhs

    HE.heat_step(A, A_fact, uk, rhs, kappa, f, k, dt, Nx, Ny)

    dx = HE.gridsize(1, Nx, Float64)
    dy = HE.gridsize(2, Ny, Float64)
    l2_norm = dx*dy*LA.norm(uk, 2)^2

    return l2_norm

end

function eval_spacetime_l2(
    params::Dict;
    progress::Bool=false,
)

    prob = params[:prob]

    k = 0
    dt = prob.dt
    Nt = prob.Nt

    if progress
        pm = PM.Progress(Nt)
    end

    l2_norm = 0.0 # Assumes u0 = zeros(Nx,Ny)
    l2_space = 0.0

    while k < Nt
        l2_space = eval_space_l2(params, k)
        l2_norm += dt * l2_space
        k += 1
        if progress
            PM.next!(pm)
        end
    end

    if progress
        PM.finish!(pm)
    end

    return l2_norm

end

In [None]:
function split_state_and_control(x::AbstractVector, p::Dict)
    nc = p[:ncontrol]
    theta = @view x[1:nc]
    u0 = @view x[nc+1:end]
    return (theta, u0)
end

function fuse_state_and_control(x, u0)
    return vcat(x, u0)
end

function evaluate_f(params::Dict)
    N = params[:N]
    f = params[:prob].f
    f0 = zeros(N^2)
    for i in 1:N, j in 1:N
        xi = HE.gridpoint(1, i, N, Float64)
        for j in 1:N
            k = HE.linear_index(i, j, N, N)
            yj = HE.gridpoint(2, j, N, Float64)
            f0[k] = f(0.0, xi, yj)
        end
    end
    return f0
end 

function eval_space_l2_residual(y, x, p)

    prob = p[:prob]
    (theta, u0) = split_state_and_control(x, p)
    u1 = y

    A = prob.A
    dt = prob.dt
    f0 = evaluate_f(p)

    # res = 0.5 * LA.norm(A*u1 .- u0 .- dt.*f0)^2
    res = LA.norm(A*u1 .- u0 .- dt.*f0)

    return res

end

function eval_space_l2_drdy(y, x, p)

    prob = p[:prob]
    (_, u0) = split_state_and_control(x, p)
    u1 = y

    A = prob.A
    dt = prob.dt
    v0 = 0.5 .* (u0 .+ dt.*evaluate_f(p))
    # v0 = 0.5 .* u0

    @show u0
    @show evaluate_f(p)
    @show v0

    dr = transpose(A)*A*u1 .- transpose(A)*v0 .- A*v0

    return dr

end

In [None]:
function simulate_heat(x, p)
    (a, b, r, h) = x
    f(t, x, y) = my_source(t, x, y, a, b, r, h)
    (kappa, tf, dt, N) = p
    u0 = zeros(eltype(x), N, N)
    if eltype(x) <: AbstractFloat
        prob = HE.heat_setup_cpu(u0, kappa, f, tf, dt, N, N, -1, :csc)
    elseif eltype(x) <: FD.Dual
        prob = heat_setup_fd_cpu(u0, kappa, f, tf, dt, N, N, -1, :csc)
    else
        prob = heat_setup_rd_cpu(u0, kappa, f, tf, dt, N, N, -1, :csc)
    end
    HE.heat_loop(prob, nothing; progress=false)
    return prob.uk
end

In [None]:
function heat_setup_fd_cpu(
    u0::Matrix,
    kappa,
    interior::Function,
    tf::R,
    dt::R,
    Nx::I,
    Ny::I,
    save_rate::I,
    format::Symbol,
) where {I<:Integer, R<:Real}

    T = FD.valtype(eltype(u0))
    tf = convert(T, tf)
    dt = convert(T, dt)
    kappa = HE.convert_kappa(T, kappa)
    save_rate = convert(I, save_rate)

    return HE.heat_setup(u0, kappa, interior, tf, dt, Nx, Ny, save_rate, format)

end
    
function heat_setup_rd_cpu(
    u0::Matrix,
    kappa,
    interior::Function,
    tf::R,
    dt::R,
    Nx::I,
    Ny::I,
    save_rate::I,
    format::Symbol,
) where {I<:Integer, R<:Real}

    T = RD.valtype(eltype(u0))
    tf = convert(T, tf)
    dt = convert(T, dt)
    kappa = HE.convert_kappa(T, kappa)
    save_rate = convert(I, save_rate)

    return HE.heat_setup(u0, kappa, interior, tf, dt, Nx, Ny, save_rate, format)

end

function setup_heat(x, p)
    (a, b, r, h) = x
    f(t, x, y) = my_source(t, x, y, a, b, r, h)
    kappa = p[:kappa]
    tf = p[:tf]
    dt = p[:dt]
    N = p[:N]
    u0 = zeros(eltype(x), N, N)
    if eltype(x) <: AbstractFloat
        prob = HE.heat_setup_cpu(u0, kappa, f, tf, dt, N, N, -1, :csc)
    elseif eltype(x) <: FD.Dual
        prob = heat_setup_fd_cpu(u0, kappa, f, tf, dt, N, N, -1, :csc)
    else
        prob = heat_setup_rd_cpu(u0, kappa, f, tf, dt, N, N, -1, :csc)
    end
    params = copy(p)
    params[:prob] = prob
    return params
end

function cost_x(x, p)
    # return 0.5*LA.norm(x,2)^2 + 0.25*LA.norm(x,4)^4 - log(x[3]) - log(x[4])
    return 0.5*LA.norm(x,2)^2 - log(x[3]) - log(x[4])
end

function cost_u(x, p; progress=true)
    params = setup_heat(x, p)
    return 0.5*eval_spacetime_l2(params; progress=progress)
end

function cost_u_svd(x, p; progress=true)
    params = setup_heat(x, p)
    return 0.5*eval_spacetime_l2_svd(params; progress=progress)
end

function cost(cu, cx)
    a = -1e0
    b = 1e0
    return a*cu + b*cx
end

function obj(x, p; progress=false)
    cx = cost_x(x, p)
    cu = cost_u(x, p; progress=progress)
    return cost(cu, cx)
end

function obj_svd(x, p; progress=false)
    cx = cost_x(x, p)
    cu = cost_u_svd(x, p; progress=progress)
    return cost(cu, cx)
end

In [None]:
N = 3
dt = 5e-1
# tf = 100*dt
tf = 1.0
kappa = 1.0
# u0 = zeros(N,N)
params = Dict{Symbol,Any}(:kappa => kappa, :tf => tf, :dt => dt, :N => N)
x0 = [0.9, -0.2, 0.5, 1.0]
params[:ncontrol] = length(x0)
# x0 = [0.0, -0.0, 0.5, 1.0]
;

In [None]:
obj(x0, params)

In [None]:
cost_u(x0, params)

In [None]:
cost_x(x0, params)

In [None]:
p = setup_heat(x0, params)
prob = p[:prob]
A = prob.A
A_fact = prob.A_fact
u0 = prob.u0
# u0 .= ones(N,N)
u1 = copy(u0)
HE.heat_step(A, A_fact, u1, prob.rhs, prob.kappa, prob.f, 0, prob.dt, N, N)
eval_space_l2_residual(u1[:], fuse_state_and_control(x0, u0[:]), p)

In [None]:
u1

In [None]:
eval_space_l2_drdy(u1[:], fuse_state_and_control(x0, u0[:]), p)

In [None]:
dr = eval_space_l2_drdy(u1[:], fuse_state_and_control(x0, u0[:]), p)
reshape(dr, N, N)

In [None]:
eval_space_l2_residual(u1[:], fuse_state_and_control(x0, u0[:]), p)

In [None]:
f0 = eval_space_l2_residual(u1[:], fuse_state_and_control(x0, u0[:]), p)
du = 1e-3
for k in 1:9
    du1 = copy(u1)
    du1[k] += du
    f1 = eval_space_l2_residual(du1[:], fuse_state_and_control(x0, u0[:]), p)
    @show (f1 - f0) / du
end

In [None]:
FD.gradient(y->eval_space_l2_residual(y, fuse_state_and_control(x0, u0[:]), p), u1[:])

In [None]:
@time FD.gradient(x->obj(x, params), x0)

In [None]:
@time RD.gradient(x->obj(x, params), x0)

In [None]:
gtp = RD.GradientTape(x->obj(x, params), x0)

In [None]:
gtp.tape |> length

In [None]:
gtp.tape

# Optimize

In [None]:
tol = 1e-5
my_options = Optim.Options(
    g_abstol=tol,
    # g_reltol=tol,
    outer_g_abstol=tol,
    # outer_g_reltol=tol,
    store_trace=true,
    extended_trace=true,
    show_trace=true
)
lb = [-1.0, -1.0, 0.0, 0.0]
ub = [1.0, 1.0, Inf, Inf]
my_params = (kappa, tf, dt, N)
# my_svd_params = (kappa, tf, dt, N, 1e-3)

In [None]:
my_objective(x) = obj(x, my_params)

res = Optim.optimize(
    my_objective,
    lb,
    ub,
    x0,
    Optim.Fminbox(Optim.BFGS()),
    my_options;
    # autodiff = :forward, # uses ForwardDiff.jl
)
@show Optim.converged(res)
@show Optim.minimum(res)
;

In [None]:
@show Optim.converged(res)
@show Optim.minimum(res)
@show Optim.minimizer(res)
x_sol = Optim.minimizer(res)
;

In [None]:
cost_x(x_sol, params)

In [None]:
cost_u(x_sol, params)

In [None]:
@time FD.gradient(x->obj(x, params), x_sol)

# ReverseDiff Optimize

In [None]:
my_objective(x) = obj(x, my_params)
function my_obj_grad(g, x)
    g .= RD.gradient(my_objective, x)
    return
end

res = Optim.optimize(
    my_objective,
    my_obj_grad,
    lb,
    ub,
    x0,
    Optim.Fminbox(Optim.BFGS()),
    my_options;
    # autodiff = :forward, # uses ForwardDiff.jl
)
@show Optim.converged(res)
@show Optim.minimum(res)
;

# Plots

In [None]:
function split_trace_variables(my_trace)

    n = length(my_trace)
    a = zeros(n)
    b = zeros(n)
    r = zeros(n)
    h = zeros(n)

    for ll in 1:n
        x = my_trace[ll].metadata["x"]
        a[ll] = x[1]
        b[ll] = x[2]
        r[ll] = x[3]
        h[ll] = x[4]
    end
    
    return (n, a, b, r, h)

end

In [None]:
function make_variable_plot(optim_trace, iter)
    (niter, a, b, r, h) = split_trace_variables(optim_trace)
    @assert(iter < niter)
    ymin = floor(min(minimum(a), minimum(b), minimum(r), minimum(h)))
    ymax = ceil(max(maximum(a), maximum(b), maximum(r), maximum(h)))
    p = plot(xticks=0:2:niter-1, xrange=(0, niter-1), yrange=(ymin, ymax), legend=:topleft)
    plot!(p, 0:iter, a[1:iter+1], label="a")
    plot!(p, 0:iter, b[1:iter+1], label="b")
    plot!(p, 0:iter, r[1:iter+1], label="r")
    plot!(p, 0:iter, h[1:iter+1], label="h")
    return p
end

function make_residual_plot(optim_trace, iter, tol)
    niter = length(optim_trace)
    res = getfield.(optim_trace, :g_norm)
    rmax = ceil(maximum(res))
    @assert(iter < niter)
    p = plot(xticks=0:2:niter-1, 
        xrange=(0, niter-1), yrange=(1e-1*tol, rmax), 
        yscale=:log10, legend=false
    )
    plot!(p, 0:iter, res[1:iter+1])
    return p
end

function make_source_plot(x0, N, hmax)

    (a,b) = HE.endpoints(1, Float64)
    xgrid = (b - a) .* (1:N) ./ (N + 1) .+ a
    (a,b) = HE.endpoints(2, Float64)
    ygrid = (b - a) .* (1:N) ./ (N + 1) .+ a
    
    my_heat_source = zeros(N, N)
    for j in 1:N, i in 1:N
        xi = HE.gridpoint(1, i, N, Float64)
        yj = HE.gridpoint(1, j, N, Float64)
        my_heat_source[j,i] = my_source(0.0, xi, yj, x0...)
    end

    p = plot(clim=(0.0, hmax))
    return heatmap!(p, xgrid, ygrid, my_heat_source)

end

function make_source_plot(optim_trace, iter, N, hmax)
    return make_source_plot(res.trace[iter+1].metadata["x"], N, hmax)
end

function make_stationary_plot(x0, N, params, umax)

    (a,b) = HE.endpoints(1, Float64)
    xgrid = (b - a) .* (1:N) ./ (N + 1) .+ a
    (a,b) = HE.endpoints(2, Float64)
    ygrid = (b - a) .* (1:N) ./ (N + 1) .+ a
    
    u_sol = simulate_heat(x0, params);
    # u_sol = HE.run_heat_cpu(zeros(N,N), 

    return heatmap(xgrid, ygrid, u_sol', clim=(0.0, umax))

end

function make_stationary_plot(optim_trace, iter, N, params, umax)
    return make_stationary_plot(res.trace[iter+1].metadata["x"], N, params, umax)
end

In [None]:
function make_plot_group(optim_trace, iter, params, tol)

    (kappa, tf, dt, N) = params

    (niter, a, b, r, h) = split_trace_variables(res.trace)
    hmax = ceil(maximum(h); digits=2)

    u_sol = simulate_heat(res.trace[end].metadata["x"], my_params)
    umax = ceil(maximum(u_sol); digits=2)

    # my_layout = @layout([
    #     a{0.5w} [grid(2,1)]
    # ])

    vp = make_variable_plot(optim_trace, iter)
    rp = make_residual_plot(optim_trace, iter, tol)
    srp = make_source_plot(optim_trace, iter, N, hmax)
    stp = make_stationary_plot(optim_trace, iter, N, params, umax)

    p = plot(
        vp, rp, srp, stp,
        # layout=my_layout,
        layout=(2,2),
        size=(1050,800),
        suptitle="Iteration: $(iter)",
        left_margin=[3mm 0mm],
        right_margin=[3mm 3mm 3mm],
        bottom_margin=[3mm 3mm],
    )

    return p

end

In [None]:
function make_gif(optim_trace, gif_name::AbstractString, params, tol; fps::Int=2)
        
    ani = @animate for ll in 1:length(optim_trace)
        make_plot_group(optim_trace, ll - 1, params, tol)
    end

    return gif(ani, gif_name * ".gif", fps=fps)

end

In [None]:
make_gif(res.trace, "test", my_params, tol)