## Old manual

In [None]:
w1, b1 = rand(rng, Float64, (5, 5)), rand(rng, Float64, 5)
w2, b2 = rand(rng, Float64, (3, 5)), rand(rng, Float64, 3)
w3, b3 = rand(rng, Float64, (1, 3)), rand(rng, Float64, 1)

In [None]:
w1, b1 = nn.layers[1].W, nn.layers[1].b
w2, b2 = nn.layers[2].W, nn.layers[2].b
w3, b3 = nn.layers[3].W, nn.layers[3].b

In [None]:
params_flux = params(model_flux)
w1, b1 = params_flux[1], params_flux[2]
w2, b2 = params_flux[3], params_flux[4]
w3, b3 = params_flux[5], params_flux[6]

In [None]:
nn_params = [w1, b1, w2, b2, w3, b3]
nn_params = [Float64.(param) for param in nn_params]
params_vec = vcat([vec(i) for i in nn_params]...);
shapes = [size(param) for param in nn_params]
idxs = vcat(0, cumsum(prod.(shapes)));
function extract(params_vec)
    return [reshape(params_vec[idxs[i]+1:idxs[i+1]], shapes[i]) for i in 1:6]
end

In [None]:
function model_manual(positions)
    return exp.(w3 * (w2 * (w1 * positions + b1) + b2) + b3)[1]
end

In [None]:
@time model_manual(positions)

In [None]:
function model_manual(positions, params_vec)
    w1, b1, w2, b2, w3, b3 = extract(params_vec)
    return exp.(w3 * (w2 * (w1 * positions + b1) + b2) + b3)[1]
end

In [None]:
@time model_manual(positions, params_vec)

## Gradient of manual model

In [None]:
config = ReverseDiff.GradientConfig(params_vec)
tape = ReverseDiff.GradientTape(params_vec -> model_manual(positions, params_vec), params_vec, config)
tape = ReverseDiff.compile(tape)
result = zero(params_vec);

In [None]:
function grad_manual(result, tape, params_vec)
    result = ReverseDiff.gradient!(result, tape, params_vec)
    return result
end;

In [None]:
@btime grad_manual_res = grad_manual(result, tape, params_vec);

In [None]:
extract(grad_manual_res)[1]

In [None]:
positions = rand(rng, Float64, 5)

## Hessian

In [None]:
const CACHE = Dict{DataType, Any}()
function inner(y, positions::Array{T}) where {T<:Real}
    if !haskey(CACHE, T)
        config = ReverseDiff.GradientConfig(positions)
        tape = ReverseDiff.compile(ReverseDiff.GradientTape(model_manual, positions, config))
        CACHE[T] = tape
    end
    tape = CACHE[T]
    return ReverseDiff.gradient!(y, tape, positions)
end

function kineticMX(positions, inner, config, result, y)
    ForwardDiff.jacobian!(result, inner, y, positions, config)
    return result #/ SimpleG(positions)
end

In [None]:
y = similar(positions)
config = ForwardDiff.JacobianConfig(inner, y, positions);
result = zeros(length(positions), length(positions));

In [None]:
@btime kineticMX(positions, inner, config, result, y)