-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
MAINT/ENH: SaveModel based serialization #128
Conversation
…rs and metrics, reorganize serialization stuff into a module
Codecov Report
@@ Coverage Diff @@
## develop #128 +/- ##
===========================================
+ Coverage 99.52% 99.70% +0.18%
===========================================
Files 5 6 +1
Lines 627 678 +51
===========================================
+ Hits 624 676 +52
+ Misses 3 2 -1
Continue to review full report at Codecov.
|
Look like this is only working on |
So tests are mostly working, but I'm seeing |
scikeras/__init__.py
Outdated
keras.Model.__reduce__ = saving_utils.pack_keras_model | ||
keras.losses.Loss.__reduce__ = saving_utils.pack_keras_loss | ||
keras.metrics.Metric.__reduce__ = saving_utils.pack_keras_metric | ||
keras.optimizers.Optimizer.__reduce__ = saving_utils.pack_keras_optimizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't these lines go in the definitions of Model/Loss/etc?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean in the class def within Tensorflow? Maybe I'm not understanding...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see. These are Keras classes, not SciKeras classes. Then why not put this into a Keras RFC?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right now, these things are very hacky. The optimizer hack is never going to make it past any sort of review (I'm not even sure it works for 100% of cases; it relies on private methods). Model.__reduce__
does a bunch of zip file stuff that is also pretty hacky, and really should be replaced with TF utilities that need to be implemented on the C++ backend (I think) as discussed in the TF PR (tensorflow/tensorflow#39609 (comment)). TL;DR this should be upstreamed, but it's not going to happen in the current form, and not anytime soon.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there's two levels for RFCs: the goal and how to achieve that goal.
This PR shows the goal is feasible and practical, and serves as a prototype (_restore_optimizer_weights
is clearly a rough prototype). It's clear what needs to be implemented in a cleaner way. Inside TF, there's no downside to depending on private functions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that this PR shows that it's feasible, but a (similar) thing is already at tensorflow/tensorflow#39609.
I guess, let me just ask, what steps do you suggest be taken? Open a new RFC or PR the optimizer weights part specifically?
For this PR/SciKeras I'm thinking of reverting to the old model serialization code (so that it works on TF 2.2.0 and Windows) but keep the optimizer stuff (so that all of the good new tests you added pass).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might open a new PR to TF after this PR is merged. Or maybe add it to that PR with an explaining comment if the implementation is simple.
If they want an RFC they'll call for it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have no problem opening a hacky PR in TF just for visibility. Just to be clear, we're talking specifically about the optimizer weights part right (not about Metrics or Model)? And specifically, about the hack to restore the weights for Adam
and other optimizers that use slots, not about implementing __reduce__
for all optimizers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we're talking specifically about the optimizer weights part right
My concern is Keras having a working serialization method. I think Models and optimizers are most important use cases (plus all metrics I can think of are stateless past some rolling average).
I'm not sure what the difference is between "restoring weights for Adam" and "implementing __reduce__
." Do they both accomplish a working serialization method? Do they have a different implementation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, they are fundamentally two distinct things.
In this PR, we are fixing the weight restoration bug by implementing some hacks within Model.__reduce__
. I reported the bug and showed how this hack works in tensorflow/tensorflow#44670 in hopes that that will make it easier for them to fix the bug. To actually fix the bug, one would have to implement an actual fix within the TF SaveModel
ecosystem. I tried and failed to do that. Thus I don't see that there's any further PRs I can submit to TensorFlow.
As for implementing Model.__reduce__
, I think we should update the existing PR with whatever implementation we land on here, sans the hack to restore optimizer weights since that is functionally unrelated.
I hope this makes sense! I know it's a convoluted topic.
Looks like the problem was something related to collisions of the temporary ram folders. Fixed by using a uuid instead of the id(object). |
7d8391f
to
d7c776a
Compare
scikeras/_saving_utils.py
Outdated
with tempfile.TemporaryDirectory() as tmpdirname: | ||
model.save(tmpdirname) | ||
b = BytesIO() | ||
with tarfile.open(fileobj=b, mode="w:gz") as tar: | ||
tar.add(tmpdirname, arcname=os.path.sep) | ||
b.seek(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@stsievert how do you feel about saving to a temp directory? It works (that's the best thing I can say about it). But it seems wrong to me to serialize to disk and then load from disk to memory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I wouldn't write to disk. Pythons io module has in-memory file pointers with StringIO and BytesIO. Would those work?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is what we're using here. The pickle transfer happens as bytes. The issue is that Keras/TF can't save to BytesIO
or StringIO
. That is:
model.save("some/directory/on/disk") # works
model.save(BytesIO()) # does not work
So what we are doing here is:
- Save TF model to a temporary folder on disk.
- Load from that temp dir into a
BytesIO
object. - Wrap the
BytesIO
object in Numpy and pickle that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why wouldn't pickle work to serialize a Keras model? I thought that's the point of tensorflow/tensorflow#39609.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the same implementation as tensorflow/tensorflow#39609. The only difference is that there I am using TF's ram://
filesystem instead of writing to disk (i.e. write to disk-like thing in RAM, then load that into BytesIO). But that doesn't work on Windows, hence why I am considering writing to actual disk here. We could stop supporting Windows in which case can implement things exactly like tensorflow/tensorflow#39609.
That said, obviously the goal is for this to eventually end up in TensorFlow itself via tensorflow/tensorflow#39609 or tensorflow/tensorflow#39609 + other PRs at which point we'd simply delete the implementation from SciKeras.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's a good clarification.
The optimizer hack is technically separate from the move from the old Model
serialization code (an assortment of Keras methods that is not officially supported by TF/Keras) to this code (which uses SaveModel
as the backend).
The move to SaveModel was going to happen at some point because that's what Keras is going to have more support for going forward and that's what will eventually be upstreamed.
That said, from an implementation perspective, these two things are pretty intermingled, and since they're private implementation details anyway, it makes sense to make both changes in the same PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That said, from an implementation perspective, these two things are pretty intermingled, and since they're private implementation details anyway, it makes sense to make both changes in the same PR.
How complex will the implementation of optimzier.__reduce__
be? Or, how many LOC do you estimate will need to be added to implement optimizer.__reduce__
in tensorflow/tensorflow#39609? I'm not talking about the implementation (not the tests/etc).
If it's a pretty simple implementation, I think the optimizer.__reduce__
implementation should go in that PR with comment explaining the change. I'm not sure if the TF team will accept the addition – in the minutes (tensorflow/community#286 (comment)), Francois said:
- makes sense for models, not so much for optimizers (since state needs to be maintained, we only serialize configs)
- ...
- we should not implement at the moment for objects (callbacks, optimizers, metrics, losses with lambdas, etc.) where we only promise to save the config, not the internal state (as this breaks guarantees and API safety; state is usually retrieved via different API calls; results of these can be used to set state too)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was referring to putting them together in this PR. But I agree, currently tensorflow/tensorflow#39609 is in a state of POC, so adding code that maybe won't make it into the final version in order to chart/scope the project in general is probably a good idea.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd be inclined to reach out to the TF team to hear what they have to say; they had some not-mild pushback on serializing optimizer state. I'd be inclined to reach out to the TF team after this PR is firmed up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the pushback wasn't because "it shouldn't be done" but rather because "it's not implemented cleanly in TF, and we don't want a complex __reduce__
implementation". But yeah I'll follow up after this PR.
Do we both agree reviewing / merging this PR into SciKeras and tabling the TF upstreaming discussion for a later date?
…rs and metrics, reorganize serialization stuff into a module
Co-authored-by: Scott <stsievert@users.noreply.github.com>
Building upon the test added in #126.
This PR implements:
SaveModel
based serialization.Closes #126, closes #70, closes #125