Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-307] Fix flaky tutorial tests from CI #10956

Merged
merged 14 commits into from May 17, 2018
5 changes: 3 additions & 2 deletions docs/tutorials/index.md
Expand Up @@ -83,16 +83,17 @@ Select API: 
* [Movie Review Classification using Convolutional Networks](/tutorials/nlp/cnn.html)
* [Generative Adversarial Networks (GANs)](/tutorials/unsupervised_learning/gan.html)
* [Recommender Systems using Matrix Factorization](/tutorials/python/matrix_factorization.html)
* [Speech Recognition with Connectionist Temporal Classification Loss](https://mxnet.incubator.apache.org/tutorials/speech_recognition/ctc.html)
* [Speech Recognition with Connectionist Temporal Classification Loss](/tutorials/speech_recognition/ctc.html)
* Practitioner Guides
* [Predicting on new images using a pre-trained ImageNet model](/tutorials/python/predict_image.html)
* [Fine-Tuning a pre-trained ImageNet model with a new dataset](/faq/finetune.html)
* [Large-Scale Multi-Host Multi-GPU Image Classification](/tutorials/vision/large_scale_classification.html)
* API Guides
* Core APIs
* NDArray
* [NDArray API](/tutorials/gluon/ndarray.html)
* [Advanced NDArray API](/tutorials/basic/ndarray.html)
* [NDArray Indexing](https://mxnet.incubator.apache.org/tutorials/basic/ndarray_indexing.html)
* [NDArray Indexing](/tutorials/basic/ndarray_indexing.html)
* Sparse NDArray
* [Sparse Gradient Updates (RowSparseNDArray)](/tutorials/sparse/row_sparse.html)
* [Compressed Sparse Row Storage Format (CSRNDArray)](/tutorials/sparse/csr.html)
Expand Down
22 changes: 11 additions & 11 deletions docs/tutorials/onnx/fine_tuning_gluon.md
Expand Up @@ -40,7 +40,7 @@ logging.basicConfig(level=logging.INFO)


### Downloading supporting files
These are images and a vizualisation script
These are images and a vizualisation script:


```python
Expand All @@ -59,12 +59,12 @@ from utils import *

## Downloading a model from the ONNX model zoo

We download a pre-trained model, in our case the [vgg16](https://arxiv.org/abs/1409.1556) model, trained on [ImageNet](http://www.image-net.org/) from the [ONNX model zoo](https://github.com/onnx/models). The model comes packaged in an archive `tar.gz` file containing an `model.onnx` model file and some sample input/output data.
We download a pre-trained model, in our case the [GoogleNet](https://arxiv.org/abs/1409.4842) model, trained on [ImageNet](http://www.image-net.org/) from the [ONNX model zoo](https://github.com/onnx/models). The model comes packaged in an archive `tar.gz` file containing an `model.onnx` model file.


```python
base_url = "https://s3.amazonaws.com/download.onnx/models/"
current_model = "vgg16"
base_url = "https://s3.amazonaws.com/download.onnx/models/opset_3/"
current_model = "bvlc_googlenet"
model_folder = "model"
archive_file = "{}.tar.gz".format(current_model)
archive_path = os.path.join(model_folder, archive_file)
Expand Down Expand Up @@ -230,15 +230,15 @@ sym.get_internals()



```<Symbol group [gpu_0/data_0, gpu_0/conv1_w_0, gpu_0/conv1_b_0, convolution0, relu0, lrn0, pad0, pooling0, gpu_0/conv2_w_0, gpu_0/conv2_b_0, convolution1, relu1, lrn1, pad1, pooling1, gpu_0/conv3_w_0, gpu_0/conv3_b_0, convolution2, relu2, gpu_0/conv4_w_0, gpu_0/conv4_b_0, convolution3, relu3, gpu_0/conv5_w_0, gpu_0/conv5_b_0, convolution4, relu4, pad2, pooling2, flatten0, gpu_0/fc6_w_0, linalg_gemm20, gpu_0/fc6_b_0, _mulscalar0, broadcast_add0, relu5, flatten1, gpu_0/fc7_w_0, linalg_gemm21, gpu_0/fc7_b_0, _mulscalar1, broadcast_add1, relu6, flatten2, gpu_0/fc8_w_0, linalg_gemm22, gpu_0/fc8_b_0, _mulscalar2, broadcast_add2, softmax0]>```<!--notebook-skip-line-->
```<Symbol group [data_0, pad0, conv1/7x7_s2_w_0, conv1/7x7_s2_b_0, convolution0, relu0, pad1, pooling0, lrn0, pad2, conv2/3x3_reduce_w_0, conv2/3x3_reduce_b_0, convolution1, relu1, pad3, conv2/3x3_w_0, conv2/3x3_b_0, convolution2, relu2, lrn1, pad4, pooling1, pad5, inception_3a/1x1_w_0, inception_3a/1x1_b_0, convolution3, relu3, pad6, .................................................................................inception_5b/pool_proj_b_0, convolution56, relu56, concat8, pad70, pooling13, dropout0, flatten0, loss3/classifier_w_0, linalg_gemm20, loss3/classifier_b_0, _mulscalar0, broadcast_add0, softmax0]>```<!--notebook-skip-line-->



We get the network until the output of the `relu6` layer
We get the network until the output of the `flatten0` layer


```python
new_sym, new_arg_params, new_aux_params = get_layer_output(sym, arg_params, aux_params, 'relu6')
new_sym, new_arg_params, new_aux_params = get_layer_output(sym, arg_params, aux_params, 'flatten0')
```

### Fine-tuning in gluon
Expand All @@ -258,7 +258,7 @@ We create a symbol block that is going to hold all our pre-trained layers, and a


```python
pre_trained = gluon.nn.SymbolBlock(outputs=new_sym, inputs=mx.sym.var('gpu_0/data_0'))
pre_trained = gluon.nn.SymbolBlock(outputs=new_sym, inputs=mx.sym.var('data_0'))
net_params = pre_trained.collect_params()
for param in new_arg_params:
if param in net_params:
Expand Down Expand Up @@ -299,7 +299,7 @@ Initialize trainer with common training parameters


```python
LEARNING_RATE = 0.001
LEARNING_RATE = 0.0005
WDECAY = 0.00001
MOMENTUM = 0.9
```
Expand Down Expand Up @@ -349,7 +349,7 @@ print("Untrained network Test Accuracy: {0:.4f}".format(evaluate_accuracy_gluon(

```python
val_accuracy = 0
for epoch in range(20):
for epoch in range(5):
for i, (data, label) in enumerate(dataloader_train):
data = data.astype(np.float32).as_in_context(ctx)
label = label.as_in_context(ctx)
Expand Down Expand Up @@ -430,4 +430,4 @@ plot_predictions(caltech101_images_test, result, categories, TOP_P)

**Great!** The network classified these images correctly after being fine-tuned on a dataset that contains images of `wrench`, `dolphin` and `lotus`

<!-- INSERT SOURCE DOWNLOAD BUTTONS -->
<!-- INSERT SOURCE DOWNLOAD BUTTONS -->
62 changes: 10 additions & 52 deletions docs/tutorials/onnx/inference_on_onnx_model.md
Expand Up @@ -51,12 +51,12 @@ from utils import *

## Downloading a model from the ONNX model zoo

We download a pre-trained model, in our case the [vgg16](https://arxiv.org/abs/1409.1556) model, trained on [ImageNet](http://www.image-net.org/) from the [ONNX model zoo](https://github.com/onnx/models). The model comes packaged in an archive `tar.gz` file containing an `model.onnx` model file and some sample input/output data.
We download a pre-trained model, in our case the [GoogleNet](https://arxiv.org/abs/1409.4842) model, trained on [ImageNet](http://www.image-net.org/) from the [ONNX model zoo](https://github.com/onnx/models). The model comes packaged in an archive `tar.gz` file containing an `model.onnx` model file.


```python
base_url = "https://s3.amazonaws.com/download.onnx/models/"
current_model = "vgg16"
base_url = "https://s3.amazonaws.com/download.onnx/models/opset_3/"
current_model = "bvlc_googlenet"
model_folder = "model"
archive = "{}.tar.gz".format(current_model)
archive_file = os.path.join(model_folder, archive)
Expand Down Expand Up @@ -97,11 +97,11 @@ We get the symbol and parameter objects
sym, arg_params, aux_params = onnx_mxnet.import_model(onnx_path)
```

We pick a context, GPU if available, otherwise CPU
We pick a context, CPU is fine for inference, switch to mx.gpu() if you want to use your GPU.


```python
ctx = mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()
ctx = mx.cpu()
```

We obtain the data names of the inputs to the model by using the model metadata API:
Expand All @@ -121,14 +121,13 @@ data_names = [inputs[0] for inputs in model_metadata.get('input_tensor_data')]
print(data_names)
```

```
[u'gpu_0/data_0']
```

```[u'data_0']```<!--notebook-skip-line-->

And load them into a MXNet Gluon symbol block.

```python
net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('gpu_0/data_0'))
net = gluon.nn.SymbolBlock(outputs=sym, inputs=mx.sym.var('data_0'))
net_params = net.collect_params()
for param in arg_params:
if param in net_params:
Expand All @@ -146,30 +145,6 @@ We can now cache the computational graph through [hybridization](https://mxnet.i
net.hybridize()
```

## Test using sample inputs and outputs
The model comes with sample input/output we can use to test that whether model is correctly loaded


```python
numpy_path = os.path.join(model_folder, current_model, 'test_data_0.npz')
sample = np.load(numpy_path, encoding='bytes')
inputs = sample['inputs']
outputs = sample['outputs']
```


```python
print("Input format: {}".format(inputs[0].shape))
print("Output format: {}".format(outputs[0].shape))
```

`Input format: (1, 3, 224, 224)` <!--notebook-skip-line-->


`Output format: (1, 1000)` <!--notebook-skip-line-->



We can visualize the network (requires graphviz installed)


Expand All @@ -178,9 +153,7 @@ mx.visualization.plot_network(sym, node_attrs={"shape":"oval","fixedsize":"fals
```




![png](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/network.png?raw=true)<!--notebook-skip-line-->
![png](https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/network2.png?raw=true)<!--notebook-skip-line-->



Expand All @@ -196,21 +169,6 @@ def run_batch(net, data):
return np.array(results)
```


```python
result = run_batch(net, nd.array([inputs[0]], ctx))
```


```python
print("Loaded model and sample output predict the same class: {}".format(np.argmax(result) == np.argmax(outputs[0])))
```

Loaded model and sample output predict the same class: True <!--notebook-skip-line-->


Good the sample output and our prediction match, now we can run against real data

## Test using real images


Expand Down Expand Up @@ -274,4 +232,4 @@ We show that in our next tutorial:

- [Fine-tuning an ONNX Model using the modern imperative MXNet/Gluon](http://mxnet.incubator.apache.org/tutorials/onnx/fine_tuning_gluon.html)

<!-- INSERT SOURCE DOWNLOAD BUTTONS -->
<!-- INSERT SOURCE DOWNLOAD BUTTONS -->
87 changes: 49 additions & 38 deletions docs/tutorials/python/predict_image.md
@@ -1,33 +1,28 @@
# Predict with pre-trained models

This tutorial explains how to recognize objects in an image with a
pre-trained model, and how to perform feature extraction.
This tutorial explains how to recognize objects in an image with a pre-trained model, and how to perform feature extraction.

## Prerequisites

To complete this tutorial, we need:

- MXNet. See the instructions for your operating system in [Setup and Installation](http://mxnet.io/install/index.html)

- [Python Requests](http://docs.python-requests.org/en/master/), [Matplotlib](https://matplotlib.org/) and [Jupyter Notebook](http://jupyter.org/index.html).
- [Matplotlib](https://matplotlib.org/) and [Jupyter Notebook](http://jupyter.org/index.html).

```
$ pip install requests matplotlib jupyter opencv-python
$ pip install matplotlib
```

## Loading

We first download a pre-trained ResNet 152 layer that is trained on the full
ImageNet dataset with over 10 million images and 10 thousand classes. A
pre-trained model contains two parts, a json file containing the model
definition and a binary file containing the parameters. In addition, there may be
a text file for the labels.
We first download a pre-trained ResNet 18 model that is trained on the ImageNet dataset with over 1 million images and one thousand classes. A pre-trained model contains two parts, a json file containing the model definition and a binary file containing the parameters. In addition, there may be a `synset.txt` text file for the labels.

```python
import mxnet as mx
path='http://data.mxnet.io/models/imagenet-11k/'
[mx.test_utils.download(path+'resnet-152/resnet-152-symbol.json'),
mx.test_utils.download(path+'resnet-152/resnet-152-0000.params'),
path='http://data.mxnet.io/models/imagenet/'
[mx.test_utils.download(path+'resnet/18-layers/resnet-18-0000.params'),
mx.test_utils.download(path+'resnet/18-layers/resnet-18-symbol.json'),
mx.test_utils.download(path+'synset.txt')]
```

Expand All @@ -39,7 +34,7 @@ ctx = mx.cpu()
```

```python
sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-152', 0)
sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-18', 0)
mod = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))],
label_shapes=mod._label_shapes)
Expand All @@ -56,7 +51,6 @@ prediction:
```python
%matplotlib inline
import matplotlib.pyplot as plt
import cv2
import numpy as np
# define a simple data batch
from collections import namedtuple
Expand All @@ -65,23 +59,22 @@ Batch = namedtuple('Batch', ['data'])
def get_image(url, show=False):
# download and show the image
fname = mx.test_utils.download(url)
img = cv2.cvtColor(cv2.imread(fname), cv2.COLOR_BGR2RGB)
img = mx.image.imread(fname)
if img is None:
return None
return None
if show:
plt.imshow(img)
plt.axis('off')
plt.imshow(img.asnumpy())
plt.axis('off')
# convert into format (batch, RGB, width, height)
img = cv2.resize(img, (224, 224))
img = np.swapaxes(img, 0, 2)
img = np.swapaxes(img, 1, 2)
img = img[np.newaxis, :]
img = mx.image.imresize(img, 224, 224) # resize
img = img.transpose((2, 0, 1)) # Channel first
img = img.expand_dims(axis=0) # batchify
return img

def predict(url):
img = get_image(url, show=True)
# compute the predict probabilities
mod.forward(Batch([mx.nd.array(img)]))
mod.forward(Batch([img]))
prob = mod.get_outputs()[0].asnumpy()
# print the top-5
prob = np.squeeze(prob)
Expand All @@ -96,31 +89,48 @@ Now, we can perform prediction with any downloadable URL:
predict('https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/python/predict_image/cat.jpg?raw=true')
```

`probability=0.249607, class=n02119022 red fox, Vulpes vulpes` <!--notebook-skip-line-->

`probability=0.172868, class=n02119789 kit fox, Vulpes macrotis` <!--notebook-skip-line-->

![](https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/python/predict_image/cat.jpg?raw=true) <!--notebook-skip-line-->

```python
predict('https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/python/predict_image/dog.jpg?raw=true')
```

`probability=0.873920, class=n02110958 pug, pug-dog` <!--notebook-skip-line-->

`probability=0.102659, class=n02108422 bull mastiff` <!--notebook-skip-line-->

![](https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/python/predict_image/dog.jpg?raw=true) <!--notebook-skip-line-->

## Feature extraction

By feature extraction, we mean presenting the input images by the output of an
internal layer rather than the last softmax layer. These outputs, which can be
viewed as the feature of the raw input image, can then be used by other
applications such as object detection.
By feature extraction, we mean presenting the input images by the output of an internal layer rather than the last softmax layer. These outputs, which can be viewed as the feature of the raw input image, can then be used by other applications such as object detection.

We can use the ``get_internals`` method to get all internal layers from a
Symbol.
We can use the ``get_internals`` method to get all internal layers from a Symbol.

```python
# list the last 10 layers
all_layers = sym.get_internals()
all_layers.list_outputs()[-10:]
```

An often used layer for feature extraction is the one before the last fully
connected layer. For ResNet, and also Inception, it is the flattened layer with
name `flatten0` which reshapes the 4-D convolutional layer output into 2-D for
the fully connected layer. The following source code extracts a new Symbol which
outputs the flattened layer and creates a model.
```
['bn1_moving_var',
'bn1_output',
'relu1_output',
'pool1_output',
'flatten0_output',
'fc1_weight',
'fc1_bias',
'fc1_output',
'softmax_label',
'softmax_output']
```

An often used layer for feature extraction is the one before the last fully connected layer. For ResNet, and also Inception, it is the flattened layer with name `flatten0` which reshapes the 4-D convolutional layer output into 2-D for the fully connected layer. The following source code extracts a new Symbol which outputs the flattened layer and creates a model.

```python
fe_sym = all_layers['flatten0_output']
Expand All @@ -133,10 +143,11 @@ We can now invoke `forward` to obtain the features:

```python
img = get_image('https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/python/predict_image/cat.jpg?raw=true')
fe_mod.forward(Batch([mx.nd.array(img)]))
features = fe_mod.get_outputs()[0].asnumpy()
print(features)
assert features.shape == (1, 2048)
fe_mod.forward(Batch([img]))
features = fe_mod.get_outputs()[0]
print('Shape',features.shape)
print(features.asnumpy())
assert features.shape == (1, 512)
```

<!-- INSERT SOURCE DOWNLOAD BUTTONS -->
4 changes: 2 additions & 2 deletions tests/tutorials/test_tutorials.py
Expand Up @@ -79,12 +79,12 @@ def _test_tutorial_nb(tutorial):
os.makedirs(working_dir)
try:
notebook = nbformat.read(tutorial_path + '.ipynb', as_version=IPYTHON_VERSION)
time.sleep(0.5) # Adding a small delay to allow time for sockets to be freed
if kernel is not None:
eprocessor = ExecutePreprocessor(timeout=TIME_OUT, kernel_name=kernel)
else:
eprocessor = ExecutePreprocessor(timeout=TIME_OUT)
nb, stuff = eprocessor.preprocess(notebook, {'metadata': {'path': working_dir}})
print(stuff)
nb, _ = eprocessor.preprocess(notebook, {'metadata': {'path': working_dir}})
except Exception as err:
err_msg = str(err)
errors.append(err_msg)
Expand Down