Skip to content

Getting started: using the new features of MIGraphX 0.2

Shucai Xiao edited this page Jun 25, 2020 · 32 revisions

New Features in MIGraphX 0.2

MIGraphX 0.2 supports the following new features:

  • New Python API
  • Support for additional ONNX operators and fixes that now enable a large set of Imagenet models
  • Support for RNN Operators
  • Support for multi-stream Execution
  • [Experimental] Support for Tensorflow frozen protobuf files

This page provides examples of how to use these new features.

Python API

MIGraphX functionality can now be called from Python as well as C++. This support is illustrated with an example below of a "webcam classifier". The classifier uses OpenCV Python modules to capture images from a webcam, reforms the image to NCHW format and then uses MIGraphX to evaluate these image using an Imagenet-based neural network. The result is a stream of classifications of what is seen in the webcam.

The first release of Python API is for python 2.7.

Prerequisites

Prior to running this example, one needs to install OpenCV. On Ubuntu this can be done by installing the following package

prompt% apt install python-opencv

The PYTHONPATH variable should be set by the package installation scripts. However, if necessary it can be set using

export PYTHONPATH=/opt/rocm/lib:$PYTHONPATH

Python code example

Our Python code example starts with setup code for the webcam. In this particular example we capture a small image size (240x320) that we will later and crop to an Imagenet size (CHW, 3 Channels, 224 Height, 224 Width) and representing as float32 values instead of int8.

import numpy as np
import cv2
# video settings
cap = cv2.VideoCapture(0)
cap.set(cv2.cv.CV_CAP_PROP_FRAME_WIDTH,320)
cap.set(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT,240)

An additional piece of setup is to read in Imagenet labels from file that stores them in Json format

import json
# get labels
with open('imagenet_class_index.json') as json_data:
   class_idx = json.load(json_data)
   idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]

With OpenCV and labels set up we now initialize the MIGraphX interface. The first step is to read a model including weights from an ONNX file

model = migraphx.parse_onnx("resnet50.onnx")

The next step is to "compile" the model. The compilation step runs optimization passes on the model and also loads constant parameter weights to the GPU memory.

model.compile(migraphx.get_target("gpu"))

While the compilation step has loaded model weights, we also need to allocate GPU memory for input, output and scratch parameters found in the model. This is accomplished with the following code

# allocate space on the GPU for model parameters
params = {}
for key,value in model.get_parameter_shapes().items():
   params[key] = migraphx.allocate_gpu(value)

With these steps complete, we now get to the primary loop that will capture images from a webcam, manipulate them to Imagenet format and call MIGraphX to evaluate the model.

while (True):
   # capture frame by frame
   ret,frame = cap.read()

   if ret: # check - some webcams need warmup operations

The following steps process the captured frame to an image for the Resnet50 model

      cropped = frame[16:304,8:232]    # convert to 224x224
      trans = cropped.transpose(2,0,1) # convert HWC to CHW
      image = np.ascontiguousarray(    # contiguous to feed to migraphx initializer
         np.expand_dims(               # change CHW to NCHW
            trans.astype('float32')/256.0,0))  # convert int8 to float32

The following creates a window to display webcam frames in int8 format before conversion

      cv2.imshow('frame',cropped)

The following code copies the converted frame to the GPU.

      params['0'] = migraphx.to_gpu(migraphx.argument(image))

The following code runs the model, returns the result and puts it in a numpy array

      result = np.array(migraphx.from_gpu(model.run(params),copy=False)))

The result for the Resnet50 model is an array of 1000 elements containing probabilities. We find the highest probability, look up the label name and print it as output

      idx = np.argmax(result[0])
      print idx2label[idx]

The last part of the code looks for a 'q' key to be pressed to exit the program

   if cv2.waitKey(1) & 0xFF == ord('q'):
      break

Outside the loop we close up OpenCV context and exit the program

# when all is done, release the capture
cap.release()
cv2.destroyAllWindows()

Overall, this program provides an end-to-end example using the MIGRaphX Python API including parsing ONNX files, compiling models, loading parameters to memory and running programs. Following is the complete program example

import numpy as np
import cv2
import json
import migraphx

# video settings
cap = cv2.VideoCapture(0)
cap.set(cv2.cv.CV_CAP_PROP_FRAME_WIDTH, 320)
cap.set(cv2.cv.CV_CAP_PROP_FRAME_HEIGHT, 240)
ret, frame = cap.read()

# neural network settings
model = migraphx.parse_onnx("resnet50.onnx")
model.compile(migraphx.get_target("gpu"))

# allocate space on the GPU using randomly generated parameters
params = {}
for key,value in model.get_parameter_shapes().items():
    print("Parameter {} -> {}".format(key,value))
    params[key] = migraphx.allocate_gpu(value)

# get labels
with open('imagenet_class_index.json') as json_data:
    class_idx = json.load(json_data)
    idx2label = [class_idx[str(k)][1] for k in range(len(class_idx))]

# primary loop to read webcam images
count = 0
while (True):
    # capture frame by frame
    ret, frame = cap.read()

    if ret: # check - some webcams need warmup operations on the frame
        cropped = frame[16:304,8:232]    # 224x224

        trans = cropped.transpose(2,0,1) # convert HWC to CHW

        # convert to float, normalize and make batch size = 1
        image = np.ascontiguousarray(
            np.expand_dims(trans.astype('float32')/256.0,0))

        # display the frame
        cv2.imshow('frame',cropped)

        params['0'] = migraphx.to_gpu(migraphx.argument(image))

        tmp_result = migraphx.from_gpu(model.run(params))

        result = np.array(migraphx.from_gpu(model.run(params)),copy=False)

        idx = np.argmax(result[0])

        print idx2label[idx], " ", result[0][idx]

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# when everything is done, release the capture
cap.release()
cv2.destroyAllWindows()

Additional Operator Support for new Models

The list of ONNX operators supported by MIGraphX is listed below. Those in bold are new in release 0.2:

  • Abs
  • Acos
  • Add
  • Asin
  • Atan
  • AveragePool
  • BatchNormalization
  • Concat
  • Constant
  • Conv
  • Cos
  • Cosh
  • Div
  • Dropout
  • Elu
  • Exp
  • Flatten
  • GRU
  • Gather
  • Gemm
  • GlobalAveragePool
  • GlobalMaxPool
  • Identity
  • LRN
  • LeakyRelu
  • Log
  • LogSoftmax
  • LSTM
  • MatMul
  • Max
  • MaxPool
  • Min
  • Mul
  • Pad
  • RNN
  • Relu
  • Reshape
  • Shape
  • Sigmoid
  • Sin
  • Sinh
  • Slice
  • Softmax
  • Squeeze
  • Sub
  • Sum
  • Tan
  • Tanh
  • Transpose
  • Unsqueeze
  • ConstantFill
  • ImageScalar

As a result of adding these operators, a much larger set of models has been shown to work with MIGraphX. Two examples of model collections tried with MIGraphX are the Torchvision PyTorch models and collection of pretrained PyTorch models.

Example of generating ONNX file from pretrained model

The following code example shows how PyTorch 0.4.0 can be used to generate an ONNX file for use with MIGraphX

import torch
import torchvision.models as models
batchsize = 64
resnet50 = models.resnet50(pretrained=True)
resnet50.eval()
torch.onnx.export(resnet50,torch.randn(batchsize,3,224,224),'resnet50.onnx')

RNN Operator support

Recursive Neural Network (RNN) is an important type of neural network. To support RNN, ONNX introduced three operators (Vallina RNN, GRU, and LSTM). MIGraphX 0.2 added these operators.

In this section, we use the pytorch sequence to sequence (seq2seq) pytorch example as a demo to show how to use MIGraphX to do inference from the models exported from pytorch.

The pytorch seq2seq example is to translate French to English sentence by sentence. It contains a encoder and decoder, with each of them being a RNN model. In both the encoder and the decoder, the GRU operator is used. The example works as follows: it reads a French sentence, encoded into a sequence of indices based on a dictionary (in the class Lang) created from the French. Then the index sequence is used as an input to the encoder, whose output is used as the input to the decoder. Next, the decoder generates a sequence of indices. By looking up a English dictionary, it outputs the English sentence.

In the pytorch example, it first performs the training with the train() function. The encoder and decoder models are trained by calling:

hidden_size = 256
encoder1 = EncoderRNN(input_lang.n_words, hidden_size).to(device)
attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words, dropout_p = 0.1).to(device)
trainIters(encoder1, attn_decoder1, 75000, print_every=5000)

For inference, it calls evalute() like:

def evaluateRandomly(encoder, decoder, n = 10):
    for i in range(n):
        pair = random.choice(pairs)
        print('>', pair[0])
        print('=', pair[1])
        output_words, attentions = evaluate(encoder, decoder, pair[0])
        output_sentence = ' '.join(output_words)
        print('<', output_sentence)
        print('')

Our MIGraphX example is to implement the inference functionality based on the exported encoder and decoder models. We implemented the inference in both C++ and python using the following steps. Specifically, we export the trained models to ONNX files, then import the models to MIGraphX, and wrap up the inputs and generate output in the MIGraphX way.

One note is the decoder model has multiple outputs that MIGraphX does not support for now, so we changed the pytorch decoder model to generate only one useful output (by concatenating multiple outputs into one, then split them after we retrieved the output). Specifically, we concatenate the hidden state output of the GRU operator and the output of the LogSoftmax operator. After we get the single output, we split it into two according to their size. The code changes in the class AttnDecoderRNN(nn.Module) are as follows (the decoder still has two outputs, but the first one is enough for us):

-        output = F.log_softmax(self.out(output[0]), dim=1)
-        return output, hidden, attn_weights
+        uns_output = torch.unsqueeze(output, 0)
+        output = torch.cat((uns_output, hidden), 2)
+        return output, attn_weights

Then in the train() function and the evaluate() function, the changes are (concatenation in the model, then do a split for the model output):

-            decode_output, decoder_hidden, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs)
+            output, decoder_attention = decoder(decoder_input, decoder_hidden, encoder_outputs)
+            decoder_output = output.narrow(2, 0, output_lang.n_words)
+            decoder_output = torch.squeeze(decoder_output, 0)
+            decoder_hidden = ouuput.narrow(2, output_lang.n_words, hidden_size)

Another note is the pytorch example does some pre-processing for the input data file to remove unnecessary information and remove sentences that are too long. To simplify our implementation, we dumped processed data from the pytorch example into a data file and use it as input to our MIGraphX example. We provided a python program (pre_processing.py, stored in the data folder in the seq2seq.tar file) for the pre-processing, so users can just download the original data file and generate the data used by our MIGraphX example.

With the above code changes, we can export the onnx files of the encoder and decoder models by calling the function torch.onnx.export() as in the previous section. Specifically, we add the following lines at the end of the pytorch example:

#export encoder onnx file
input_tensor, target_tensor = tensorsFromPair(random.choice(pairs))
encoder_hidden = encoder1.initHidden()
torch.onnx.export(encoder1, (input_tensor[0], encoder_hidden), "s2s_encoder.onnx", verbose=True)

#export decoder onnx file
decoder_input = torch.tensor([[SOS_token]], device = device)
decoder_hidden = encoder_hidden
encoder_outputs = torch.randn(MAX_LENGTH, encoder1.hidden_size)
torch.onnx.export(attn_decoder1, (decoder_input, decoder_hidden, encoder_outputs), "s2s_decoder.onnx", verbose=True)

After the training is done, we get the encoder onnx file s2s_encoder.onnx and the decoder onnx file s2s_decoder.onnx.

