From 6bc18675646f5a766b531fab5455a5809bafbf1b Mon Sep 17 00:00:00 2001 From: fpacaud Date: Mon, 30 Aug 2021 17:09:53 -0500 Subject: [PATCH] support DenseKKTSystem on GPUs * add dedicated GPU kernels for KKT operations * add proper tests for DenseKKTSystem on GPU * move build_qp_dense function in MadNLPTests --- lib/MadNLPGPU/Project.toml | 9 +- lib/MadNLPGPU/src/MadNLPGPU.jl | 12 ++- lib/MadNLPGPU/src/kernels.jl | 123 +++++++++++++++++++++++++++ lib/MadNLPGPU/src/lapackgpu.jl | 20 ++--- lib/MadNLPGPU/test/densekkt_gpu.jl | 47 +++++++++++ lib/MadNLPGPU/test/runtests.jl | 5 +- lib/MadNLPTests/Project.toml | 5 +- lib/MadNLPTests/src/MadNLPTests.jl | 124 ++++++++++++++++++++++++++- src/interiorpointsolver.jl | 45 +++++----- src/kktsystem.jl | 22 +++-- test/madnlp_dense.jl | 129 ++--------------------------- 11 files changed, 371 insertions(+), 170 deletions(-) create mode 100644 lib/MadNLPGPU/src/kernels.jl create mode 100644 lib/MadNLPGPU/test/densekkt_gpu.jl diff --git a/lib/MadNLPGPU/Project.toml b/lib/MadNLPGPU/Project.toml index 0d1efd187..7b2bb0a15 100644 --- a/lib/MadNLPGPU/Project.toml +++ b/lib/MadNLPGPU/Project.toml @@ -5,16 +5,21 @@ version = "0.1.0" [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +CUDAKernels = "72cfdca4-0801-4ab0-bf6a-d52aa10adc57" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MadNLP = "2621e9c9-9eb4-46b1-8089-e8c72242dfb6" [compat] CUDA = "~2,~3" +CUDAKernels = "0.3.0" +KernelAbstractions = "0.7" MadNLP = "~0.2" julia = "1.5" [extras] -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" MadNLPTests = "b52a2a03-04ab-4a5f-9698-6a2deff93217" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test","MadNLPTests"] \ No newline at end of file +test = ["Test", "MadNLPTests"] diff --git a/lib/MadNLPGPU/src/MadNLPGPU.jl b/lib/MadNLPGPU/src/MadNLPGPU.jl index 1af44a2d7..46d3f6b51 100644 --- a/lib/MadNLPGPU/src/MadNLPGPU.jl +++ b/lib/MadNLPGPU/src/MadNLPGPU.jl @@ -1,12 +1,22 @@ module MadNLPGPU -import CUDA: CUBLAS, CUSOLVER, CuVector, CuMatrix, toolkit_version, R_64F, has_cuda +# CUDA +import CUDA: CUBLAS, CUSOLVER, CuVector, CuMatrix, CuArray, toolkit_version, R_64F, has_cuda +# Kernels +using KernelAbstractions +using CUDAKernels +using LinearAlgebra +using MadNLP + import MadNLP: @kwdef, Logger, @debug, @warn, @error, AbstractOptions, AbstractLinearSolver, set_options!, MadNLPLapackCPU, SymbolicException,FactorizationException,SolveException,InertiaException, introduce, factorize!, solve!, improve!, is_inertia, inertia, tril_to_full! + +include("kernels.jl") + if has_cuda() include("lapackgpu.jl") export MadNLPLapackGPU diff --git a/lib/MadNLPGPU/src/kernels.jl b/lib/MadNLPGPU/src/kernels.jl new file mode 100644 index 000000000..30e1a37a9 --- /dev/null +++ b/lib/MadNLPGPU/src/kernels.jl @@ -0,0 +1,123 @@ +#= + MadNLP utils +=# + +@kernel function _copy_diag!(dest, src) + i = @index(Global) + dest[i] = src[i, i] +end + +function MadNLP.diag!(dest::CuVector{T}, src::CuMatrix{T}) where T + @assert length(dest) == size(src, 1) + ev = _copy_diag!(CUDADevice())(dest, src, ndrange=length(dest)) + wait(ev) +end + +@kernel function _add_diagonal!(dest, src1, src2) + i = @index(Global) + dest[i, i] = src1[i] + src2[i] +end + +function MadNLP.diag_add!(dest::CuMatrix, src1::CuVector, src2::CuVector) + ev = _add_diagonal!(CUDADevice())(dest, src1, src2, ndrange=size(dest, 1)) + wait(ev) +end + +#= + MadNLP kernels +=# + +# Overload is_valid to avoid fallback to default is_valid, slow on GPU +MadNLP.is_valid(src::CuArray) = true + +# Constraint scaling +function MadNLP.set_con_scale!(con_scale::AbstractVector, jac::CuMatrix, nlp_scaling_max_gradient) + # Compute reduction on the GPU with built-in CUDA.jl function + d_con_scale = maximum(abs, jac, dims=2) + copyto!(con_scale, d_con_scale) + con_scale .= min.(1.0, nlp_scaling_max_gradient ./ con_scale) +end + +@kernel function _treat_fixed_variable_kernell!(dest, ind_fixed) + k, j = @index(Global, NTuple) + i = ind_fixed[k] + + if i == j + dest[i, i] = 1.0 + else + dest[i, j] = 0.0 + dest[j, i] = 0.0 + end +end + +function MadNLP.treat_fixed_variable!(kkt::MadNLP.AbstractKKTSystem{T, MT}) where {T, MT<:CuMatrix{T}} + length(kkt.ind_fixed) == 0 && return + aug = kkt.aug_com + d_ind_fixed = kkt.ind_fixed |> CuVector # TODO: allocate ind_fixed directly on the GPU + ndrange = (length(d_ind_fixed), size(aug, 1)) + ev = _treat_fixed_variable_kernell!(CUDADevice())(aug, d_ind_fixed, ndrange=ndrange) + wait(ev) +end + +#= + DenseKKTSystem kernels +=# + +function MadNLP.mul!(y::AbstractVector, kkt::MadNLP.DenseKKTSystem{T, VT, MT}, x::AbstractVector) where {T, VT<:CuVector{T}, MT<:CuMatrix{T}} + @assert length(kkt._w2) == length(x) + # x and y can be host arrays. Copy them on the device to avoid side effect. + copyto!(kkt._w2, x) + mul!(kkt._w1, kkt.aug_com, kkt._w2) + copyto!(y, kkt._w1) +end + +function MadNLP.jtprod!(y::AbstractVector, kkt::MadNLP.DenseKKTSystem{T, VT, MT}, x::AbstractVector) where {T, VT<:CuVector{T}, MT<:CuMatrix{T}} + # x and y can be host arrays. Copy them on the device to avoid side effect. + # + d_x = x |> CuArray + d_y = y |> CuArray + mul!(d_y, kkt.jac', d_x) + copyto!(y, d_y) +end + +function MadNLP.set_aug_diagonal!(kkt::MadNLP.DenseKKTSystem{T, VT, MT}, ips::MadNLP.Solver) where {T, VT<:CuVector{T}, MT<:CuMatrix{T}} + # Broadcast is not working as MadNLP array are allocated on the CPU, + # whereas pr_diag is allocated on the GPU + copyto!(kkt.pr_diag, ips.zl./(ips.x.-ips.xl) .+ ips.zu./(ips.xu.-ips.x)) + fill!(kkt.du_diag, 0.0) +end + +@kernel function _build_dense_kkt_system_kerneq!( + dest, hess, jac, pr_diag, du_diag, diag_hess, n, m, ns +) + i, j = @index(Global, NTuple) + if (i <= n) + # Transfer Hessian + if (i == j) + dest[i, i] = pr_diag[i] + diag_hess[i] + elseif j <= n + dest[i, j] = hess[i, j] + dest[j, i] = hess[j, i] + end + elseif i <= n + ns + # Transfer slack diagonal + dest[i, i] = pr_diag[i] + elseif i <= n + ns + m + # Transfer Jacobian + i_ = i - n - ns + dest[i, j] = jac[i_, j] + dest[j, i] = jac[i_, j] + # Transfer dual regularization + dest[i, i] = du_diag[i_] + end +end + +function MadNLP._build_dense_kkt_system!( + dest::CuMatrix, hess::CuMatrix, jac::CuMatrix, + pr_diag::CuVector, du_diag::CuVector, diag_hess::CuVector, n, m, ns +) + ndrange = (n+m+ns, n+ns) + ev = _build_dense_kkt_system_kerneq!(CUDADevice())(dest, hess, jac, pr_diag, du_diag, diag_hess, n, m, ns, ndrange=ndrange) + wait(ev) +end + diff --git a/lib/MadNLPGPU/src/lapackgpu.jl b/lib/MadNLPGPU/src/lapackgpu.jl index d4d85c0f8..db8e0ec7b 100644 --- a/lib/MadNLPGPU/src/lapackgpu.jl +++ b/lib/MadNLPGPU/src/lapackgpu.jl @@ -21,8 +21,8 @@ const INPUT_MATRIX_TYPE = :dense lapackgpu_algorithm::Algorithms = BUNCHKAUFMAN end -mutable struct Solver <: AbstractLinearSolver - dense::Matrix{Float64} +mutable struct Solver{MT} <: AbstractLinearSolver + dense::MT fact::CuMatrix{Float64} rhs::CuVector{Float64} work::CuVector{Float64} @@ -35,10 +35,10 @@ mutable struct Solver <: AbstractLinearSolver logger::Logger end -function Solver(dense::Matrix{Float64}; +function Solver(dense::MT; option_dict::Dict{Symbol,Any}=Dict{Symbol,Any}(), opt=Options(),logger=Logger(), - kwargs...) + kwargs...) where {MT <: AbstractMatrix} set_options!(opt,option_dict,kwargs...) fact = CuMatrix{Float64}(undef,size(dense)) @@ -82,10 +82,10 @@ introduce(M::Solver) = "Lapack-GPU ($(M.opt.lapackgpu_algorithm))" if toolkit_version() >= v"11.3.1" is_inertia(M::Solver) = false # TODO: implement inertia(M::Solver) for BUNCHKAUFMAN - + function factorize_bunchkaufman!(M::Solver) - haskey(M.etc,:ipiv) || (M.etc[:ipiv] = CuVector{Int32}(undef,size(M.dense,1))) - haskey(M.etc,:ipiv64) || (M.etc[:ipiv64] = CuVector{Int64}(undef,length(M.etc[:ipiv]))) + haskey(M.etc,:ipiv) || (M.etc[:ipiv] = CuVector{Int32}(undef,size(M.dense,1))) + haskey(M.etc,:ipiv64) || (M.etc[:ipiv64] = CuVector{Int64}(undef,length(M.etc[:ipiv]))) copyto!(M.fact,M.dense) cusolverDnDsytrf_bufferSize( @@ -99,7 +99,7 @@ if toolkit_version() >= v"11.3.1" end function solve_bunchkaufman!(M::Solver,x) - + copyto!(M.etc[:ipiv64],M.etc[:ipiv]) copyto!(M.rhs,x) ccall((:cusolverDnXsytrs_bufferSize, libcusolver()), cusolverStatus_t, @@ -120,7 +120,7 @@ if toolkit_version() >= v"11.3.1" size(M.fact,1),1,R_64F,M.fact,size(M.fact,2), M.etc[:ipiv64],R_64F,M.rhs,length(M.rhs),M.work,M.lwork[],M.work_host,M.lwork_host[],M.info) copyto!(x,M.rhs) - + return x end else @@ -156,7 +156,7 @@ else Cvoid, (Ref{Cchar},Ref{Int},Ref{Int},Ptr{Cdouble},Ref{Int},Ptr{Int},Ptr{Cdouble},Ref{Int},Ptr{Int}), 'L',size(M.fact,1),1,M.etc[:fact_cpu],size(M.fact,2),M.etc[:ipiv_cpu],x,length(x),[1]) - + return x end end diff --git a/lib/MadNLPGPU/test/densekkt_gpu.jl b/lib/MadNLPGPU/test/densekkt_gpu.jl new file mode 100644 index 000000000..8763df248 --- /dev/null +++ b/lib/MadNLPGPU/test/densekkt_gpu.jl @@ -0,0 +1,47 @@ + +using CUDA +using MadNLPTests + +function _compare_gpu_with_cpu(n, m, ind_fixed) + madnlp_options = Dict{Symbol, Any}( + :kkt_system=>MadNLP.DENSE_KKT_SYSTEM, + :linear_solver=>MadNLPLapackGPU, + :print_level=>MadNLP.ERROR, + ) + + nlp = build_qp_test(; n=n, m=m, fixed_variables=ind_fixed) + x, l = copy(nlp.x), copy(nlp.l) + + h_ips = MadNLP.Solver(nlp; option_dict=copy(madnlp_options)) + MadNLP.optimize!(h_ips) + + # Reinit NonlinearProgram to avoid side effect + nlp = MadNLPTests.build_qp_test(; n=n, m=m, fixed_variables=ind_fixed) + ind_cons = MadNLP.get_index_constraints(nlp) + ns = length(ind_cons.ind_ineq) + + # Init KKT on the GPU + kkt = MadNLP.DenseKKTSystem{Float64, CuVector{Float64}, CuMatrix{Float64}}( + nlp, ind_cons; buffer_size=(nlp.n+nlp.m+ns), + ) + # Instantiate Solver with KKT on the GPU + d_ips = MadNLP.Solver(nlp, kkt; option_dict=copy(madnlp_options)) + MadNLP.optimize!(d_ips) + + # Check that both results match exactly + @test h_ips.cnt.k == d_ips.cnt.k + @test h_ips.obj_val ≈ d_ips.obj_val atol=1e-10 + @test h_ips.x ≈ d_ips.x atol=1e-10 + @test h_ips.l ≈ d_ips.l atol=1e-10 +end + +@testset "MadNLP: dense versus sparse" begin + @testset "Size: ($n, $m)" for (n, m) in [(10, 0), (10, 5), (50, 10)] + _compare_gpu_with_cpu(n, m, Int[]) + end + @testset "Fixed variables" begin + n, m = 10, 5 + _compare_gpu_with_cpu(10, 5, Int[1, 2]) + end +end + diff --git a/lib/MadNLPGPU/test/runtests.jl b/lib/MadNLPGPU/test/runtests.jl index f7a197b5e..4c7e11dda 100644 --- a/lib/MadNLPGPU/test/runtests.jl +++ b/lib/MadNLPGPU/test/runtests.jl @@ -30,10 +30,13 @@ testset = [ ], ] -@testset "MadNLPGPU test" begin +# Test LapackGPU wrapper +@testset "LapackGPU test" begin for (name,optimizer_constructor,exclude) in testset test_madnlp(name,optimizer_constructor,exclude) end end +# Test DenseKKTSystem on GPU +include("densekkt_gpu.jl") diff --git a/lib/MadNLPTests/Project.toml b/lib/MadNLPTests/Project.toml index cbcb9e7d2..8a97ea694 100644 --- a/lib/MadNLPTests/Project.toml +++ b/lib/MadNLPTests/Project.toml @@ -6,8 +6,11 @@ version = "0.1.0" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" JuMP = "4076af6c-e467-56ae-b986-b466b2749572" +MadNLP = "2621e9c9-9eb4-46b1-8089-e8c72242dfb6" +NLPModels = "a4795742-8479-5a88-8948-cc11e1c8c1a6" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] JuMP = "~0.21" -julia = "1.3" \ No newline at end of file +julia = "1.3" diff --git a/lib/MadNLPTests/src/MadNLPTests.jl b/lib/MadNLPTests/src/MadNLPTests.jl index 5e3699ccb..261d663be 100644 --- a/lib/MadNLPTests/src/MadNLPTests.jl +++ b/lib/MadNLPTests/src/MadNLPTests.jl @@ -1,9 +1,14 @@ module MadNLPTests -import LinearAlgebra: norm +import MadNLP +import LinearAlgebra: norm, I, mul!, dot +using NLPModels import JuMP: Model, @variable, @constraint, @objective, @NLconstraint , @NLobjective, optimize!, MOI, termination_status, LowerBoundRef, UpperBoundRef, value, dual import Test: @test, @testset +using Random + +export test_madnlp, solcmp function solcmp(x,sol;atol=1e-4,rtol=1e-4) aerr = norm(x-sol,Inf) @@ -57,7 +62,7 @@ function lootsma(optimizer_constructor::Function) @test solcmp(dual.(l),[2.000024518601535,2.0000305441119535]) @test solcmp(dual.(LowerBoundRef.(x)),[0.,0.,0.]) @test solcmp(dual.(UpperBoundRef.(x)),[0.,0.,0.]) - + @test termination_status(m) == MOI.LOCALLY_SOLVED end end @@ -190,6 +195,119 @@ function eigmina(optimizer_constructor::Function) end end -export test_madnlp, solcmp + +struct DenseDummyQP <: AbstractNLPModel{Float64,Vector{Float64}} + meta::NLPModels.NLPModelMeta{Float64, Vector{Float64}} + P::Matrix{Float64} # primal hessian + A::Matrix{Float64} # constraint jacobian + q::Vector{Float64} + hrows::Vector{Int} + hcols::Vector{Int} + jrows::Vector{Int} + jcols::Vector{Int} + counters::Counters +end + +function NLPModels.jac_structure!(qp::DenseDummyQP,I, J) + copyto!(I, qp.jrows) + copyto!(J, qp.jcols) +end +function NLPModels.hess_structure!(qp::DenseDummyQP,I, J) + copyto!(I, qp.hrows) + copyto!(J, qp.hcols) +end + +function NLPModels.obj(qp::DenseDummyQP,x) + return 0.5 * dot(x, qp.P, x) + dot(qp.q, x) +end +function NLPModels.grad!(qp::DenseDummyQP,x,g) + mul!(g, qp.P, x) + g .+= qp.q + return +end +function NLPModels.cons!(qp::DenseDummyQP,x,c) + mul!(c, qp.A, x) +end +# Jacobian: sparse callback +function NLPModels.jac_coord!(qp::DenseDummyQP, x, J::AbstractVector) + index = 1 + for (i, j) in zip(qp.jrows, qp.jcols) + J[index] = qp.A[i, j] + index += 1 + end +end +# Jacobian: dense callback +MadNLP.jac_dense!(qp::DenseDummyQP, x, J::AbstractMatrix) = copyto!(J, qp.A) +# Hessian: sparse callback +function NLPModels.hess_coord!(qp::DenseDummyQP,x, l, hess::AbstractVector; obj_weight=1.) + index = 1 + for i in 1:get_nvar(qp) , j in 1:i + hess[index] = obj_weight * qp.P[j, i] + index += 1 + end +end +# Hessian: dense callback +function MadNLP.hess_dense!(qp::DenseDummyQP, x, l,hess::AbstractMatrix; obj_weight=1.) + hess .= obj_weight .* qp.P +end + +function DenseDummyQP(; n=100, m=10, fixed_variables=Int[]) + if m >= n + error("The number of constraints `m` should be less than the number of variable `n`.") + end + + Random.seed!(1) + + # Build QP problem 0.5 * x' * P * x + q' * x + P = randn(n , n) + P += P' # P is symmetric + P += 100.0 * I + + q = randn(n) + + # Build constraints gl <= Ax <= gu + A = zeros(m, n) + for j in 1:m + A[j, j] = 1.0 + A[j, j+1] = -1.0 + end + + x0 = zeros(n) + y0 = zeros(m) + + # Bound constraints + xu = ones(n) + xl = - ones(n) + gl = -ones(m) + gu = ones(m) + + xl[fixed_variables] .= xu[fixed_variables] + + hrows = [i for i in 1:n for j in 1:i] + hcols = [j for i in 1:n for j in 1:i] + nnzh = div(n * (n + 1), 2) + + jrows = [j for i in 1:n for j in 1:m] + jcols = [i for i in 1:n for j in 1:m] + nnzj = n * m + + return DenseDummyQP( + NLPModels.NLPModelMeta( + n, + ncon = m, + nnzj = nnzj, + nnzh = nnzh, + x0 = x0, + y0 = y0, + lvar = xl, + uvar = xu, + lcon = gl, + ucon = gu, + minimize = true + ), + P,A,q,hrows,hcols,jrows,jcols, + Counters() + ) +end end # module diff --git a/src/interiorpointsolver.jl b/src/interiorpointsolver.jl index fcf180c70..10f75c422 100644 --- a/src/interiorpointsolver.jl +++ b/src/interiorpointsolver.jl @@ -173,14 +173,14 @@ end struct MadNLPExecutionStats{T} <: AbstractExecutionStats status::Status - solution::StrideOneVector{T} - objective::T + solution::StrideOneVector{T} + objective::T constraints::StrideOneVector{T} - dual_feas::T - primal_feas::T - multipliers::StrideOneVector{T} - multipliers_L::StrideOneVector{T} - multipliers_U::StrideOneVector{T} + dual_feas::T + primal_feas::T + multipliers::StrideOneVector{T} + multipliers_L::StrideOneVector{T} + multipliers_U::StrideOneVector{T} iter::Int counters::NLPModelsCounters elapsed_time::Real @@ -191,7 +191,7 @@ struct NotEnoughDegreesOfFreedomException <: Exception end MadNLPExecutionStats(ips::Solver) =MadNLPExecutionStats( ips.status,view(ips.x,1:get_nvar(ips.nlp)),ips.obj_val,ips.c, - ips.inf_du, ips.inf_pr, + ips.inf_du, ips.inf_pr, ips.l,view(ips.zl,1:get_nvar(ips.nlp)),view(ips.zu,1:get_nvar(ips.nlp)), ips.cnt.k, ips.nlp.counters,ips.cnt.total_time) getStatus(result::MadNLPExecutionStats) = STATUS_OUTPUT_DICT[result.status] @@ -337,7 +337,6 @@ function eval_lag_hess_wrapper!(ipp::Solver, kkt::AbstractKKTSystem, x::Vector{F return hess end - function eval_jac_wrapper!(ipp::Solver, kkt::DenseKKTSystem, x::Vector{Float64}) nlp = ipp.nlp cnt = ipp.cnt @@ -367,7 +366,7 @@ function eval_lag_hess_wrapper!(ipp::Solver, kkt::DenseKKTSystem, x::Vector{Floa return hess end -function Solver(nlp::AbstractNLPModel; +function Solver(nlp::AbstractNLPModel, kkt=nothing; option_dict::Dict{Symbol,Any}=Dict{Symbol,Any}(), kwargs...) @@ -393,16 +392,18 @@ function Solver(nlp::AbstractNLPModel; m = get_ncon(nlp) # Initialize KKT - kkt = if opt.kkt_system == SPARSE_KKT_SYSTEM - MT = (opt.linear_solver.INPUT_MATRIX_TYPE == :csc) ? SparseMatrixCSC{Float64, Int32} : Matrix{Float64} - SparseKKTSystem{Float64, MT}(nlp, ind_cons) - elseif opt.kkt_system == SPARSE_UNREDUCED_KKT_SYSTEM - MT = (opt.linear_solver.INPUT_MATRIX_TYPE == :csc) ? SparseMatrixCSC{Float64, Int32} : Matrix{Float64} - SparseUnreducedKKTSystem{Float64, MT}(nlp, ind_cons) - elseif opt.kkt_system == DENSE_KKT_SYSTEM - MT = Matrix{Float64} - VT = Vector{Float64} - DenseKKTSystem{Float64, VT, MT}(nlp, ind_cons) + if !isa(kkt, AbstractKKTSystem) + kkt = if opt.kkt_system == SPARSE_KKT_SYSTEM + MT = (opt.linear_solver.INPUT_MATRIX_TYPE == :csc) ? SparseMatrixCSC{Float64, Int32} : Matrix{Float64} + SparseKKTSystem{Float64, MT}(nlp, ind_cons) + elseif opt.kkt_system == SPARSE_UNREDUCED_KKT_SYSTEM + MT = (opt.linear_solver.INPUT_MATRIX_TYPE == :csc) ? SparseMatrixCSC{Float64, Int32} : Matrix{Float64} + SparseUnreducedKKTSystem{Float64, MT}(nlp, ind_cons) + elseif opt.kkt_system == DENSE_KKT_SYSTEM + MT = Matrix{Float64} + VT = Vector{Float64} + DenseKKTSystem{Float64, VT, MT}(nlp, ind_cons) + end end xl = [get_lvar(nlp);view(get_lcon(nlp),ind_cons.ind_ineq)] @@ -529,7 +530,7 @@ function initialize!(ips::AbstractInteriorPointSolver) @trace(ips.logger,"Computing constraint scaling.") eval_jac_wrapper!(ips, ips.kkt, ips.x) compress_jacobian!(ips.kkt) - if ips.opt.nlp_scaling + if (ips.m > 0) && ips.opt.nlp_scaling jac = get_raw_jacobian(ips.kkt) set_con_scale!(ips.con_scale, jac, ips.opt.nlp_scaling_max_gradient) set_jacobian_scaling!(ips.kkt, ips.con_scale) @@ -1210,7 +1211,7 @@ end # Kernel functions --------------------------------------------------------- is_valid(val::Real) = !(isnan(val) || isinf(val)) -function is_valid(vec) +function is_valid(vec::AbstractArray) @inbounds for i=1:length(vec) is_valid(vec[i]) || return false end diff --git a/src/kktsystem.jl b/src/kktsystem.jl index 8d90e69f0..90c7ffb41 100644 --- a/src/kktsystem.jl +++ b/src/kktsystem.jl @@ -240,7 +240,7 @@ function SparseKKTSystem{T, MT}(nlp::AbstractNLPModel, ind_cons=get_index_constr jac_I = Vector{Int32}(undef, get_nnzj(nlp)) jac_J = Vector{Int32}(undef, get_nnzj(nlp)) jac_structure!(nlp,jac_I, jac_J) - + hess_I = Vector{Int32}(undef, get_nnzh(nlp)) hess_J = Vector{Int32}(undef, get_nnzh(nlp)) hess_structure!(nlp,hess_I,hess_J) @@ -411,19 +411,25 @@ struct DenseKKTSystem{T, VT, MT} <: AbstractKKTSystem{T, MT} diag_hess::VT # KKT system aug_com::MT + # Buffers + _w1::VT + _w2::VT # Info ind_ineq::Vector{Int} ind_fixed::Vector{Int} jacobian_scaling::VT end -function DenseKKTSystem{T, VT, MT}(n, m, ind_ineq, ind_fixed) where {T, VT, MT} +function DenseKKTSystem{T, VT, MT}(n, m, ind_ineq, ind_fixed; buffer_size=0) where {T, VT, MT} ns = length(ind_ineq) hess = MT(undef, n, n) jac = MT(undef, m, n+ns) pr_diag = VT(undef, n+ns) du_diag = VT(undef, m) diag_hess = VT(undef, n) + # Buffers (used mostly when DenseKKTSystem is deported on the GPU) + _w1 = VT(undef, buffer_size) + _w2 = VT(undef, buffer_size) # If the the problem is unconstrained, then KKT system is directly equal # to the Hessian (+ some regularization terms) @@ -445,13 +451,13 @@ function DenseKKTSystem{T, VT, MT}(n, m, ind_ineq, ind_fixed) where {T, VT, MT} fill!(jacobian_scaling, one(T)) return DenseKKTSystem{T, VT, MT}( - hess, jac, pr_diag, du_diag, diag_hess, aug_com, ind_ineq, ind_fixed, jacobian_scaling, + hess, jac, pr_diag, du_diag, diag_hess, aug_com, _w1, _w2, ind_ineq, ind_fixed, jacobian_scaling, ) end -function DenseKKTSystem{T, VT, MT}(nlp::AbstractNLPModel, info_constraints=get_index_constraints(nlp)) where {T, VT, MT} +function DenseKKTSystem{T, VT, MT}(nlp::AbstractNLPModel, info_constraints=get_index_constraints(nlp); options...) where {T, VT, MT} return DenseKKTSystem{T, VT, MT}( - get_nvar(nlp), get_ncon(nlp), info_constraints.ind_ineq, info_constraints.ind_fixed, + get_nvar(nlp), get_ncon(nlp), info_constraints.ind_ineq, info_constraints.ind_fixed; options... ) end @@ -468,9 +474,9 @@ get_raw_jacobian(kkt::DenseKKTSystem) = kkt.jac nnz_jacobian(kkt::DenseKKTSystem) = length(kkt.jac) nnz_kkt(kkt::DenseKKTSystem) = length(kkt.aug_com) -function _update_diagonal!(dest::AbstractMatrix, d1::AbstractVector, d2::AbstractVector) +function diag_add!(dest::AbstractMatrix, d1::AbstractVector, d2::AbstractVector) n = length(d1) - for i in 1:n + @inbounds for i in 1:n dest[i, i] = d1[i] + d2[i] end end @@ -505,7 +511,7 @@ function build_kkt!(kkt::DenseKKTSystem{T, VT, MT}) where {T, VT, MT} m = size(kkt.jac, 1) ns = length(kkt.ind_ineq) if m == 0 # If problem is unconstrained, just need to update the diagonal - _update_diagonal!(kkt.aug_com, kkt.diag_hess, kkt.pr_diag) + diag_add!(kkt.aug_com, kkt.diag_hess, kkt.pr_diag) else # otherwise, we update the full matrix _build_dense_kkt_system!(kkt.aug_com, kkt.hess, kkt.jac, kkt.pr_diag, kkt.du_diag, kkt.diag_hess, n, m, ns) end diff --git a/test/madnlp_dense.jl b/test/madnlp_dense.jl index 20b075cf9..d0336024a 100644 --- a/test/madnlp_dense.jl +++ b/test/madnlp_dense.jl @@ -2,125 +2,10 @@ using Test import MadNLP: jac_structure!, hess_structure!, obj, grad!, cons!, jac_coord!, hess_coord!, jac_dense!, hess_dense! using NLPModels using LinearAlgebra +using MadNLPTests using SparseArrays using Random -struct DenseDummyQP <: AbstractNLPModel{Float64,Vector{Float64}} - meta::NLPModels.NLPModelMeta{Float64, Vector{Float64}} - P::Matrix{Float64} # primal hessian - A::Matrix{Float64} # constraint jacobian - q::Vector{Float64} - hrows::Vector{Int} - hcols::Vector{Int} - jrows::Vector{Int} - jcols::Vector{Int} - counters::Counters -end - - -function jac_structure!(qp::DenseDummyQP,I, J) - copyto!(I, qp.jrows) - copyto!(J, qp.jcols) -end -function hess_structure!(qp::DenseDummyQP,I, J) - copyto!(I, qp.hrows) - copyto!(J, qp.hcols) -end - -function obj(qp::DenseDummyQP,x) - return 0.5 * dot(x, qp.P, x) + dot(qp.q, x) -end -function grad!(qp::DenseDummyQP,x,g) - mul!(g, qp.P, x) - g .+= qp.q - return -end -function cons!(qp::DenseDummyQP,x,c) - mul!(c, qp.A, x) -end -# Jacobian: sparse callback -function jac_coord!(qp::DenseDummyQP, x, J::AbstractVector) - index = 1 - for (i, j) in zip(qp.jrows, qp.jcols) - J[index] = qp.A[i, j] - index += 1 - end -end -# Jacobian: dense callback -jac_dense!(qp::DenseDummyQP, x, J::AbstractMatrix) = copyto!(J, qp.A) -# Hessian: sparse callback -function hess_coord!(qp::DenseDummyQP,x, l, hess::AbstractVector; obj_weight=1.) - index = 1 - for i in 1:get_nvar(qp) , j in 1:i - hess[index] = obj_weight * qp.P[j, i] - index += 1 - end -end -# Hessian: dense callback -function hess_dense!(qp::DenseDummyQP, x, l,hess::AbstractMatrix; obj_weight=1.) - hess .= obj_weight .* qp.P -end - - -function DenseDummyQP(; n=100, m=10, fixed_variables=Int[]) - if m >= n - error("The number of constraints `m` should be less than the number of variable `n`.") - end - - Random.seed!(1) - - # Build QP problem 0.5 * x' * P * x + q' * x - P = randn(n , n) - P += P' # P is symmetric - P += 100.0 * I - - q = randn(n) - - # Build constraints gl <= Ax <= gu - A = zeros(m, n) - for j in 1:m - A[j, j] = 1.0 - A[j, j+1] = -1.0 - end - - x0 = zeros(n) - y0 = zeros(m) - - # Bound constraints - xu = ones(n) - xl = - ones(n) - gl = -ones(m) - gu = ones(m) - - xl[fixed_variables] .= xu[fixed_variables] - - hrows = [i for i in 1:n for j in 1:i] - hcols = [j for i in 1:n for j in 1:i] - nnzh = div(n * (n + 1), 2) - - jrows = [j for i in 1:n for j in 1:m] - jcols = [i for i in 1:n for j in 1:m] - nnzj = n * m - - return DenseDummyQP( - NLPModels.NLPModelMeta( - n, - ncon = m, - nnzj = nnzj, - nnzh = nnzh, - x0 = x0, - y0 = y0, - lvar = xl, - uvar = xu, - lcon = gl, - ucon = gu, - minimize = true - ), - P,A,q,hrows,hcols,jrows,jcols, - Counters() - ) -end - @testset "MadNLP: dense API" begin n = 10 @testset "Unconstrained" begin @@ -129,7 +14,7 @@ end :linear_solver=>MadNLPLapackCPU, ) m = 0 - nlp = DenseDummyQP(; n=n, m=m) + nlp = MadNLPTests.DenseDummyQP(; n=n, m=m) ipd = MadNLP.Solver(nlp, option_dict=dense_options) kkt = ipd.kkt @@ -146,7 +31,7 @@ end :kkt_system=>MadNLP.DENSE_KKT_SYSTEM, :linear_solver=>MadNLPUmfpack, ) - @test_throws Exception MadNLP.Solver(nlp, dense_options_error) + @test_throws Exception MadNLP.Solver(nlp; option_dict=dense_options_error) end @testset "Constrained" begin dense_options = Dict{Symbol, Any}( @@ -154,7 +39,7 @@ end :linear_solver=>MadNLPLapackCPU, ) m = 5 - nlp = DenseDummyQP(; n=n, m=m) + nlp = MadNLPTests.DenseDummyQP(; n=n, m=m) ipd = MadNLP.Solver(nlp, option_dict=dense_options) ns = length(ipd.ind_ineq) @@ -181,11 +66,11 @@ function _compare_dense_with_sparse(n, m, ind_fixed) :print_level=>MadNLP.ERROR, ) - nlp = DenseDummyQP(; n=n, m=m, fixed_variables=ind_fixed) - + nlp = MadNLPTests.DenseDummyQP(; n=n, m=m, fixed_variables=ind_fixed) + ips = MadNLP.Solver(nlp, option_dict=sparse_options) ipd = MadNLP.Solver(nlp, option_dict=dense_options) - + MadNLP.optimize!(ips) MadNLP.optimize!(ipd)