Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up tests to diagnose Frank-Wolfe linear solver error #39

Merged
merged 7 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,22 @@ ThreadsX = "0.1.11"
julia = "1.7"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
FrankWolfe = "f55ce6ea-fdc5-4628-88c5-0087fe54bd30"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
GridGraphs = "dd2b58c7-5af7-4f17-9e46-57c68ac813fb"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test"]
test = ["Aqua", "Documenter", "Flux", "FrankWolfe", "Graphs", "GridGraphs", "JuliaFormatter", "LinearAlgebra", "Literate", "ProgressMeter", "Random", "Revise", "Statistics", "Test", "UnicodePlots", "Zygote"]
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

## Overview

`InferOpt.jl` is a toolbox for using combinatorial optimization algorithms within machine learning pipelines.
InferOpt.jl is a toolbox for using combinatorial optimization algorithms within machine learning pipelines.

It allows you to create differentiable layers from optimization oracles that do not have meaningful derivatives.
Typical examples include mixed integer linear programs or graph algorithms.
Expand Down
6 changes: 3 additions & 3 deletions docs/src/background.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Background

The goal of `InferOpt.jl` is to make machine learning pipelines more expressive by incorporating combinatorial optimization layers.
The goal of InferOpt.jl is to make machine learning pipelines more expressive by incorporating combinatorial optimization layers.

## How the math works

Expand All @@ -12,9 +12,9 @@ where $\mathcal{V} \subset \mathbb{R}^d$ is a finite set of feasible solutions,
Note that any linear program (LP) or mixed integer linear program (MILP) can be formulated this way.

Unfortunately, the optimal solution $f(\theta)$ is a piecewise constant function of $\theta$, which means its derivative is either zero or undefined.
Starting with an oracle for $f$, `InferOpt.jl` approximates it with a differentiable "layer", whose derivatives convey meaningful slope information.
Starting with an oracle for $f$, InferOpt.jl approximates it with a differentiable "layer", whose derivatives convey meaningful slope information.
Such a layer can then be used within a machine learning pipeline, and gradient descent will succeed.
`InferOpt.jl` also provides adequate loss functions for structured learning.
InferOpt.jl also provides adequate loss functions for structured learning.

For more details on the theoretical aspects, you can check out our paper:

Expand Down
2 changes: 1 addition & 1 deletion src/InferOpt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module InferOpt

using ChainRulesCore
using FrankWolfe: FrankWolfe
using FrankWolfe: ActiveSet, Adaptive, LinearMinimizationOracle
using FrankWolfe: ActiveSet, Agnostic, LinearMinimizationOracle
using FrankWolfe: away_frank_wolfe, compute_extreme_point
using Krylov: gmres
using LinearAlgebra
Expand Down
13 changes: 12 additions & 1 deletion src/frank_wolfe/differentiable_frank_wolfe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ function DifferentiableFrankWolfe(f, f_grad1, lmo, linear_solver=gmres)
return DifferentiableFrankWolfe(f, f_grad1, lmo, linear_solver)
end

struct SolverFailureException{S} <: Exception
msg::String
stats::S
end

function Base.show(io::IO, sfe::SolverFailureException)
return println(io, "SolverFailureException: $(sfe.msg)\nSolver stats: $(sfe.stats)")
end

## Forward pass

"""
Expand Down Expand Up @@ -117,7 +126,9 @@ function ChainRulesCore.rrule(
weights_tangent = probadist_tangent.weights
dp = convert(Vector{R}, unthunk(weights_tangent))
u, stats = linear_solver(Aᵀ, dp)
stats.solved || error("Linear solver failed to converge")
if !stats.solved
throw(SolverFailureException("Linear solver failed to converge", stats))
end
dθ_vec = Bᵀ * u
= reshape(dθ_vec, size(θ))
return (NoTangent(), NoTangent(), dθ, NoTangent())
Expand Down
8 changes: 4 additions & 4 deletions src/frank_wolfe/frank_wolfe_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@ Default configuration for the Frank-Wolfe wrapper.

# Parameters
- `away_steps=true`: activate away steps to avoid zig-zagging
- `epsilon=1e-2`: precision
- `epsilon=1e-4`: precision
- `lazy=true`: caching strategy
- `line_search=FrankWolfe.Adaptive()`: step size selection
- `line_search=FrankWolfe.Agnostic()`: step size selection
- `max_iteration=10`: number of iterations
- `timeout=1.0`: maximum time in seconds
- `verbose=false`: console output
"""
const DEFAULT_FRANK_WOLFE_KWARGS = (
away_steps=true,
epsilon=1e-2,
epsilon=1e-4,
lazy=true,
line_search=Adaptive(),
line_search=Agnostic(),
max_iteration=10,
timeout=1.0,
verbose=false,
Expand Down
Loading