# AirSeaFlux: bulkformulae Optimization Problem

The notebook `optim_and_enzyme.ipynb` includes several examples the use Enzyme and Optim to get adjoints and to minimize some cost function. The input in most cases is a vector, but the output of the function we are optimizing is scalar (a single term from the bulkformulae function). Here we are working to build a version that uses all outputs of bulkformulae.

In [3]:
if !isdefined(Main,:Enzyme)
    using Pkg; Pkg.add.(["ECCO","AirSeaFluxes","LinearAlgebra","Enzyme","Optim"])
end

using ECCO
import AirSeaFluxes: bulkformulae
using Enzyme, Optim
using LinearAlgebra

In [4]:
x0 = [300.0,0.001,1.0,10.0]

function J_bulkformulae(x::Vector{Float64})
    obs = [-3.1,2.1,5.5e-9,0.05]
    res = bulkformulae(x[1],x[2],x[3],x[4])
    y = [res[1],res[2],res[3],res[6]]
    
    J = norm(y-obs)^2
    return J
end

# J_bulkformulae(x0)

function J_ad!(dx2, x) 
    dx = zeros(size(x))
    Enzyme.autodiff(Reverse, J_bulkformulae, Duplicated(x, dx))
    dx2 .= dx
end

# for testing: evaluate the gradient at x0
dx2 = zeros(size(x0))
J_ad!(dx2,x0)

# optimization with the cost function and it's adjoint 
result=Optim.optimize(J_bulkformulae, J_ad!, x0, Optim.Options(show_trace=true))
x1=Optim.minimizer(result)

# check that tau at x1 is close to y_obs
y1 = bulkformulae(x1[1],x1[2],x1[3],x1[4])
(hl=y1.hl,hs=y1.hs,evap=y1.evap,tau=y1.tau)

Iter     Function value   Gradient norm 
     0     8.369224e-03     3.615157e+01


[33m[1m└ [22m[39m[90m@ Enzyme.Compiler ~/.julia/packages/GPUCompiler/2CW9L/src/utils.jl:59[39m


 * time: 0.006735086441040039
     1     7.535613e-03     2.459393e+01
 * time: 0.3312399387359619
     2     6.817075e-03     2.917685e-01
 * time: 0.3314950466156006
     3     1.650358e-03     2.663119e+00
 * time: 0.3316960334777832
     4     1.642103e-03     2.666215e-02
 * time: 0.3319358825683594
     5     1.642075e-03     6.924692e-05
 * time: 0.3321518898010254
     6     1.642075e-03     2.035970e-03
 * time: 0.3323829174041748
     7     1.628204e-03     1.349200e+00
 * time: 0.33255600929260254
     8     1.622217e-03     2.273159e+00
 * time: 0.33275389671325684
     9     1.591295e-03     7.702138e-01
 * time: 0.33296799659729004
    10     1.589195e-03     1.073982e+00
 * time: 0.3331279754638672
    11     1.583926e-03     1.247669e+00
 * time: 0.3332810401916504
    12     1.579176e-03     1.426435e+00
 * time: 0.3334341049194336
    13     1.572024e-03     1.600151e+00
 * time: 0.333604097366333
    14     1.562448e-03     1.176509e+00
 * time: 0.33374905586242676
 

(hl = -3.099999999999782, hs = 2.1000000000023693, evap = 1.240248049609835e-9, tau = 0.05000000000017541)