Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

shared memory support for SonicTriton #33801

Merged
merged 10 commits into from Jun 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 14 additions & 7 deletions HeterogeneousCore/CUDAUtilities/interface/cudaCheck.h
Expand Up @@ -5,6 +5,8 @@
#include <iostream>
#include <sstream>
#include <stdexcept>
#include <string>
#include <string_view>

// CUDA headers
#include <cuda.h>
Expand All @@ -21,19 +23,22 @@ namespace cms {
const char* cmd,
const char* error,
const char* message,
const char* description = nullptr) {
std::string_view description = std::string_view()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fwyzard Do you see any potential problems in using std::string_view here? (all relevant compilers for CUDA should support C++17 by some time already, right?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay, the TDR has kept me fully busy the last few days (...).

My main concern was what happens in the vast majority of the cases, when no description is passed. Both @makortel and I have made some checks on godbolt, and it looks like the compiler should optimise the std::string_view away in that case.

As for C++17 vs earlier versions of the standard: yes, CUDA 11 fully supports C++ 17, so no problem there either.

std::ostringstream out;
out << "\n";
out << file << ", line " << line << ":\n";
out << "cudaCheck(" << cmd << ");\n";
out << error << ": " << message << "\n";
if (description)
if (!description.empty())
out << description << "\n";
throw std::runtime_error(out.str());
}

inline bool cudaCheck_(
const char* file, int line, const char* cmd, CUresult result, const char* description = nullptr) {
inline bool cudaCheck_(const char* file,
int line,
const char* cmd,
CUresult result,
std::string_view description = std::string_view()) {
if (LIKELY(result == CUDA_SUCCESS))
return true;

Expand All @@ -45,8 +50,11 @@ namespace cms {
return false;
}

inline bool cudaCheck_(
const char* file, int line, const char* cmd, cudaError_t result, const char* description = nullptr) {
inline bool cudaCheck_(const char* file,
int line,
const char* cmd,
cudaError_t result,
std::string_view description = std::string_view()) {
if (LIKELY(result == cudaSuccess))
return true;

Expand All @@ -55,7 +63,6 @@ namespace cms {
abortOnCudaError(file, line, cmd, error, message, description);
return false;
}

} // namespace cuda
} // namespace cms

Expand Down
4 changes: 4 additions & 0 deletions HeterogeneousCore/SonicCore/src/SonicClientBase.cc
Expand Up @@ -73,6 +73,10 @@ void SonicClientBase::finish(bool success, std::exception_ptr eptr) {
holder_.reset();
} else if (eptr)
std::rethrow_exception(eptr);

//reset client data now (usually done at end of produce())
if (eptr)
reset();
}

void SonicClientBase::fillBasePSetDescription(edm::ParameterSetDescription& desc, bool allowRetry) {
Expand Down
4 changes: 4 additions & 0 deletions HeterogeneousCore/SonicTriton/BuildFile.xml
Expand Up @@ -4,8 +4,12 @@
<use name="FWCore/ParameterSet"/>
<use name="FWCore/Utilities"/>
<use name="HeterogeneousCore/SonicCore"/>
<use name="HeterogeneousCore/CUDAUtilities"/>
<use name="triton-inference-server"/>
<use name="protobuf"/>
<iftool name="cuda">
<use name="cuda"/>
</iftool>
<export>
<lib name="1"/>
</export>
31 changes: 27 additions & 4 deletions HeterogeneousCore/SonicTriton/README.md
Expand Up @@ -30,6 +30,12 @@ The model information from the server can be printed by enabling `verbose` outpu
* `preferredServer`: name of preferred server, for testing (see [Services](#services) below)
* `timeout`: maximum allowed time for a request
* `outputs`: optional, specify which output(s) the server should send
* `verbose`: enable verbose printouts (default: false)
* `useSharedMemory`: enable use of shared memory (see [below](#shared-memory)) with local servers (default: true)

The batch size should be set using the client accessor, in order to ensure a consistent value across all inputs:
* `setBatchSize()`: set a new batch size
* some models may not support batching

Useful `TritonData` accessors include:
* `variableDims()`: return true if any variable dimensions
Expand All @@ -39,8 +45,6 @@ Useful `TritonData` accessors include:
* `byteSize()`: return number of bytes for data type
* `dname()`: return name of data type
* `batchSize()`: return current batch size
* `setBatchSize()`: set a new batch size
* some models may not support batching

To update the `TritonData` shape in the variable-dimension case:
* `setShape(const std::vector<int64_t>& newShape)`: update all (variable) dimensions with values provided in `newShape`
Expand All @@ -49,9 +53,28 @@ To update the `TritonData` shape in the variable-dimension case:
There are specific local input and output containers that should be used in producers.
Here, `T` is a primitive type, and the two aliases listed below are passed to `TritonInputData::toServer()`
and returned by `TritonOutputData::fromServer()`, respectively:
* `TritonInput<T> = std::vector<std::vector<T>>`
* `TritonInputContainer<T> = std::shared_ptr<TritonInput<T>> = std::shared_ptr<std::vector<std::vector<T>>>`
* `TritonOutput<T> = std::vector<edm::Span<const T*>>`

The `TritonInputContainer` object should be created using the helper function described below.
It expects one vector per batch entry (i.e. the size of the outer vector is the batch size).
Therefore, it is best to call `TritonClient::setBatchSize()`, if necessary, before calling the helper.
It will also reserve the expected size of the input in each inner vector (by default),
if the concrete shape is available (i.e. `setShape()` was already called, if the input has variable dimensions).
* `allocate<T>()`: return a `TritonInputContainer` properly allocated for the batch and input sizes

### Shared memory

If the local fallback server (see [Services](#services) below) is in use,
input and output data can be transferred via shared memory rather than gRPC.
Both CPU and GPU (CUDA) shared memory are supported.
This is more efficient for some algorithms;
if shared memory is not more efficient for an algorithm, it can be disabled in the Python configuration for the client.

For outputs, shared memory can only be used if the batch size and concrete shape are known in advance,
because the shared memory region for the output must be registered before the inference call is made.
As with the inputs, this is handled automatically, and the use of shared memory can be disabled if desired.

## Modules

SONIC Triton supports producers, filters, and analyzers.
Expand All @@ -71,7 +94,7 @@ If an `edm::GlobalCache` of type `T` is needed, there are two changes:
In a SONIC Triton producer, the basic flow should follow this pattern:
1. `acquire()`:
a. access input object(s) from `TritonInputMap`
b. allocate input data using `std::make_shared<TritonInput<T>>()`
b. allocate input data using `allocate<T>()`
c. fill input data
d. set input shape(s) (optional, only if any variable dimensions)
e. convert using `toServer()` function of input object(s)
Expand Down
24 changes: 21 additions & 3 deletions HeterogeneousCore/SonicTriton/interface/TritonClient.h
Expand Up @@ -5,6 +5,7 @@
#include "FWCore/ParameterSet/interface/ParameterSetDescription.h"
#include "HeterogeneousCore/SonicCore/interface/SonicClient.h"
#include "HeterogeneousCore/SonicTriton/interface/TritonData.h"
#include "HeterogeneousCore/SonicTriton/interface/TritonService.h"

#include <map>
#include <vector>
Expand All @@ -31,20 +32,28 @@ class TritonClient : public SonicClient<TritonInputMap, TritonOutputMap> {
//constructor
TritonClient(const edm::ParameterSet& params, const std::string& debugName);

//destructor
~TritonClient() override;

//accessors
unsigned batchSize() const { return batchSize_; }
bool verbose() const { return verbose_; }
bool useSharedMemory() const { return useSharedMemory_; }
void setUseSharedMemory(bool useShm) { useSharedMemory_ = useShm; }
bool setBatchSize(unsigned bsize);
void reset() override;
bool noBatch() const { return noBatch_; }
TritonServerType serverType() const { return serverType_; }

//for fillDescriptions
static void fillPSetDescription(edm::ParameterSetDescription& iDesc);

protected:
//helper
bool getResults(std::shared_ptr<nvidia::inferenceserver::client::InferResult> results);

//helpers
void getResults(std::shared_ptr<nvidia::inferenceserver::client::InferResult> results);
void evaluate() override;
template <typename F>
bool handle_exception(F&& call);

void reportServerSideStats(const ServerSideStats& stats) const;
ServerSideStats summarizeServerStats(const inference::ModelStatistics& start_status,
Expand All @@ -57,6 +66,8 @@ class TritonClient : public SonicClient<TritonInputMap, TritonOutputMap> {
unsigned batchSize_;
bool noBatch_;
bool verbose_;
bool useSharedMemory_;
TritonServerType serverType_;

//IO pointers for triton
std::vector<nvidia::inferenceserver::client::InferInput*> inputsTriton_;
Expand All @@ -65,6 +76,13 @@ class TritonClient : public SonicClient<TritonInputMap, TritonOutputMap> {
std::unique_ptr<nvidia::inferenceserver::client::InferenceServerGrpcClient> client_;
//stores timeout, model name and version
nvidia::inferenceserver::client::InferOptions options_;

private:
friend TritonInputData;
friend TritonOutputData;

//private accessors only used by data
auto client() { return client_.get(); }
};

#endif
72 changes: 60 additions & 12 deletions HeterogeneousCore/SonicTriton/interface/TritonData.h
Expand Up @@ -10,20 +10,34 @@
#include <numeric>
#include <algorithm>
#include <memory>
#include <any>
#include <atomic>

#include "grpc_client.h"
#include "grpc_service.pb.h"

//forward declaration
class TritonClient;
template <typename IO>
class TritonMemResource;
template <typename IO>
class TritonHeapResource;
template <typename IO>
class TritonCpuShmResource;
#ifdef TRITON_ENABLE_GPU
template <typename IO>
class TritonGpuShmResource;
#endif

//aliases for local input and output types
template <typename DT>
using TritonInput = std::vector<std::vector<DT>>;
template <typename DT>
using TritonOutput = std::vector<edm::Span<const DT*>>;

//other useful typdefs
template <typename DT>
using TritonInputContainer = std::shared_ptr<TritonInput<DT>>;

//store all the info needed for triton input and output
template <typename IO>
class TritonData {
Expand All @@ -34,15 +48,18 @@ class TritonData {
using ShapeView = edm::Span<ShapeType::const_iterator>;

//constructor
TritonData(const std::string& name, const TensorMetadata& model_info, bool noBatch);
TritonData(const std::string& name, const TensorMetadata& model_info, TritonClient* client, const std::string& pid);

//some members can be modified
bool setShape(const ShapeType& newShape) { return setShape(newShape, true); }
bool setShape(unsigned loc, int64_t val) { return setShape(loc, val, true); }
void setShape(const ShapeType& newShape);
void setShape(unsigned loc, int64_t val);

//io accessors
template <typename DT>
void toServer(std::shared_ptr<TritonInput<DT>> ptr);
TritonInputContainer<DT> allocate(bool reserve = true);
template <typename DT>
void toServer(TritonInputContainer<DT> ptr);
void prepare();
template <typename DT>
TritonOutput<DT> fromServer() const;

Expand All @@ -60,14 +77,23 @@ class TritonData {

private:
friend class TritonClient;
friend class TritonMemResource<IO>;
friend class TritonHeapResource<IO>;
friend class TritonCpuShmResource<IO>;
#ifdef TRITON_ENABLE_GPU
friend class TritonGpuShmResource<IO>;
#endif

//private accessors only used by client
bool setShape(const ShapeType& newShape, bool canThrow);
bool setShape(unsigned loc, int64_t val, bool canThrow);
//private accessors only used internally or by client
unsigned fullLoc(unsigned loc) const { return loc + (noBatch_ ? 0 : 1); }
void setBatchSize(unsigned bsize);
void reset();
void setResult(std::shared_ptr<Result> result) { result_ = result; }
IO* data() { return data_.get(); }
void updateMem(size_t size);
void computeSizes();
void resetSizes();
nvidia::inferenceserver::client::InferenceServerGrpcClient* client();

//helpers
bool anyNeg(const ShapeView& vec) const {
Expand All @@ -76,11 +102,20 @@ class TritonData {
int64_t dimProduct(const ShapeView& vec) const {
return std::accumulate(vec.begin(), vec.end(), 1, std::multiplies<int64_t>());
}
void createObject(IO** ioptr) const;
void createObject(IO** ioptr);
//generates a unique id number for each instance of the class
unsigned uid() const {
static std::atomic<unsigned> uid{0};
return ++uid;
}
std::string xput() const;

//members
std::string name_;
std::shared_ptr<IO> data_;
TritonClient* client_;
bool useShm_;
std::string shmName_;
const ShapeType dims_;
bool noBatch_;
unsigned batchSize_;
Expand All @@ -91,7 +126,11 @@ class TritonData {
std::string dname_;
inference::DataType dtype_;
int64_t byteSize_;
std::any holder_;
size_t sizeShape_;
size_t byteSizePerBatch_;
size_t totalByteSize_;
std::shared_ptr<void> holder_;
std::shared_ptr<TritonMemResource<IO>> memResource_;
std::shared_ptr<Result> result_;
};

Expand All @@ -102,19 +141,28 @@ using TritonOutputMap = std::unordered_map<std::string, TritonOutputData>;

//avoid "explicit specialization after instantiation" error
template <>
std::string TritonInputData::xput() const;
template <>
std::string TritonOutputData::xput() const;
template <>
template <typename DT>
TritonInputContainer<DT> TritonInputData::allocate(bool reserve);
template <>
template <typename DT>
void TritonInputData::toServer(std::shared_ptr<TritonInput<DT>> ptr);
template <>
void TritonOutputData::prepare();
template <>
template <typename DT>
TritonOutput<DT> TritonOutputData::fromServer() const;
template <>
void TritonInputData::reset();
template <>
void TritonOutputData::reset();
template <>
void TritonInputData::createObject(nvidia::inferenceserver::client::InferInput** ioptr) const;
void TritonInputData::createObject(nvidia::inferenceserver::client::InferInput** ioptr);
template <>
void TritonOutputData::createObject(nvidia::inferenceserver::client::InferRequestedOutput** ioptr) const;
void TritonOutputData::createObject(nvidia::inferenceserver::client::InferRequestedOutput** ioptr);

//explicit template instantiation declarations
extern template class TritonData<nvidia::inferenceserver::client::InferInput>;
Expand Down
14 changes: 14 additions & 0 deletions HeterogeneousCore/SonicTriton/interface/TritonException.h
@@ -0,0 +1,14 @@
#ifndef HeterogeneousCore_SonicTriton_TritonException
#define HeterogeneousCore_SonicTriton_TritonException

#include "FWCore/Utilities/interface/Exception.h"

#include <string>

class TritonException : public cms::Exception {
public:
explicit TritonException(std::string const& aCategory);
void convertToWarning() const;
};

#endif