In [1]:
using Revise
using BilevelTrajOpt

In [6]:
a = rand(4)
A = diagm(a)

4×4 Array{Float64,2}:
 0.408196  0.0       0.0       0.0     
 0.0       0.315306  0.0       0.0     
 0.0       0.0       0.480072  0.0     
 0.0       0.0       0.0       0.648007

In [56]:
# test the solution
function f(x)
#     return 0.
#     return x[1] * x[4] * (x[1] + x[2] + x[3]) + x[3]
#     return x'*A*x 
    return sum(x)
end

function h(x)
#     return [40. - (x[1]^2 + x[2]^2 + x[3]^2 + x[4]^2)]
#     return [x[1] + x[2] - 1., x[3] - 1., x[4]*x[1]]
    return [6. - x[1],x[1]*x[2],x[1]*.5+x[3]-4.]
end

function g(x)
#     return [25. - (x[1]*x[2]*x[3]*x[4])]
#     return vcat([25. - (x[1]*x[2]*x[3]*x[4])], 1. - x, x - 5.)
#     return vcat([1. - x[2]],-x)
#     return vcat(-x[2:3],5.-x[4])    
#     return [0.]
#     return vcat([x[1] - .1], 5. - x[4])
    return vcat(-x[2:3],5.-x[4],10.-x[4],17.-x[4],5-x[1]*x[4])
end

x0 = zeros(4)
λ0 = zeros(length(h(x0)))
μ0 = zeros(length(g(x0)))
c0 = 1.

# x0 = x_sol
# λ0 = λ_sol
# μ0 = μ_sol
# c0 = c_sol

x_sol, λ_sol, μ_sol, c_sol = auglag_solve(x0,λ0,μ0,f,h,g,c0=c0)

num_h = length(λ0)
num_g = length(μ0)
x_sol_ip = ip_solve(x0,f,h,g,num_h,num_g)

# x_sol_known = [1.000, 4.743, 3.821, 1.379]

display(x_sol)
display(x_sol_ip)
# display(x_sol_known)

4-element Array{Float64,1}:
  6.0        
  6.98608e-14
  1.0        
 17.0        

4-element Array{Float64,1}:
  6.0        
  2.16408e-24
  1.0        
 17.0        

Solve_Succeeded


In [None]:
x = x_sol
display(f(x))
display(h(x))
display(g(x))

x = x_sol_ip
display(f(x))
display(h(x))
display(g(x))

In [59]:
# test the gradient
function solve_prob(z)
    function f(x)
#         return x'*A*x + z[1]
        return sum(x)
    end

    function h(x)
#         return [x[1] + x[2] - 1., x[3] - z[2], x[4]*x[1]]
        return [z[1] - x[1],x[1]*x[2],x[1]*.5+x[3]-z[2]]
    end

    function g(x)
#         return vcat([x[1] - z[4]], z[3] - x[4])
        return vcat(-x[2:3],z[3]-x[4],10.-x[4],z[4]-x[4],5-x[1]*x[4])
    end

    x0 = zeros(4)
    λ0 = zeros(length(h(x0)))
    μ0 = zeros(length(g(x0)))

#     x_sol_ip = ip_solve(x0,f,h,g,length(h(x0)),length(g(x0)))
#     display(x_sol_ip)
    
    x_sol, λ_sol, μ_sol, c_sol = auglag_solve(x0,λ0,μ0,f,h,g)
    display(x_sol)

    x_sol
end

solve_prob (generic function with 1 method)

In [61]:
# z0 = [2.,-1.,0.,1.]
z0 = [6.,4.,5.,17.]

sol = solve_prob(z0)

# autodiff 
J_auto = ForwardDiff.jacobian(solve_prob,z0)

# # numerical
ϵ = 1e-8
J_num = zeros(size(J_auto))
for i = 1:length(z0)
    δ = zeros(length(z0))
    δ[i] = ϵ
    J_num[:,i] = (solve_prob(z0 + δ) - sol)/ϵ
end

display("----")
display(sol)
display(J_auto)
display(J_num)

4-element Array{Float64,1}:
  6.0       
  2.7361e-14
  1.0       
 17.0       

4-element Array{ForwardDiff.Dual{ForwardDiff.Tag{#solve_prob,Float64},Float64,4},1}:
 Dual{ForwardDiff.Tag{#solve_prob,Float64}}(6.0,1.0,1.67627e-11,1.96769e-16,-7.80404e-9)                  
 Dual{ForwardDiff.Tag{#solve_prob,Float64}}(5.97395e-14,-5.34526e-10,7.15195e-13,3.89951e-14,-1.91572e-10)
 Dual{ForwardDiff.Tag{#solve_prob,Float64}}(1.0,-0.5,1.0,-8.76288e-15,3.54326e-9)                         
 Dual{ForwardDiff.Tag{#solve_prob,Float64}}(17.0,-1.83056e-8,-2.38944e-9,6.61031e-9,1.0)                  

4-element Array{Float64,1}:
  6.0       
 -6.4048e-14
  1.0       
 17.0       

4-element Array{Float64,1}:
  6.0        
 -2.66781e-14
  1.0        
 17.0        

4-element Array{Float64,1}:
  6.0        
  1.89441e-14
  1.0        
 17.0        

4-element Array{Float64,1}:
  6.0        
  3.70634e-14
  1.0        
 17.0        

"----"

4-element Array{Float64,1}:
  6.0       
  2.7361e-14
  1.0       
 17.0       

4×4 Array{Float64,2}:
  1.0           1.67627e-11   1.96769e-16  -7.80404e-9 
 -5.34526e-10   7.15195e-13   3.89951e-14  -1.91572e-10
 -0.5           1.0          -8.76288e-15   3.54326e-9 
 -1.83056e-8   -2.38944e-9    6.61031e-9    1.0        

4×4 Array{Float64,2}:
  0.999985     1.45661e-5  -1.97176e-5  -4.3876e-5 
 -9.1409e-6   -5.40391e-6  -8.41684e-7   9.70242e-7
 -0.500037     0.999909    -1.85407e-5  -4.03455e-5
  4.36984e-5   4.40536e-5   5.47118e-5   1.00009   

