Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plans for autodiff on fitted models? #220

Open
pat-alt opened this issue Mar 31, 2023 · 2 comments
Open

Plans for autodiff on fitted models? #220

pat-alt opened this issue Mar 31, 2023 · 2 comments

Comments

@pat-alt
Copy link
Collaborator

pat-alt commented Mar 31, 2023

Motivation and description

Maybe this is a more general topic for MLJ, not only related to Flux. I know that autodiff has been discussed in the past and with MLJFlux now being developed, I was wondering if this topic has come back into focus.

In an ideal world, it would be possible to differentiate through any SupervisedModel and get gradients with respect to parameters or inputs. This would, for example, greatly increase the scope of models we can explain through Counterfactual Explanations (see plans outlined here).

MLJFlux seems like a good place to start, since the underlying models are compatible with Zygote. But even here we quickly run into issues: for example, it does not seem possible to differentiate through a predict call.

An example:

using MLJ
Random.seed!(1234)

X, y = make_blobs(1000, 2, centers=2)
X = MLJ.table(Float32.(MLJ.matrix(X)))
NeuralNetworkClassifier = @load NeuralNetworkClassifier pkg=MLJFlux
clf = NeuralNetworkClassifier()
mach = machine(clf, X, y)
fit!(mach)

# Two different methods to return softmax output:
using Flux
f(x) = permutedims(pdf(predict(mach, MLJ.table(Float32.(x'))), levels(y)))
g(x) = mach.fitresult[1](Float32.(x))

x = rand(2,1)

Both f and g can be used to return softmax output for x

julia> f(x)
2×1 Matrix{Float32}:
 0.24943815
 0.75056183

julia> g(x)
2×1 Matrix{Float32}:
 0.24943815
 0.75056183

Autodiff only works for g,

loss(x, y, fun) = Flux.Losses.crossentropy(fun(x),y)
julia> gradient(loss, x, 1, g)
([2.226835250854492; 0.937971830368042;;], 6.655112266540527, nothing)

but not for f:

julia> gradient(loss, x, 1, f)
ERROR: 

──────────────────────────────────────────────────────────────── Zygote.CompileError ───────────────────────────────────────────────────────────────
╭──── Error Stack ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
│                                                                                                                                          │
│                                                                                                                                          │
│  ╭──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮        │
│  │                                                                                                                              │        │
│  │  (1) top-level scope                                                                                                         │        │
│  │      ~/.julia/packages/CUDA/BbliS/src/initialization.jl:52                                                                   │        │
│  │        │ ╭──────────────────────────────────────────────────────────╮                                                        │        │
│  │        ╰─│    50     quote                                          │                                                        │        │
│  │          │    51         try                                        │                                                        │        │
│  │          │  ❯ 52             $(ex)                                  │                                                        │        │
│  │          │    53         finally                                    │                                                        │        │
│  │          │    54             $task_local_state()...                 │                                                        │        │
│  │          ╰─────────────────────── error line ───────────────────────╯                                                        │        │
│  │                                                                                                                              │        │
│  ╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────── TOP LEVEL ───╯        │
│                                                                                                                                          │
│     (2) top-level scope                                                                                                                  │
│         REPL[202]:1                                                                                                                      │
│                                                                                                                                          │
│  ─────────────────────────────────────────────────────────── In module Core ───────────────────────────────────────────────────────────  │
│                                                                                                                                          │
│     ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────     │
│                                                    Skipped 16 frames in Zygote, Base                                                     │
│     ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────     │
│                                                                                                                                          │
│     (19) (::Core.GeneratedFunctionStub)(::Any, ::Vararg)                                                                                 │
│          ./boot.jl:582                                                                                                                   │
│                                                                                                                                          │
│  ────────────────────────────────────────────────────────── In module Zygote ──────────────────────────────────────────────────────────  │
│                                                                                                                                          │
│     (20) var"#s2948#1107"(::Any, ctx::Any, f::Any, args::Any)                                                                            │
│          ./none:0                                                                                                                        │
│                                                                                                                                          │
│     ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────     │
│                                                        Skipped 6 frames in Zygote                                                        │
│     ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────     │
│                                                                                                                                          │
│  ╭──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮        │
│  │                                                                                                                              │        │
│  │  (28) error(s::String)                                                                                                       │        │
│  │       ./error.jl:35                                                                                                          │        │
│  │                                                                                                                              │        │
│  ╰─────────────────────────────────────────────────────────────────────────────────────────────────────────────── ERROR LINE ───╯        │
│                                                                                                                                          │
╰──── Error Stack ─────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
╭─────────────────────────────────────────────────────────── Zygote.CompileError ──────────────────────────────────────────────────────────╮
│                                                                                                                                          │
│  no message for error of type Zygote.CompileError, sorry.                                                                                │
│                                                                                                                                          │
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯

A simple workaround for this specific issue is to just use the Chain directly to produce the softmax output but this approach does not generalise to other MLJ models.

I appreciate that this is a very ambitious idea (perhaps previous discussions have that this is simply asking too much), but I would be curious to hear what others think.

Worth mentioning that for the plans mentioned above, I will get some support from a group of CS students soon. So if you have any plans or ongoing work in this space anyway, perhaps there's something we can help with.

Thanks!

Possible Implementation

No response

@ablaom
Copy link
Collaborator

ablaom commented Apr 4, 2023

I appreciate that this is a very ambitious idea

Yeah. I think this is a very noble goal but indeed challenging. Still, given the nobility of the goal, I think it's definitely worth scoping out where the issues lie.

I'm not sure what the particluar issue raised above is.

The problem is that MLJ started when Flux was still in relative infancy (no Zygote) and there's a lot of mutation where Zygote just spits the dummy.

When I last played with this, I ran into a rather serious obstacle for probabilistic classifiers. The implementation of UnivariateFinite (now at CategoricalDistributions.jl) uses (mutable) dictionaries, which Zygote did not like. I wonder if this is stilll the case? For example, can I differentiate

p -> pdf(UnivariateFinite(["x', "y"], p, pool=missing), "x")

which is equivalent to p -> p[1] for vectors p?

@pat-alt
Copy link
Collaborator Author

pat-alt commented Apr 7, 2023

Thanks for sharing your thoughts on this, Anthony.

We'll be looking at this in the coming weeks/months and I have no doubt we'll run into lots of issues related to mutation. Nonetheless, I think it's worth exploring. I think MLJFlux is a good starting point, since CounterfactualExplanations is currently tailored to Flux. Alternatively, the logistic classifier from MLJLinearModels also seems like a natural first candidate. From there we're most interested in adding support for tree-based models, which will most likely involve a detour to classifier calibration.

If it's alright, I'll keep this open for now and we may come back here with updates.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants