Skip to content

Commit

Permalink
UT passed
Browse files Browse the repository at this point in the history
  • Loading branch information
minghaoBD committed May 27, 2022
1 parent 291bc0b commit 865b673
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions paddle/fluid/inference/tensorrt/plugin/spmm_plugin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string name, const void* data,
size_t length)
: layer_name_(name),
weight_compressed_(nullptr),
weight_compressed_dev_(nullptr),
weight_compressed_dev_global_(nullptr),
bias_(nullptr),
bias_dev_(nullptr) {
Expand All @@ -435,18 +436,26 @@ SpmmPluginDynamic::SpmmPluginDynamic(const std::string name, const void* data,
weight_compressed_ = new char[compressed_size_];
deserialize_value_size(&data, &length, weight_compressed_, compressed_size_);
//MEM: how to deal with deserialization?
auto* p_tmp = weight_compressed_dev_global_.get();
cudaMalloc(reinterpret_cast<void**>(&p_tmp),
compressed_size_);
cudaMemcpy(weight_compressed_dev_global_.get(), weight_compressed_, compressed_size_,
cudaMemcpyHostToDevice);
cudaMalloc(reinterpret_cast<void**>(&weight_compressed_dev_), compressed_size_);
cudaMemcpy(weight_compressed_dev_, weight_compressed_, compressed_size_, cudaMemcpyHostToDevice);
weight_compressed_dev_global_.reset(weight_compressed_dev_, cudaFreeFunc);

std::cout << "compressed weight:";
char* test_weight = new char[compressed_size_];
cudaMemcpy(test_weight, weight_compressed_dev_global_.get(), compressed_size_,
cudaMemcpyDeviceToHost);
std::cout << "compressed weight in deserial:";
for(int i=0; i<10; i++) {
std::cout << " " << static_cast<float>(reinterpret_cast<float*>(weight_compressed_)[i]);
}
std::cout << std::endl;

std::cout << "weight from shared ptr in deserial:";
for(int i=0; i<10; i++) {
std::cout << " " << static_cast<float>(reinterpret_cast<float*>(test_weight)[i]);
}
std::cout << std::endl;


if (has_bias_) {
bias_ = new float[out_dim_];
deserialize_value_size(&data, &length, bias_, sizeof(float) * out_dim_);
Expand Down

0 comments on commit 865b673

Please sign in to comment.