We ported the evalute() function to the MIGraphX example as follows:

  1. Inputs are the pre-processed data file with the name eng-fra_procd.txt in seq2seq/data, and onnx files of the encoder (s2s_encoder.onnx) and the decoder (s2s_decoder.onnx) in seq2seq/model.

  2. In the example, we implemented a class CLanguage (language.cpp) by following the algorithm in the python class Lang in the pytorch example. With this class, the example first load the data file line by line with each line containing a pair of sentences in English and in French with the same meaning. Then this class encodes each sentence word by word by calling the function CLanguage::add_sentence() as:

    while (std::getline(ifs, line, '\n'))
    {
        std::size_t pos = 0;
        pos = line.find('\t', pos);
        //std::cout << "line " << line_index++ << ": " << line << std::endl;
        std::string input_sent{}, output_sent{};
        if (pos != std::string::npos)
        {
            input_sent = line.substr(0, pos);
            output_sent = line.substr(pos + 1);
        }
        else
        {
            std::cout << "File " << file_name << ", Line " << line_index;
            std::cout << " does not contain two languages!" << std::endl;
            return 1;
        }

        all_sentences.push_back(std::make_pair(input_sent, output_sent));
        input_lang.add_sentence(input_sent);
        output_lang.add_sentence(output_sent);
    }
  1. The example load the encoder and decoder models by calling the function load_onnx_file(file_name), which returns two MIGraphX programs for the encoder and decoder, respectively.
    migraphx::program encoder = load_onnx_file(argv[1]); // argv[1] is s2s_encoder.onnx
    migraphx::program decoder = load_onnx_file(argv[2]); // argv[2] is s2s_decoder.onnx
  1. After compiling the program on either CPU or GPU, we call the function evaluate_cpu/gpu() to do the inference, i.e. translation. Using the evaluate_cpu() as an example, it performs:

    a. For the input sentence, it convert its words to indices by calling the function

    auto input_indices = input_lang.get_sentence_indices(sent);

The vector input_indices is used as the input to the encoder, one element for each call to the encoder as:

    // run the encoder
    for (std::size_t i = 0; i < input_len; ++i)
    {
        migraphx::program::parameter_map m;
        for (auto&& x : encoder.get_parameter_shapes())
        {
            if (x.first == "input.1")
            {   
                m[x.first] = migraphx::argument(x.second, &input_indices.at(i));
            }
            else if (x.first == "hidden")
            {   
                m[x.first] = migraphx::argument(x.second, encoder_hidden.data());
            }
            else
            {   
                m[x.first] = migraphx::generate_argument(x.second, get_hash(x.first));
            }
        }

        auto concat_hiddens = encoder.eval(m);
        // from the encoder source code, seq_size is 1, so output is the
        // same as the hidden states
        concat_hiddens.visit([&](auto output) { encoder_hidden.assign(output.begin(), output.end()); });
        
        encoder_outputs.insert(encoder_outputs.end(), encoder_hidden.begin(), encoder_hidden.end());
    }

The encoder hidden_state output is stored in the vector encoder_outputs one after another, and it is used as an input to the decoder. The output hidden_state is also used as the input to the encoder to process the next word, which is what an RNN is (output is used as input for processing the next input).

b. For the decoder, it has three inputs, the first token "SOS_token" of a sentence, an initial hidden state, and the encoder output encoder_outputs. The code of the encoder processing is:

    std::vector<long> decoder_input{SOS_token};
    std::vector<float> decoder_hidden(encoder_hidden);
    std::vector<std::string> decoder_words{};

    for (std::size_t i = 0; i < max_sent_len; ++i)
    {
        migraphx::program::parameter_map m;
        for (auto&& x : decoder.get_parameter_shapes())
        {
            if (x.first == "input.1")
            {
                m[x.first] = migraphx::argument(x.second, decoder_input.data());
            }
            else if (x.first == "hidden")
            {
                m[x.first] = migraphx::argument(x.second, decoder_hidden.data());
            }
            else if (x.first == "2")
            {
                m[x.first] = migraphx::argument(x.second, encoder_outputs.data());
            }
            else
            {
                m[x.first] = migraphx::generate_argument(x.second, get_hash(x.first));
            }
        }

        auto outputs_arg = decoder.eval(m);
        std::vector<float> outputs;
        outputs_arg.visit([&](auto output) { outputs.assign(output.begin(), output.end()); });

        std::vector<float> decoder_output(outputs.begin(), outputs.begin() + output_lang.get_word_num());
        decoder_hidden.clear();
        decoder_hidden.assign(outputs.begin() + output_lang.get_word_num(), outputs.end());

        // compute the words from the decoder output
        std::size_t max_index = std::distance(decoder_output.begin(), std::max_element(decoder_output.begin(),
                    decoder_output.end()));
        if (max_index == static_cast<std::size_t>(EOS_token))
        {
            break;
        }
        else
        {
            decoder_words.push_back(output_lang.get_word(max_index));
            decoder_input.at(0) = static_cast<long>(max_index);
        }
    }

