Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Saving and loading cudNN autotune and graph optimization #16173

Open
QueensGambit opened this issue Sep 14, 2019 · 18 comments
Open

Saving and loading cudNN autotune and graph optimization #16173

QueensGambit opened this issue Sep 14, 2019 · 18 comments

Comments

@QueensGambit
Copy link
Contributor

QueensGambit commented Sep 14, 2019

Hello everyone,

there are several tasks which are executed repeatedly on binding MXNet graphs and result in the same outcome when the graph is unchanged.
In theory these results could be saved to disk and later reloaded.
These tasks include cudNN autotuning, TensorRT graph fusion, IntelMKLDNN graph optimization.

Here is a short overview:

cudNN convolution autotune

  • Description: runs performance tests for convolutional layers to check what convolutional algorithm types are most performant for the given computation graph. CudNN autotune is enabled by default for cuDNN back-ends.
  • Indicated by:
incubator-mxnet/src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:97:
Running performance tests to find the best convolution algorithm, 
this can take a while...
(set the environment variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)

TensorRT graph fusion

  • Description: attempts to fuse multiple cuda operations into one which saves memory transfer times. Can be applied when MXNet was built with TensorRT back-end.
  • Indicated by: multiple log messages in the case the current GPU does not support fp16.
../src/operator/subgraph/tensorrt/onnx_to_tensorrt.cc:121:
 TensorRT can't use fp16 on this platform
  • Optimization time: long (e.g. 24 seconds)
  • Urgency: high

MKLDNN graph optimization

  • Description: applies MKLDNN optimization to Convolutional and FullyConnected layers and is enabled by default (previously by setting: MXNET_SUBGRAPH_BACKEND=MKLDNN).
  • Indicated by:
src/operator/subgraph/build_subgraph.cc:686:
start to execute MKLDNN convolution optimization pass.
src/operator/subgraph/build_subgraph.cc:686:
start to execute MKLDNN FullyConnected optimization pass.
  • Optimization time: very fast (e.g. < 1 second)
  • Urgency: low

Experimental support

First I suggest adding two experimental API methods (python, C++,...) for each optimization technique independently which can be called by model.save_cache() and model.load_cache():
These methods are only supported for static graphs. (e.g. after net.hybridize() in the case of Gluon).

def load_cache(filename, type)

"""Load optimization cache from file previously saved by `save_cache`.

Parameters
----------
filename : str
    Path to cache file.
type : str, default 'cudnn'
    must be in {'cudnn', 'tensorrt', 'mkldnn'}
References
----------
`Saving and Loading Optimization Cache for Models \
`_
"""

if type == 'cudnn':
   raise NotImplementedError
   # _load_cudnn_cache(filename)
elif type == 'tensorrt':
   raise NotImplementedError
   # _load_tensorrt_cache(filename)
elif type == 'mkldnn':
   raise NotImplementedError
   # _load_mkldnn_cache(filename)
else:
  raise ValueError("type must be in {'cudnn', 'tensorrt', 'mkldnn'}")
def save_cache(filename, type)

"""Saves the optimization cache for the graph.
    Must be run after `model.bind(optimize=True)`.

Parameters
----------
filename : str
    Path to cache file.
type : str, default 'cudnn'
    must be in {'cudnn', 'tensorrt', 'mkldnn'}
References
----------
`Saving and Loading Optimization Cache for Models \
`_
"""

if type == 'cudnn':
   raise NotImplementedError
   # _save_cudnn_cache(filename)
elif type == 'tensorrt':
   raise NotImplementedError
   # _save_tensorrt_cache(filename)
elif type == 'mkldnn':
   raise NotImplementedError
   # _save_mkldnn_cache(filename)
else:
  raise ValueError("type must be in {'cudnn', 'tensorrt', 'mkldnn'}")

This addition requires a new boolean parameter for bind() methods which is set to True by default for backward-compatibility.

def bind(self, ctx, args, args_grad=None, grad_req='write',
               aux_states=None, group2ctx=None, shared_exec=None, optimize=True):
"""
   # ...
    optimize : boolean, optional
        When set to True, the model is optimized with the current back-end (e.g. cuDNN, TensorRT, MKLDNN)
"""

Automatic integration with bind()

As soon as caching and reloading has surpassed experimental status, we can consider integrating it as an automatic procedure to bind():

As a unified module I imagine the following process:

For every model a unique fingerprint is generated similar to a git commit hash in terms of length.
The last 7 digits of the fingerprint indicate the filename for the cache file.
It is based on the following ingredients:

