Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add interface points for accessing the internal state of an exported learning network composite model #644

Merged

Conversation

ablaom
Copy link
Member

@ablaom ablaom commented Sep 24, 2021

This PR is attempt to address some issues raised in JuliaAI/MLJ.jl#841.

Supposing you have a learning network that you would like to export as a new stand-alone model type composite. Having bound composite to some data in a machine mach, you would like to arrange that fitted_params(mach) report(mach) will record some additional information about the internal state of the learning network that was built internally, when you called fit!(mach).

Specifically, this PR allows the following: Given any node N in the network, specified in the export process, you can arrange for the result of N() to be recorded in fitted_params(mach) report(mach). Naturally, the call N() happens immediately after the network is fit (and before the internal data anonymization step that empties the source nodes immediately thereafter).

Since N is called with no arguments, it will never see "production" data, which is a point of difference with the predict and/or transform nodes declared at export, which are always called on the production data Xnew, as in predict(mach, Xnew). However, this also means N can have multiple origin nodes (query origins for details). This is indeed the case in the following example, recording a training error in the composite model report:

(edited to reflect syntax adopted after discussions below)

using MLJ

import MLJModelInterface

struct MyModel <: ProbabilisticComposite
    model
end

function MLJModelInterface.fit(composite::MyModel, verbosity, X, y)

    Xs = source(X)
    ys = source(y)

    mach = machine(composite.model, Xs, ys)
    yhat = predict(mach, Xs)
    e = @node auc(yhat, ys)   # <------  node whose state we wish to export

    network_mach = machine(Probabilistic(),
                           Xs,
                           ys,
                           predict=yhat,
                           report=(training_error=e,))  # <------ how we export additional node(s)

    return!(network_mach, composite, verbosity)
end

# demo

X, y = make_moons()
composite = MyModel(ConstantClassifier())
mach = machine(composite, X, y) |> fit!
err = report(mach).training_error    # <------ accesssing the node state

yhat = predict(mach, rows=:);
@assert err  auc(yhat, y)

This is preliminary proof of concept and criticism is most welcome.

The PR also needs a bit more unit testing.

@codecov-commenter
Copy link

codecov-commenter commented Sep 24, 2021

Codecov Report

Merging #644 (0ec5472) into for-0-point-19-release (3676ea3) will increase coverage by 0.20%.
The diff coverage is 92.42%.

Impacted file tree graph

@@                    Coverage Diff                     @@
##           for-0-point-19-release     #644      +/-   ##
==========================================================
+ Coverage                   86.60%   86.81%   +0.20%     
==========================================================
  Files                          37       37              
  Lines                        3352     3389      +37     
==========================================================
+ Hits                         2903     2942      +39     
+ Misses                        449      447       -2     
Impacted Files Coverage Δ
src/MLJBase.jl 92.85% <ø> (ø)
src/composition/learning_networks/nodes.jl 69.17% <ø> (+1.36%) ⬆️
src/machines.jl 84.02% <ø> (ø)
src/composition/learning_networks/machines.jl 90.36% <91.52%> (+2.66%) ⬆️
src/composition/models/inspection.jl 100.00% <100.00%> (ø)
src/composition/models/methods.jl 100.00% <100.00%> (ø)
src/operations.jl 80.76% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 3676ea3...0ec5472. Read the comment docs.

@olivierlabayle
Copy link
Collaborator

The overall high level functionality seems to match what I had in mind and looks great from my perspective. I've added a high level test in that direction but it may be of little value, feel free to revert the commit if this is so.

I am afraid I can't provide a huge feedback on the details of the implementation as I don't fully understand how things are chained together under the hood. I've tried to dig a bit in the learning network implementation but with the asynchrononisity I find it difficult to understand the internals. Especially, I am wondering if the results of a node call are cached (if the node is called again for instance) ?

@ablaom
Copy link
Member Author

ablaom commented Sep 25, 2021

After reflection, I think the extra information should go into the report rather than the fitted_params. Strictly speaking fittted_params is just for the "minimum" required to dispatch predict and/or transform. See this this guideline.

@olivierlabayle Any objections to this change? I understand that in your use-case fitted_params might feel more appropriate, but your use-case really sits outside the current intentions of the API, no?

Another possibility is to allow writing to both report and fitted_params, but I don't see how to design this in a way that's not more complicated, and am not sure the extra complication is warranted. But perhaps you have a suggestion?

@ablaom
Copy link
Member Author

ablaom commented Sep 26, 2021

@olivierlabayle By the way, thanks for the extra test.

Especially, I am wondering if the results of a node call are cached (if the node is called again for instance) ?

The answer to the question I think you are asking is "yes". The call to the node is made immediately after fit! and recorded. When you inspect fitresult, you are inspecting this return value, not recalling the node (which wouldn't work anyway, because of data anonymization.)

Generally, the machines in a learning network and elsewhere cache data, unless they are constructed with machine(... ; cache=false) . If caching is turned off, evaluating a node is purely lazy.

A machine bound to a composite model (subtype Composite) does not cache data by default, although, as just mentioned, learning networks constructed under the hood generally do, as just explained.

Clear as mud, right?

@olivierlabayle
Copy link
Collaborator

No problem at all to export the results to the report instead of the fitted_params , as you say fitted_params should hold the parameter values of the learnt function.

@olivierlabayle
Copy link
Collaborator

Haha it's indeed hard to catch, I was actually wondering in a general manner as you describe second.

From what I "imagine", a fit! of the composite model will necessary call each node in the computational graph to trigger the different fits on the appropriate data. Some nodes are not bound to a machine (static) so they cannot be cached right? Does it mean this kind of node might be evaluated (computed) multiple times if asked for multiple times?

@ablaom
Copy link
Member Author

ablaom commented Sep 26, 2021

From what I "imagine", a fit! of the composite model will necessary call each node in the computational graph to trigger the different fits on the appropriate data. Some nodes are not bound to a machine (static) so they cannot be cached right? Does it mean this kind of node might be evaluated (computed) multiple times if asked for multiple times?

Perhaps there is some confusion about what "caching" means here. The only caching that takes place is for the benefit of training machines. A machine constructed with cache=true internally caches data used to train it. Then, if a hyper-parameter changes, and I have no reason to believe the training nodes have changes the data they deliver if called, then I use the cached data in the next call to fit! the machine. The data cached is generally a model-specific representation of the data (eg, a matrix instead of a table). It was to avoid repeating these internal data pre-processing that caching was introduced (and to allow observation resampling to happen at the level of the model-specific representation).

If you have an static node that performs an expensive computation, then the only benefit caching has is if the output of the node is needed as training data for a machine downstream. However, if you are just calling a node downstream of that static node, the static node will need to re-compute. Similarly, if predict or transform are expensive operations for some internal machine, then caching data is only helpful for training machines downstream of those predict/transform nodes, but calling those nodes is still going to be expensive every time they are called.

Does that help?

@olivierlabayle
Copy link
Collaborator

Allright, thank you for the clarification, that helps a lot!

@ablaom
Copy link
Member Author

ablaom commented Oct 6, 2021

Note to self:

@ablaom
Copy link
Member Author

ablaom commented Oct 27, 2021

Comment to trigger notification to self.

@ablaom ablaom marked this pull request as ready for review December 20, 2021 03:06
@ablaom ablaom changed the base branch from dev to for-0-point-19-release December 20, 2021 03:33
@ablaom ablaom changed the title New interface points for accessing the internal state of an exported learning network composite model Add interface points for accessing the internal state of an exported learning network composite model Dec 20, 2021
@ablaom ablaom mentioned this pull request Dec 20, 2021
18 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants