Skip to content

Commit

Permalink
shared_ptr compilation passed
Browse files Browse the repository at this point in the history
  • Loading branch information
minghaoBD committed May 26, 2022
1 parent 0af0393 commit 3340d6b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 18 deletions.
35 changes: 18 additions & 17 deletions paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ inline void deserialize_value_size(void const** buffer, size_t* buffer_size,

inline float round_scale(float x) { return std::floor(x + 0.5f); }

inline void cudaFreeFunc(void* p) { if(p) { cudaFree(p); } }

inline void convertAndCopy(const nvinfer1::Weights& src,
nvinfer1::DataType type, void* dest) {
PADDLE_ENFORCE_EQ(src.type == nvinfer1::DataType::kFLOAT ||
Expand Down Expand Up @@ -252,6 +254,7 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string& layer_name,
weight_scale_(1.0f),
weight_compressed_(nullptr),
weight_compressed_dev_(nullptr),
weight_compressed_dev_global_(nullptr),
compressed_size_(0),
has_bias_(false),
bias_(nullptr),
Expand Down Expand Up @@ -310,12 +313,12 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string& layer_name,
cudaMemcpy(weight_dev, weight_host.data(), precision_size_ * weight.count,
cudaMemcpyHostToDevice);
}

spmm_context_.compressMatB(out_dim_, k_, convertTrtType(precision_),
weight_dev, &weight_compressed_dev_,
&compressed_size_);
weight_compressed_ = new char[compressed_size_];
cudaMemcpy(weight_compressed_, weight_compressed_dev_, compressed_size_,
weight_compressed_dev_global_.reset(weight_compressed_dev_, cudaFreeFunc);
cudaMemcpy(weight_compressed_, weight_compressed_dev_global_.get(), compressed_size_,
cudaMemcpyDeviceToHost);

has_bias_ = (bias.count != 0);
Expand Down Expand Up @@ -352,7 +355,7 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string& layer_name,
optim_alg_(optim_alg),
weight_scale_(1.0f),
weight_compressed_(nullptr),
weight_compressed_dev_(nullptr),
weight_compressed_dev_global_(nullptr),
compressed_size_(compressed_size),
has_bias_(false),
bias_(nullptr),
Expand All @@ -373,11 +376,6 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string& layer_name,
std::copy_n(static_cast<const char*>(weight_compressed), compressed_size,
static_cast<char*>(weight_compressed_));

cudaMalloc(reinterpret_cast<void**>(&weight_compressed_dev_),
compressed_size);
cudaMemcpy(weight_compressed_dev_, weight_compressed_, compressed_size,
cudaMemcpyHostToDevice);

has_bias_ = (bias != nullptr);
if (has_bias_) {
// Each plugin has a copy of bias
Expand All @@ -403,7 +401,7 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string name, const void* data,
size_t length)
: layer_name_(name),
weight_compressed_(nullptr),
weight_compressed_dev_(nullptr),
weight_compressed_dev_global_(nullptr),
bias_(nullptr),
bias_dev_(nullptr) {
DeserializeValue(&data, &length, &precision_);
Expand All @@ -424,9 +422,10 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string name, const void* data,
"Deserialize data should be configured"));
weight_compressed_ = new char[compressed_size_];
deserialize_value_size(&data, &length, weight_compressed_, compressed_size_);
cudaMalloc(reinterpret_cast<void**>(&weight_compressed_dev_),
//MEM: how to deal with deserialization?
cudaMalloc(reinterpret_cast<void**>(weight_compressed_dev_global_.get()),
compressed_size_);
cudaMemcpy(weight_compressed_dev_, weight_compressed_, compressed_size_,
cudaMemcpy(weight_compressed_dev_global_.get(), weight_compressed_, compressed_size_,
cudaMemcpyHostToDevice);

if (has_bias_) {
Expand All @@ -451,8 +450,8 @@ nvinfer1::IPluginV2DynamicExt* SpmmPluginDynamic::clone() const noexcept {
weight_compressed_, compressed_size_, bias_,
is_configured_, m_max_, optim_alg_, activation_);
p->weight_scale_ = weight_scale_;
p->weight_compressed_dev_global_ = weight_compressed_dev_global_;
p->setPluginNamespace(namespace_.c_str());

return p;
} catch (const std::exception& e) {
std::cerr << e.what() << std::endl;
Expand Down Expand Up @@ -614,7 +613,7 @@ void SpmmPluginDynamic::configurePlugin(
spmm_context_.workspace_size);
paddle::platform::dynload::cusparseLtMatmulSearch(
&spmm_context_.handle, &spmm_context_.plan, &alpha, dA,
weight_compressed_dev_, &beta, dC, dC, d_workspace, nullptr, 0);
weight_compressed_dev_global_.get(), &beta, dC, dC, d_workspace, nullptr, 0);
paddle::platform::dynload::cusparseLtMatmulAlgGetAttribute(
&spmm_context_.handle, &spmm_context_.alg_sel,
CUSPARSELT_MATMUL_ALG_CONFIG_ID, &optim_alg_, sizeof(optim_alg_));
Expand Down Expand Up @@ -658,22 +657,22 @@ int SpmmPluginDynamic::enqueue(const nvinfer1::PluginTensorDesc* inputDesc,
auto* output = static_cast<float*>(outputs[0]);
cusparseStatus_t status = paddle::platform::dynload::cusparseLtMatmul(
&spmm_context_.handle, &spmm_context_.plan, &alpha, input,
weight_compressed_dev_, &beta, output, output, workSpace, &stream, 1);
weight_compressed_dev_global_.get(), &beta, output, output, workSpace, &stream, 1);
return status != CUSPARSE_STATUS_SUCCESS;
} else if (inputDesc->type == nvinfer1::DataType::kHALF) {
const auto* const input = static_cast<const half*>(inputs[0]);
auto* output = static_cast<half*>(outputs[0]);
cusparseStatus_t status = paddle::platform::dynload::cusparseLtMatmul(
&spmm_context_.handle, &spmm_context_.plan, &alpha, input,
weight_compressed_dev_, &beta, output, output, workSpace, &stream, 1);
weight_compressed_dev_global_.get(), &beta, output, output, workSpace, &stream, 1);
return status != CUSPARSE_STATUS_SUCCESS;
} else if (inputDesc->type == nvinfer1::DataType::kINT8) {
alpha = inputDesc->scale * weight_scale_ / outputDesc->scale;
const auto* const input = static_cast<const int8_t*>(inputs[0]);
auto* output = static_cast<int8_t*>(outputs[0]);
cusparseStatus_t status = paddle::platform::dynload::cusparseLtMatmul(
&spmm_context_.handle, &spmm_context_.plan, &alpha, input,
weight_compressed_dev_, &beta, output, output, workSpace, &stream, 1);
weight_compressed_dev_global_.get(), &beta, output, output, workSpace, &stream, 1);
return status != CUSPARSE_STATUS_SUCCESS;
} else {
PADDLE_THROW(paddle::platform::errors::Fatal(
Expand Down Expand Up @@ -749,7 +748,9 @@ void SpmmPluginDynamic::serialize(void* buffer) const noexcept {

void SpmmPluginDynamic::destroy() noexcept {
delete[] reinterpret_cast<char*>(weight_compressed_);
cudaFree(weight_compressed_dev_);
//MEM:
// cudaFree(weight_compressed_dev_);
weight_compressed_dev_global_.reset();
if (has_bias_) {
cudaFree(bias_dev_);
}
Expand Down
6 changes: 5 additions & 1 deletion paddle/fluid/inference/tensorrt/plugin/spmm_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin.h"
#include "paddle/fluid/platform/dynload/cusparseLt.h"

using namespace std;

namespace paddle {
namespace inference {
namespace tensorrt {
Expand Down Expand Up @@ -77,6 +79,7 @@ class SpmmPluginDynamic : public nvinfer1::IPluginV2DynamicExt {
const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) noexcept override;

nvinfer1::DataType getOutputDataType(int index,
const nvinfer1::DataType* inputTypes,
int nbInputs) const noexcept override;
Expand Down Expand Up @@ -128,7 +131,8 @@ class SpmmPluginDynamic : public nvinfer1::IPluginV2DynamicExt {
int optim_alg_; // the index of optimal algorithm
float weight_scale_; // record the weight scale from constructor
void* weight_compressed_; // host compressed weight
void* weight_compressed_dev_; // device compressed weight
void* weight_compressed_dev_; // device compressed weight
shared_ptr<void> weight_compressed_dev_global_; // shared pointer to the device compressed weight
size_t compressed_size_; // size of compressed weight
bool has_bias_; // there is bias or not
void* bias_; // host bias
Expand Down

0 comments on commit 3340d6b

Please sign in to comment.