For cudNN convolution autotuning

  • MXNet version (major, minor)
  • model structure
  • CUDA version (major, minor)
  • cudNN version (major, minor)

For TensorRT graph optimization:

  • MXNet version (major, minor)
  • model structure
  • CUDA version (major, minor)
  • cudNN version (major, minor)
  • Tensorrt version (major, minor)

MKLDNN graph optimization

  • MXNet version (major, minor)
  • model structure
  • MKLDNN version (major, minor)

On every bind() call, MXNet generates the fingerprint and attempts to load the file from the current working directory if it exists.
If the file was not found or loading failed then the optimization will be run and saved afterwards to <fingerprint_digits>.cache.


Is an important detail missing or do you recommend changes in certain aspects (e.g. naming conventions)?
I am interested to hear your thought on this.

Best regards,
~Johannes Czech

@QueensGambit
Copy link
Contributor Author

QueensGambit commented Sep 14, 2019

@mxnet-label-bot add [Backend, CUDA, Call for Contribution, MKLDNN, ONNX, Feature request]

@QueensGambit
Copy link
Contributor Author

ping @KellenSunderland @ThomasDelteil

@QueensGambit QueensGambit changed the title Saving and loading of cudNN optimization and graph fusion Saving and loading of cudNN autotune and graph optimization Sep 14, 2019
@QueensGambit QueensGambit changed the title Saving and loading of cudNN autotune and graph optimization Saving and loading cudNN autotune and graph optimization Sep 14, 2019
@QueensGambit QueensGambit changed the title Saving and loading cudNN autotune and graph optimization Saving and loading cudNN autotuning and graph optimization Sep 14, 2019
@QueensGambit QueensGambit changed the title Saving and loading cudNN autotuning and graph optimization Saving and loading cudNN autotun and graph optimization Sep 14, 2019
@QueensGambit QueensGambit changed the title Saving and loading cudNN autotun and graph optimization Saving and loading cudNN autotune and graph optimization Sep 14, 2019
@pengzhao-intel
Copy link
Contributor

FYI, MKLDNN graph fusion is already enabled by default :)
One more, saving the fused graph maybe cause portable issues and break backward compatible so we need a solution to fall back as well.

@QueensGambit
Copy link
Contributor Author

Thank you for the reply @pengzhao-intel. I updated the description on MKLDNN.

I see the point about portability and backward compatibility issues.
Maybe it is better to define optimize as a string argument which must be in {'on_bind', 'save_reload', 'disabled'}:

def bind(self, ctx, args, args_grad=None, grad_req='write',
               aux_states=None, group2ctx=None, shared_exec=None, optimize='on_bind'):
"""
   # ...
    optimize : str, optional, default 'on_bind'
                    must be in {'on_bind', 'save_reload', 'disabled'}
                    'on_bind': Graph optimization / cuDNN autotune is executed during model binding
                    'save_reload': MXNet attempts to recover previous optimization information. 
                                   Otherwise MXNet will perform optimization and save it to disk.
                    'disabled': No graph optimization / cuDNN autotune is performed
"""

In the default case optimize='on_bind', it will behave the same way as currently and all previous code will behave the same.

As a different aspect, it might be preferable to treat graph optimization (MKLDNN graph optimization / TensorRT graph fusion) as a different entity compared to cudNN autotune because cudNN autotune might also be performed on fused graphs in future versions.

@pengzhao-intel
Copy link
Contributor

@ZhennanQin will follow up with the details :)

@ZhennanQin
Copy link
Contributor

I roughly go through this topic, and I think we need a top level design first, like, which information needs to be saved, how to save the result(via aux_params or file), and what format to encode(probably we need to save many sections with different types, including some binaries). When top design is approved by community, then I can bring the implementation on MKLDNN parts.

@QueensGambit
Copy link
Contributor Author

@ZhennanQin This sounds reasonable.
Maybe it would be good to create a cwiki page for the top level design:

@mxnet-label-bot
Copy link
Contributor

Hey, this is the MXNet Label Bot.
Thank you for submitting the issue! I will try and suggest some labels so that the appropriate MXNet community members can help resolve it.
Here are my recommended label(s): Cuda, Feature

@pengzhao-intel
Copy link
Contributor

A related discussion @dev before:

https://lists.apache.org/thread.html/dc83f4c16cc47b12a52c55d641cc5c6916cf8aa51589371eaf89a94b@%3Cdev.mxnet.apache.org%3E

