Skip to content
This repository has been archived by the owner on Dec 21, 2023. It is now read-only.

Improve mxnet support for activity classifier save/load #129

Merged
merged 2 commits into from
Dec 22, 2017

Conversation

gustavla
Copy link
Collaborator

This PR updates the activity classifier (AC) save/load to work with more versions of mxnet. There are still some issues with export to Core ML for newer versions of mxnet, so this PR alone does not expand support yet (#17).

The issue with the AC is that it saves and loads the network graph using mxnet json files. MXNet is pretty good at backward compatibility, but not forward compatibility. If we support more than one version of mxnet at a time, this creates a problem for us where we can't even be same-version compatible:

  1. User 1 (turicreate==4.0.0 and mxnet==0.12.1) saves a model.
  2. User 2 (turicreate==4.0.0 and mxnet==0.11.0) loads the model.

This won't work due to forward incompatibility in mxnet. The solution is to avoid saving and loading the graph and simply building it up and copying over the weights. This has much better forward compatibility.

Support (old and new)

Let's look at current support and support after this PR. I'll show my testing as matrices where:

  • rows: model saved in
  • columns: model loaded in

Classifier (IC)/Similarity (IS)/Detector (OD): (all work on all combinations of save/load, although currently you get warnings if you do not use MXNet 0.11.0. This PR will eliminate those warnings. This broad support is thanks to the same changes that I'm making to the AC in this PR)

TC/MXNet 4/0.11.0 4/0.12.1 4/1post1
v4/0.11.0
v4/0.12.1
v4/1post1

Activity Classifier (AC): (since this PR changes the AC saver/loader, I tested a bunch of combinations)

TC/MXNet 4/0.11.0 4/0.12.1 4/1post1 PR/0.11.0 PR/0.12.1 PR/1post1
v4/0.11.0
v4/0.12.1 🚫
v4/1post1 🚫 🚫
PR/0.11.0 ⚠️ ⚠️ ⚠️
PR/0.12.1 ⚠️ ⚠️ ⚠️
PR/1post1 ⚠️ ⚠️ ⚠️

✅ Works
🚫 Does not work (ugly error)
⚠️ Does not work (fails gracefully)

"PR" refers to the Turi Create model as defined by this PR's commit. "1post1" is short for 1.0.0.post1 (1.0.0 segfaults the object detector, and this seems to have been resolved in the post1 version).

Top-left: This is the status quo
Right half: Backwards and same-version compatible
Bottom-left: Forward incompatible (with respect to TC version). See note at the bottom.

The "graceful" failure in 4.0 actually says "Corrupted model. Cannot load a model with this version." for OD/AC, and for IC/IS it does not even check the version! This PR in an isolated commit also improves this and makes the message friendlier and tells the user to upgrade Turi Create. Unfortunately, whenever we upgrade the file format for IC/IS, it will fail very ungracefully on 4.0.

Why not be forward compatible?

We could make newer models load in 4.0 as well. However, that is a commitment to write the backward migration to mxnet for all its future versions. For instance, in mxnet 5.0, we would still need to write json graphs that look like 0.11.0. It's better to break this compatibility now, since we would probably break it eventually. At least going forward, we have much better chances of being forward compatible (in mxnet version) for the AC, just like it turned out we are for IC/IS/OD.

context = _mxnet_utils.get_mxnet_context(max_devices=state['num_sessions'])
state['_loss_model'] = _mxnet_utils.load_mxnet_model_from_state(
state['_loss_model'], data, labels, None, context)
Copy link
Collaborator

@igiloh igiloh Dec 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this mean we're no longer backward compatible?
If someone saved a model using version 1, the weights are now saved only in the loss model, and therefore when later in lines 301-303 when loading params from state['_pred_model'] they would be all zeros, won't they?

I can understand not being forward compatible (model saved in new version should not load in old version). But backwards compatibility is important.
We could check for if version==1 or '_loss_model' in state then extract the params from loss model, else extract from pred model.
Right?

Copy link
Collaborator Author

@gustavla gustavla Dec 21, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review! In the current v4, there is no weight sharing when it gets saved to file. All weights are saved twice. Looking at the actual saved files, a model saved with v4 takes 4 MB while a model saved with v4+ takes 2 MB. Therefore, there is no problem for v4+ to simply ignore half of those weights and load the model entirely from the pred_model.

Also, regarding backward compatibility. Every cell in the 6x6 matrix I showed in the original post is the result of an actual test and not just my hopes (I wanted to be very thorough!), so I have tested and verified full backward compatibility.

@igiloh
Copy link
Collaborator

igiloh commented Dec 21, 2017

Please see my comment about backward compatibility. Otherwise LGTM.
I would still wait for @alonpal's review as well, as he's more familiar with the MXnet arch in the AC.

@gustavla
Copy link
Collaborator Author

Sounds good, I will wait for @alonpal's review. Thanks!

Copy link
Collaborator

@igiloh igiloh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got a message from @alonpal. He's taking a flight so he can't get online, but he reviewed the changes and approves them.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants