Skip to content

Latest commit

 

History

History
109 lines (72 loc) · 4.97 KB

File metadata and controls

109 lines (72 loc) · 4.97 KB

Refitting an Engine in Python

Table Of Contents

Description

This sample, engine_refit_mnist, trains an MNIST model in PyTorch, recreates the network in TensorRT with dummy weights, and finally refits the TensorRT engine with weights from the model. Refitting allows us to quickly modify the weights in a TensorRT engine without needing to rebuild.

How does this sample work?

This sample first reconstructs the model using the TensorRT network API. In the first pass, the weights for one of the conv layers (conv_1) are populated with dummy values resulting in an incorrect inference result. In the second pass, we refit the engine with the trained weights for the conv_1 layer and run inference again. With the weights now set correctly, inference should provide correct results.

TensorRT API layers and ops

In this sample, the following layers are used. For more information about these layers, see the TensorRT Developer Guide: Layers documentation.

Activation layer The Activation layer implements element-wise activation functions. Specifically, this sample uses the Activation layer with the type kRELU.

Convolution layer The Convolution layer computes a 2D (channel, height, and width) convolution, with or without bias.

FullyConnected layer The FullyConnected layer implements a matrix-vector product, with or without bias.

Pooling layer The Pooling layer implements pooling within a channel. Supported pooling types are maximum, average and maximum-average blend.

Prerequisites

  1. Install the dependencies for Python. python3 -m pip install -r requirements.txt

To run this sample you must be using Python 3.6 or newer.

On PowerPC systems, you will need to manually install PyTorch using IBM's PowerAI.

Running the sample

  1. Run the sample to create a TensorRT engine and run inference: python3 sample.py [-d DATA_DIR]

    to run the sample with Python 3.

    Note: If the TensorRT sample data is not installed in the default location, for example /usr/src/tensorrt/data/, the data directory must be specified. For example: python sample.py -d /path/to/my/data/.

  2. Verify that the sample ran successfully. If the sample runs successfully you should see a match between the test case and the prediction after refitting.

    Accuracy Before Engine Refit
    Got 892 correct predictions out of 10000 (8.9%)
    Accuracy After Engine Refit (expecting 98.0% correct predictions)
    Got 9798 correct predictions out of 10000 (98.0%)
    

Sample --help options

To see the full list of available options and their descriptions, use the -h or --help command line option. For example:

usage: sample.py [-h]

Description for this sample

optional arguments:
    -h, --help show this help message and exit

Additional resources

The following resources provide a deeper understanding about the engine refitting functionality and the network used in this sample:

Network

Dataset

Documentation

License

For terms and conditions for use, reproduction, and distribution, see the TensorRT Software License Agreement documentation.

Changelog

March 2021 Documented the Python version limitations.

March 2019 This README.md file was recreated, updated and reviewed.

Known issues

This sample only supports Python 3.6+ due to torch and torchvision version requirements.