@KellenSunderland
Copy link
Contributor

I agree this is an issue and really like the approach here @QueensGambit. TensorRT and Autotuning are the operations I see that impact our current cold start time. cuDNN is relatively straightforward to cache, TRT we might want to think about how to save and load the engine properly. I think the keys for caching are well though through. We would have to make sure the expectation is this will only work on a certain host-type, but I think that's a reasonable restriction.

@QueensGambit
Copy link
Contributor Author

QueensGambit commented Oct 3, 2019

Thank you for the feedback @KellenSunderland.
I think the mini-batchsize should be included as a caching specification as well because optimization techniques like TensorRT depend on it.

A different approach would be to define a save and load function for the Executor class.
The memory file of an executor handle would contain all additional platform specific definitions and optimization results. This would allow the user to run the full binding process once on a specific platform and later the option to bind it much quicker:

# mxnet/executor.py
def save(filename_exec):
"""Saves the executor handle including specific optimization of the graph.
    Must be run after the executor handle was binded: `model.bind()`.

Parameters
----------
filename : str
    Path to the executor file (e.g. "executor.exec").
References
----------
`Saving and Loading of Executor handles \
`_
"""

In order to preferably avoid an additional copy of the model parameters, one needs to specify the .params and .symbol filepath when loading the executor handle. This would also enable to update the model parameters independently from the optimization cache:

# mxnet/executor.py
def load(filename_exec, filename_symbol, filename_params):
"""Loads and binds the executor handle.

Parameters
----------
filename_exec : str
    Path to the executor file (e.g. "executor.exec").
filename_symbol : str
    Path to the model architecture definition (e.g. "model.symbol").
filename_params : str
    Path to the model weights (e.g. "model.params").
References
----------
`Saving and Loading of Executor handles \
`_
"""

@QueensGambit
Copy link
Contributor Author

QueensGambit commented Oct 3, 2019

Regarding the export of a TensorRT executor handle (@Caenorst, @haohuanw),
the ONNX-TensorRT repository provides an executable to generate an TensorRT engine file from an ONNX-model:

onnx2trt my_model.onnx -o my_engine.trt

Alternatively, one can use the the C++-API instead:

NvOnnxParser.h
NvOnnxParserTypedefs.h

Later the engine file can be reloaded from memory:
Here is an example python code for this using code fragements from onnx/onnx-tensorrt#180 and https://github.com/NVIDIA/object-detection-tensorrt-example/blob/master/SSD_Model/utils/common.py.
Unfortunately, I haven't found an example in C++ for this yet:

import pycuda.autoinit
import pycuda.driver as cuda
import tensorrt as trt
import numpy as np

trt_engine_path = 'my_engine.trt'
# initialize
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
trt.init_libnvinfer_plugins(TRT_LOGGER, '')
runtime = trt.Runtime(TRT_LOGGER)

# https://github.com/onnx/onnx-tensorrt/issues/180
def allocate_buffers(engine):
    """
    Allocates all buffers required for the specified engine
    """
    inputs = []
    outputs = []
    bindings = []
    # Iterate over binding names in engine
    for binding in engine:
        # Get binding (tensor/buffer) size
        size = trt.volume(engine.get_binding_shape(binding)) * engine.max_batch_size
        # Get binding (tensor/buffer) data type (numpy-equivalent)
        dtype = trt.nptype(engine.get_binding_dtype(binding))
        # Allocate page-locked memory (i.e., pinned memory) buffers
        host_mem = cuda.pagelocked_empty(size, dtype)
        # Allocate linear piece of device memory
        device_mem = cuda.mem_alloc(host_mem.nbytes)
        # Append the device buffer to device bindings
        bindings.append(int(device_mem))
        # Append to inputs/ouputs list
        if engine.binding_is_input(binding):
            inputs.append(HostDeviceMem(host_mem, device_mem))
        else:
            outputs.append(HostDeviceMem(host_mem, device_mem))
    # Create a stream (to eventually copy inputs/outputs and run inference)
    stream = cuda.Stream()
    return inputs, outputs, bindings, stream

def infer(context, bindings, inputs, outputs, stream, batch_size=1):
    """
    Infer outputs on the IExecutionContext for the specified inputs
    """
    # Transfer input data to the GPU
    [cuda.memcpy_htod_async(inp.device, inp.host, stream) for inp in inputs]
    # Run inference
    context.execute_async(batch_size=batch_size, bindings=bindings, stream_handle=stream.handle)
    # Transfer predictions back from the GPU
    [cuda.memcpy_dtoh_async(out.host, out.device, stream) for out in outputs]
    # Synchronize the stream
    stream.synchronize()
    # Return the host outputs
    return [out.host for out in outputs]