(Note that the decoder input encoder_outputs should contain the (max_sent_len * hidden_size) elements since the decoder needs such number of inputs. If the element number is less, append 0s at the end.)

For its output, it is the concatenation of the output hidden state and the output from the LogSoftmax operator to indicate which word it is (by selecting the index with the max logsoftmax value). Finally, all words are concatenated to form the output English sentence.

The whole MIGraphX program is attached here as seq2seq.tar. It contains the encoder and decoder onnx files in the folder seq2seq/model. You can download the original data file and save it to the folder seq2seq/data. This folder also contains the python program (pre_processing.py) that does pre-processing to the original data file and generates the input file eng-fra_procd.txt of the MIGraphX example:

python pre_processing.py

To build the program, we need to set the following three lines in the CMakeLists.txt. For example, my settings are:

set (MIGRAPHX_FOLDER /home/scxiao/Workplace/projects/AMDMIGraphX)
set (MIGRAPHX_BUILD  /home/scxiao/Workplace/projects/AMDMIGraphX/build)
set (MIGRAPHX_DEPS   /home/scxiao/Workplace/projects/AMDMIGraphX/deps_py/lib)

You need to change the three folders to whatever location in your computer. Then we can build the example as:

cd ${XXX}/seq2seq
mkdir build
cd build
cmake ..
make

If build correctly, we will get the binary seq2seq_trans, and we can run the example as:

./seq2seq_trans ../model/s2s_encoder.onnx ../model/s2s_decoder.onnx fra eng cpu    

or you can change the last argument to gpu to run on GPU, which is much faster compared to CPU.

The python version implementation is located at ${XXX}/python/seq2seq_translation.py, and you can run it using the same command line arguments as the C++ version, i.e.

python seq2seq_translation.py ../model/s2s_encoder.onnx ../model/s2s_decoder.onnx fra eng cpu(gpu)

The python program generates the same results as the C++ program.

The changed pytorch file is also attached as seq2seq.py that can do the training and export the encoder and decoder onnx files.

From Rel0.6, MIGraphX add a feature to support multiple program outputs, so the above changes to make a model have only one output is unnecessary anymore. We can use the model trained from the seq2seq example directly, and we changed the inference program seq2seq.tar accordinly.

An Overview on Tensorflow Protobuf Support in 0.2 (Experimental)

Beginning with 0.2, we support tensorflow protobuf files with the following constraints:

  1. The following operators are supported:
  • Add
  • AvgPool
  • BiasAdd
  • ConcatV2
  • Const
  • Conv2D
  • FusedBatchNorm
  • Identity
  • MaxPool
  • Mean (used for GlobalAvgPool only)
  • Pad (zero values only)
  • Relu
  • Reshape
  • Softmax
  • Squeeze
  1. The data format is either entirely "NHWC" or "NCHW"

To enable tensorflow protobuf support in the python API, use the following cmake command in your build directory:

CXX=/opt/rocm/bin/hcc cmake -DMIGRAPHX_ENABLE_TF=ON ..

This will link the tensorflow protobuf library instead of the onnx protobuf library. As noted in the documentation, insert -DCMAKE_PREFIX_PATH=/some/dir if dependencies are located elsewhere.

In the future, we will provide a python example of running a pre-trained Resnet50 V2 model.