# Image Segmentation with Pre-trained Models

We can often leverage pre-trained models contributed by the community. The [MXNet Model Zoo](http://mxnet.io/model_zoo/) contains fast implementations of many state-of-the-art models, with pre-trained weights included.

In this tutorial, we demonstrate how to use a pre-trained network and perform prediction on a new image.

The task here is image segmentation, where the network learns to assign each pixel to a category, such as shown below:

<img src="images/seg_image.png">

Here we use a pre-trained segmentation model called FCN-xs (based on [this](https://people.eecs.berkeley.edu/~jonlong/long_shelhamer_fcn.pdf) publication). This is a convolutional neural network that was trained on the PASCAL VOC 2011 dataset, which includes 2,207 images similar to those above. 

Each image was annotated to segment twenty classes of objects: `person, bird, cat, cow, dog, horse, sheep, aeroplane, bicycle, boat, bus, car, motorbike, train, bottle, chair, dining table, potted plant, and tv monitor`.


# Prepare the Data

The required model files are hosted at https://bitbucket.org/krishnasumanthm/mxnet_image_segmentation

We first define a download function and download pre-trained model, symbol file and a test image.

Note: The pre-trained model is about 500MB, so first time downloading might take some time.

In [None]:
import logging
import numpy as np
import mxnet as mx
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os, urllib2, time
import warnings
warnings.simplefilter('ignore',DeprecationWarning)
import sys
from PIL import Image
import utils 

# Load the model

We first load the model by calling the `mx.model.load_checkpoint()` method. This loads the network definiton, which based on a `json` file, and the pre-trained weights.

In [None]:
model, params, states = mx.model.load_checkpoint('data/FCN8s_VGG16', epoch=19) 

The model's weights are stored in the `params` dictionary, and easy to inspect:

For example, here we compute the average magnitude of the weights in the first convolutional layer:

In [None]:
params['conv1_1_weight']  # this is an array of size (64, 3, 3, 3)

magnitude = np.mean(np.abs(params['conv1_1_weight'].asnumpy()))
print "First layer weights have average magnitude of: {}".format(magnitude)


# Preprocess the data for evaluation
Below we create a helper function to load the image, perform some minor preprocessing, and then convert the image to a numpy array.

In [None]:
def load_image(img_path):
    # Function to convert input image to np.array
    """get the (1, 3, h, w) np.array data for the img_path"""
    mean = np.array([123.68, 116.779, 103.939])  # (R,G,B)
    
    img = Image.open(img_path)
    img.thumbnail((800, 800), Image.ANTIALIAS)
    img = np.array(img, dtype=np.float32)
    reshaped_mean = mean.reshape(1, 1, 3)
    img = img - reshaped_mean
    img = np.swapaxes(img, 0, 2)
    img = np.swapaxes(img, 1, 2)
    img = np.expand_dims(img, axis=0)
    return img

# Evaluate the model

Then, we create a function below that:
1. Tells the model what the input image size to expect.
2. Creates an executor, while binding the loaded parameters and states
3. Runs the inference pass via `forward()` method.
4. Saves the images and display in notebook.



In [None]:
def predict(img_path, model, params, states):
    img = load_image(img_path)
    
    # Input images are stored in a 4-D matrix 
    params["data"] = mx.nd.array(img, mx.cpu()) 
    data_shape = params["data"].shape
    
    # Output image has the same number of pixels
    label_shape = (1, data_shape[2]*data_shape[3])
    params["softmax_label"] = mx.nd.empty(label_shape, mx.cpu())
    
    # create an 'executor' and bind the parameters and states
    executor = model.bind(mx.cpu(), params, aux_states=states)

    tic = time.time()

    # run the inference pass
    executor.forward(is_train=False)
    print "Time taken for forward pass: {:.3f} milli sec".format((time.time()-tic)*1000)

    # save the images and print the results
    print "Saving images..."
    utils.print_images(executor.outputs[0], img_path)



In [None]:
predict("person_bicycle.jpg", model, params, states)