Table Of Contents
- Description
- How does this sample work?
- Prerequisites
- Running the sample
- Additional resources
- License
- Changelog
- Known issues
This sample, engine_refit_mnist
, trains an MNIST model in PyTorch, recreates the network in TensorRT with dummy weights, and finally refits the TensorRT engine with weights from the model. Refitting allows us to quickly modify the weights in a TensorRT engine without needing to rebuild.
This sample first reconstructs the model using the TensorRT network API. In the first pass, the weights for one of the conv layers (conv_1
) are populated with dummy values resulting in an incorrect inference result. In the second pass, we refit the engine with the trained weights for the conv_1
layer and run inference again. With the weights now set correctly, inference should provide correct results.
In this sample, the following layers are used. For more information about these layers, see the TensorRT Developer Guide: Layers documentation.
Activation layer
The Activation layer implements element-wise activation functions. Specifically, this sample uses the Activation layer with the type kRELU
.
Convolution layer The Convolution layer computes a 2D (channel, height, and width) convolution, with or without bias.
FullyConnected layer The FullyConnected layer implements a matrix-vector product, with or without bias.
Pooling layer
The Pooling layer implements pooling within a channel. Supported pooling types are maximum
, average
and maximum-average blend
.
- Install the dependencies for Python.
python3 -m pip install -r requirements.txt
To run this sample you must be using Python 3.6 or newer.
On PowerPC systems, you will need to manually install PyTorch using IBM's PowerAI.
-
Run the sample to create a TensorRT engine and run inference:
python3 sample.py [-d DATA_DIR]
to run the sample with Python 3.
Note: If the TensorRT sample data is not installed in the default location, for example
/usr/src/tensorrt/data/
, the data directory must be specified. For example:python sample.py -d /path/to/my/data/
. -
Verify that the sample ran successfully. If the sample runs successfully you should see a match between the test case and the prediction after refitting.
Accuracy Before Engine Refit Got 892 correct predictions out of 10000 (8.9%) Accuracy After Engine Refit (expecting 98.0% correct predictions) Got 9798 correct predictions out of 10000 (98.0%)
To see the full list of available options and their descriptions, use the -h
or --help
command line option. For example:
usage: sample.py [-h]
Description for this sample
optional arguments:
-h, --help show this help message and exit
The following resources provide a deeper understanding about the engine refitting functionality and the network used in this sample:
Network
Dataset
Documentation
- Introduction to NVIDIA’s TensorRT Samples
- Working with TensorRT Using the Python API
- Refitting an Engine
- NVIDIA’s TensorRT Documentation Library
For terms and conditions for use, reproduction, and distribution, see the TensorRT Software License Agreement documentation.
March 2021 Documented the Python version limitations.
March 2019
This README.md
file was recreated, updated and reviewed.
This sample only supports Python 3.6+ due to torch
and torchvision
version requirements.