Skip to content

Commit

Permalink
Merge pull request #39 from axelparmentier/fix_ci
Browse files Browse the repository at this point in the history
Clean up tests to diagnose Frank-Wolfe linear solver error:
- Revert specification of test dependencies to the more robust version (see this issue)
- Adjust default parameters for Frank-Wolfe (increase precision and switch step size selection) because a recent release or FrankWolfe.jl broke our tests (see Broken tests #40)
- Deactivate test verbosity to make CI logs easier to read
- Switch data to Float32 to suppress Flux.jl warning
- Try to parallelize tests but they no longer show up in the test set so I gave up (see this discourse thread, it seems like an open problem)
  • Loading branch information
BatyLeo committed Mar 28, 2023
2 parents a0c0cf9 + a840f89 commit 72e54d6
Show file tree
Hide file tree
Showing 12 changed files with 78 additions and 1,059 deletions.
17 changes: 16 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,22 @@ ThreadsX = "0.1.11"
julia = "1.7"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
FrankWolfe = "f55ce6ea-fdc5-4628-88c5-0087fe54bd30"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
GridGraphs = "dd2b58c7-5af7-4f17-9e46-57c68ac813fb"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test"]
test = ["Aqua", "Documenter", "Flux", "FrankWolfe", "Graphs", "GridGraphs", "JuliaFormatter", "LinearAlgebra", "Literate", "ProgressMeter", "Random", "Revise", "Statistics", "Test", "UnicodePlots", "Zygote"]
2 changes: 1 addition & 1 deletion src/InferOpt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module InferOpt

using ChainRulesCore
using FrankWolfe: FrankWolfe
using FrankWolfe: ActiveSet, Adaptive, LinearMinimizationOracle
using FrankWolfe: ActiveSet, Agnostic, LinearMinimizationOracle
using FrankWolfe: away_frank_wolfe, compute_extreme_point
using Krylov: gmres
using LinearAlgebra
Expand Down
13 changes: 12 additions & 1 deletion src/frank_wolfe/differentiable_frank_wolfe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ function DifferentiableFrankWolfe(f, f_grad1, lmo, linear_solver=gmres)
return DifferentiableFrankWolfe(f, f_grad1, lmo, linear_solver)
end

struct SolverFailureException{S} <: Exception
msg::String
stats::S
end

function Base.show(io::IO, sfe::SolverFailureException)
return println(io, "SolverFailureException: $(sfe.msg)\nSolver stats: $(sfe.stats)")
end

## Forward pass

"""
Expand Down Expand Up @@ -117,7 +126,9 @@ function ChainRulesCore.rrule(
weights_tangent = probadist_tangent.weights
dp = convert(Vector{R}, unthunk(weights_tangent))
u, stats = linear_solver(Aᵀ, dp)
stats.solved || error("Linear solver failed to converge")
if !stats.solved
throw(SolverFailureException("Linear solver failed to converge", stats))
end
dθ_vec = Bᵀ * u
= reshape(dθ_vec, size(θ))
return (NoTangent(), NoTangent(), dθ, NoTangent())
Expand Down
4 changes: 2 additions & 2 deletions src/frank_wolfe/frank_wolfe_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Default configuration for the Frank-Wolfe wrapper.
- `away_steps=true`: activate away steps to avoid zig-zagging
- `epsilon=1e-4`: precision
- `lazy=true`: caching strategy
- `line_search=FrankWolfe.Adaptive()`: step size selection
- `line_search=FrankWolfe.Agnostic()`: step size selection
- `max_iteration=10`: number of iterations
- `timeout=1.0`: maximum time in seconds
- `verbose=false`: console output
Expand All @@ -16,7 +16,7 @@ const DEFAULT_FRANK_WOLFE_KWARGS = (
away_steps=true,
epsilon=1e-4,
lazy=true,
line_search=Adaptive(),
line_search=Agnostic(),
max_iteration=10,
timeout=1.0,
verbose=false,
Expand Down
Loading

0 comments on commit 72e54d6

Please sign in to comment.