# https://github.com/NVIDIA/object-detection-tensorrt-example/blob/master/SSD_Model/utils/common.py
# Simple helper data class that's a little nicer to use than a 2-tuple.
class HostDeviceMem(object):
    def __init__(self, host_mem, device_mem):
        self.host = host_mem
        self.device = device_mem

    def __str__(self):
        return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)

    def __repr__(self):
        return self.__str__()

image = np.zeros((1, 3, 224, 224))  # dummy data

# Read the serialized ICudaEngine
with open(trt_engine_path, 'rb') as f, trt.Runtime(TRT_LOGGER) as runtime:
    # Deserialize ICudaEngine
    engine = runtime.deserialize_cuda_engine(f.read())
# Now just as with the onnx2trt samples...
# Create an IExecutionContext (context for executing inference)
with engine.create_execution_context() as context:
    # Allocate memory for inputs/outputs
    inputs, outputs, bindings, stream = allocate_buffers(engine)
    # Set host input to the image
    inputs[0].host = image
    # Inference
    trt_outputs = infer(context, bindings=bindings, inputs=inputs, outputs=outputs, stream=stream)
    # Prediction
    pred_id = np.argmax(trt_outputs[-1])

@chinakook
Copy link
Contributor

Maybe we can save the optimization states in params directly.

@QueensGambit
Copy link
Contributor Author

@chinakook This is an option, but could violate the ability to deploy the same model across different platforms (e.g. CPU, CPU-MKLDNN, GPU-CUDA, GPU-TensorRT).

@mikeobr
Copy link

mikeobr commented Nov 25, 2019

This feature would be very useful for us.

We deploy our models via MXNet Model Server: https://github.com/awslabs/multi-model-server as a custom service. Each worker has an instance of the model, so we experience memory instability if concurrent autotuning is happening at once. This forces us to either under utilize the available GPU or risk errors at startup. Being able to cache autotune would help with both cold starts and production-izing models.

@mk-61
Copy link
Contributor

mk-61 commented Mar 12, 2020

How about adding a method to Symbol, to calculate certain aspects of a model? In C API it would look something like:

int MXCalculateDigest(SymbolHandle handle, /* type TBD * / selector)?

It would calculate a hash of

  • MXNet version
  • model structure
  • attributes, filtered by the selector argument
  • anything else?

Then any library / optimizer would call it, add some other, optimization-specific details (CUDA version, hardware details, it's own library version, etc.) and would make a decision on whether to use some saved optimization result or perform a new optimization step? I can imagine it can save it as a graph attribute, for instance.

It will save us from piping through a centralized API particulars of specific libraries / optimisations, and will give more control to optimisers, who will decide, which aspects are relevant to them and which are not.

@QueensGambit
Copy link
Contributor Author

QueensGambit commented Apr 18, 2020

@mk-61 Thank you for your proposal.
Generating a unique key based on model an inference attributes is an option.
I'm not sure if I understood your idea correctly but one downside of this approach is that it might become difficult to maintain at some point if every inference backend uses a different set of attributes.

Therefore, I prefer my previous idea of adding a save and load function to the Executor class.
This way, the programmer can define a name for his executor object after all optimizations haven been done.

I implemented a wrapper for plain TensorRT to save an reload TensorRT optimization in a trt-engine file:

The class implements the same functionality as our MXNet wrapper:

The start-up time for loading the TensorRT optimization takes a few seconds while optimizing the model takes several minutes.

This process is called serialization and de-serialization.
For a TensorRT engine, it is as simple as dumping a binary memory buffer into a file and reloading it later:

void write_buffer(void* buffer, size_t bufferSize, const string& filePath);
const char* read_buffer(const string& filePath, size_t& bufferSize);

@mikeobr Would exporting and importing a MXNet executor objects suffice in your case?

@mikeobr
Copy link

mikeobr commented May 7, 2020

@QueensGambit I'm a bit of a noob level understanding around the internals, but if that allows us to load and start doing inference without taking the tuning hit, then it will help us out.

Generally my team generally uses models as SymbolBlocks and load them like gluon.SymbolBlock.imports(...``, so if we're able use an imported executor alongside our params and symbols, we're good.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

9 participants