Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 38 additions & 22 deletions src/solvers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ function CommonSolve.solve(PD::ProblemDescription, FES::Union{<:FESpace, Vector{
for j in 1:(maxits + 1)
allocs_assembly = 0
time_assembly = 0
time_solve_init = 0
allocs_solve_init = 0
time_total = 0
if is_linear && j == 2
nlres = linres
Expand Down Expand Up @@ -246,15 +248,19 @@ function CommonSolve.solve(PD::ProblemDescription, FES::Union{<:FESpace, Vector{
if SC.parameters[:verbosity] > 0
@info ".... initializing linear solver ($(method_linear))\n"
end
abstol = SC.parameters[:abstol]
reltol = SC.parameters[:reltol]
LP = reduced ? LP_reduced : SC.LP
if precon_linear !== nothing
linsolve = init(LP, method_linear; Pl = precon_linear(A.entries.cscmatrix), abstol = abstol, reltol = reltol)
else
linsolve = init(LP, method_linear; abstol = abstol, reltol = reltol)
time_solve_init += @elapsed begin
allocs_solve_init += @allocated begin
abstol = SC.parameters[:abstol]
reltol = SC.parameters[:reltol]
LP = reduced ? LP_reduced : SC.LP
if precon_linear !== nothing
linsolve = init(LP, method_linear; Pl = precon_linear(A.entries.cscmatrix), abstol = abstol, reltol = reltol)
else
linsolve = init(LP, method_linear; abstol = abstol, reltol = reltol)
end
SC.linsolver = linsolve
end
end
SC.linsolver = linsolve
end

## compute nonlinear residual
Expand Down Expand Up @@ -282,8 +288,8 @@ function CommonSolve.solve(PD::ProblemDescription, FES::Union{<:FESpace, Vector{
@info "sub-residuals = $(norms(residual))"
end
end
time_final += time_assembly
allocs_final += allocs_assembly
time_final += time_assembly + time_solve_init
allocs_final += allocs_assembly + allocs_solve_init
end
push!(stats[:assembly_allocations], allocs_assembly)
push!(stats[:assembly_times], time_assembly)
Expand Down Expand Up @@ -416,6 +422,8 @@ function CommonSolve.solve(PD::ProblemDescription, FES::Union{<:FESpace, Vector{
time_total += time_solve
time_final += time_solve
allocs_final += allocs_solve
time_solve += time_solve_init
allocs_solve += allocs_solve_init
push!(stats[:solver_allocations], allocs_solve)
push!(stats[:solver_times], time_solve)
push!(stats[:total_times], time_total)
Expand Down Expand Up @@ -559,6 +567,8 @@ function iterate_until_stationarity(
allocs_assembly = 0
time_assembly = 0
time_total = 0
time_solve_init = 0
allocs_solve_init = 0
for p in 1:nPDs
b = bs[p]
A = As[p]
Expand Down Expand Up @@ -609,20 +619,24 @@ function iterate_until_stationarity(
## init solver
linsolve = SC.linsolver
if linsolve === nothing
method_linear = SC.parameters[:method_linear]
precon_linear = SC.parameters[:precon_linear]
if SC.parameters[:verbosity] > 0
@info ".... initializing linear solver ($(method_linear))\n"
end
abstol = SC.parameters[:abstol]
reltol = SC.parameters[:reltol]
LP = SC.LP
if precon_linear !== nothing
linsolve = LinearSolve.init(LP, method_linear; Pl = precon_linear(linsolve.A), abstol = abstol, reltol = reltol)
else
linsolve = LinearSolve.init(LP, method_linear; abstol = abstol, reltol = reltol)
time_solve_init += @elapsed begin
allocs_solve_init += @allocated begin
method_linear = SC.parameters[:method_linear]
precon_linear = SC.parameters[:precon_linear]
abstol = SC.parameters[:abstol]
reltol = SC.parameters[:reltol]
LP = SC.LP
if precon_linear !== nothing
linsolve = LinearSolve.init(LP, method_linear; Pl = precon_linear(linsolve.A), abstol = abstol, reltol = reltol)
else
linsolve = LinearSolve.init(LP, method_linear; abstol = abstol, reltol = reltol)
end
SC.linsolver = linsolve
end
end
SC.linsolver = linsolve
end

## compute nonlinear residual
Expand All @@ -644,8 +658,8 @@ function iterate_until_stationarity(
nlres = norm(residual.entries)
@printf "\tres[%d] = %.2e" p nlres
end
time_final += time_assembly
allocs_final += allocs_assembly
time_final += time_assembly + time_solve_init
allocs_final += allocs_assembly + allocs_solve_init

if nlres < nltol
converged[p] = true
Expand Down Expand Up @@ -705,6 +719,8 @@ function iterate_until_stationarity(
time_total += time_solve
time_final += time_solve
allocs_final += allocs_solve
time_solve += time_solve_init
allocs_solve += allocs_solve_init
if SC.parameters[:verbosity] > -1
@printf " (%.3e)" linres
end
Expand Down
Loading