-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-433] Tutorial on saving and loading gluon models #11002
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great reference! Just a few suggestions.
batch_size = 64 | ||
|
||
# Helper to preprocess data for training | ||
def transform(data, label): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be better to use the new transforms. Something like train_data = gluon.data.vision.MNIST(train=True).transform_first(transforms.ToTensor())
batch_size, shuffle=True) | ||
|
||
# Build a simple convolutional network | ||
def build_lenet(net): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does this need to take an argument? Can't we just create net inside of function here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So that, I can use this function to build either Block or HybridBlock. I'm building the network as a Block to demonstrate saving and loading parameters. I'm then building the network as HybridBlock to demonstrate saving and loading parameters and model architecture.
import numpy as np | ||
``` | ||
|
||
## Build and train a simple model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setup: build and train a simple model
new_net.load_params(file_name, ctx=ctx) | ||
``` | ||
|
||
Note that to do this, we need the definition of the network as Python code. If our network is [Hybrid](https://mxnet.incubator.apache.org/tutorials/gluon/hybrid.html), we can even save the network architecture into files and we won't need the network definition in a Python file to load the network. We'll see how to do it in the next section. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would make a little more explicit that if on a different machine, you'd need to import the same function and run to create the same object before loading params.
|
||
Model predictions: [1. 1. 4. 5. 0. 5. 7. 0. 3. 6.] <!--notebook-skip-line--> | ||
|
||
## Saving model architecture and weights to file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Saving model parameters AND architecture to file
So that it matches with title format above.
|
||
That's it! `export` in this case creates `lenet-symbol.json` and `lenet-0001.params` in the current directory. | ||
|
||
## Loading saved model architecture and weights from a different frontend |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Loading model parameters AND architecture from file
So that it matches with title format above. As major heading. Then create subheadings for 'from different frontend' and 'from Python'
# Saving and Loading Gluon Models | ||
|
||
Training large models take a lot of time and it is a good idea to save the trained models to files to avoid training them again and again. There is a number of reasons to do this. For example, you might want to do inference on a machine that is different from the one where the model was trained. Sometimes model's performance on validation set decreases towards the end of the training because of overfitting. If you saved your model parameters after every epoch, at the end you can decide to use the model that performs best on the validation set. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would introduce that reader will be looking at two methods: params only and params and architecture. And would try to mention somewhere which method is recommended for certain situations. Currently not really discussed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
# Load the network architecture and parameters | ||
sym, arg_params, aux_params = mx.model.load_checkpoint('lenet', 1) | ||
# Create a Gluon Block using the loaded network architecture | ||
deserialized_net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('data')) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be good to explain data here. Where does the name come from?
docs/tutorials/index.md
Outdated
@@ -38,7 +38,7 @@ Select API: | |||
* [Visual Question Answering](http://gluon.mxnet.io/chapter08_computer-vision/visual-question-answer.html) <img src="https://upload.wikimedia.org/wikipedia/commons/6/6a/External_link_font_awesome.svg" alt="External link" height="15px" style="margin: 0px 0px 3px 3px;"/> | |||
* Practitioner Guides | |||
* [Multi-GPU training](http://gluon.mxnet.io/chapter07_distributed-learning/multiple-gpus-gluon.html) <img src="https://upload.wikimedia.org/wikipedia/commons/6/6a/External_link_font_awesome.svg" alt="External link" height="15px" style="margin: 0px 0px 3px 3px;"/> | |||
* [Checkpointing and Model Serialization (a.k.a. saving and loading)](http://gluon.mxnet.io/chapter03_deep-neural-networks/serialization.html) <img src="https://upload.wikimedia.org/wikipedia/commons/6/6a/External_link_font_awesome.svg" alt="External link" height="15px" style="margin: 0px 0px 3px 3px;"/> | |||
* [Saving and Loading Models](/tutorials/gluon/save_load_params.html) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could leave straight dope article, and add as alternative link? Unless it's misleading, in which case we should submit pr for straight dope.
# Create a Gluon Block using the loaded network architecture | ||
deserialized_net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('data')) | ||
# Set the parameters | ||
net_params = deserialized_net.collect_params() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
recently learned you can load the parameters like that:
deserialized_net.collect_params().load('lenet-0001.params')
rather than:
net_params = deserialized_net.collect_params()
for param in arg_params:
if param in net_params:
net_params[param]._load_init(arg_params[param], ctx=ctx)
for param in aux_params:
if param in net_params:
net_params[param]._load_init(aux_params[param], ctx=ctx)
```
|
||
```python | ||
# Load the network architecture and parameters | ||
sym, arg_params, aux_params = mx.model.load_checkpoint('lenet', 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we just need sym
I would suggest to rather use:
sym = mx.sym.load_json(open('lenet-symbol.json', 'r').read()
…arameters need to be loaded with Block.load_params()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Few nit comments. Will merge post changes.
Thanks for your contribution.
@@ -0,0 +1,269 @@ | |||
# Saving and Loading Gluon Models | |||
|
|||
Training large models take a lot of time and it is a good idea to save the trained models to files to avoid training them again and again. There is a number of reasons to do this. For example, you might want to do inference on a machine that is different from the one where the model was trained. Sometimes model's performance on validation set decreases towards the end of the training because of overfitting. If you saved your model parameters after every epoch, at the end you can decide to use the model that performs best on the validation set. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: There -are- a number of reasons..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: another motivation would be to separate research from production by using more research native Python for training and Scala/C++ in production inference.
# Train a given model using MNIST data | ||
def train_model(model): | ||
# Initialize the parameters with Xavier initializer | ||
net.collect_params().initialize(mx.init.Xavier(), ctx=ctx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model.collect_params?
net still works as it is a global parameter in your script.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. thanks!
# Use cross entropy loss | ||
softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss() | ||
# Use Adam optimizer | ||
trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': .001}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above. net->model
* Add tutorial to save and load parameters * Add outputs in markdown * Add image. Fix some formatting. * Add tutorial to index. Add to tests. * Minor language changes * Add download notebook button * Absorb suggestions for review * Add as alternate link * Use Symbol.load instead of model.load_checkpoint * Add a note discouraging the use of Block.collect_params().save() if parameters need to be loaded with Block.load_params() * Fix a bug. Also some language corrections.
* Add tutorial to save and load parameters * Add outputs in markdown * Add image. Fix some formatting. * Add tutorial to index. Add to tests. * Minor language changes * Add download notebook button * Absorb suggestions for review * Add as alternate link * Use Symbol.load instead of model.load_checkpoint * Add a note discouraging the use of Block.collect_params().save() if parameters need to be loaded with Block.load_params() * Fix a bug. Also some language corrections.
Description
Tutorial on saving and loading gluon models
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.