Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
jiweibo committed Jul 6, 2022
1 parent 8830b02 commit 4def6c1
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 27 deletions.
29 changes: 2 additions & 27 deletions paddle/fluid/inference/tensorrt/convert/op_converter.h
Expand Up @@ -563,33 +563,9 @@ class OpConverter {
const std::string& name) {
auto* var_v = scope.FindVar(name);
auto* var_t = var_v->GetMutable<framework::LoDTensor>();
void* trt_ptr = nullptr;
size_t trt_num = static_cast<size_t>(var_t->numel());
nvinfer1::DataType trt_dtype = nvinfer1::DataType::kFLOAT;
if (var_t->dtype() == phi::DataType::FLOAT32) {
float* data_ptr = engine_->GetWeightCPUData(name, var_t);
trt_ptr = static_cast<void*>(data_ptr);
} else if (var_t->dtype() == phi::DataType::INT32) {
int32_t* data_ptr = engine_->GetWeightCPUData<int32_t>(name, var_t);
trt_ptr = static_cast<void*>(data_ptr);
trt_dtype = nvinfer1::DataType::kINT32;
} else if (var_t->dtype() == phi::DataType::INT64) {
int64_t* data_ptr = engine_->GetWeightCPUData<int64_t>(name, var_t);
// We must create a new framework::Tensor()
std::unique_ptr<framework::Tensor> new_var_t(new framework::Tensor());
new_var_t->Resize({var_t->numel()});
int32_t* new_data_ptr =
new_var_t->mutable_data<int32_t>(platform::CPUPlace());
for (size_t i = 0; i < trt_num; i++) {
new_data_ptr[i] = data_ptr[i];
}
engine_->SetWeights(name, std::move(new_var_t));
trt_ptr = static_cast<void*>(new_data_ptr);
trt_dtype = nvinfer1::DataType::kINT32;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported datatype in TensorRT"));
}
auto weight = engine_->GetTrtWeight(name, *var_t);

// Now we have create weights, then we need create a itensor
auto var_dims = var_t->dims();
nvinfer1::Dims trt_in_shape;
Expand All @@ -605,7 +581,6 @@ class OpConverter {
trt_in_shape.d[i] = trt_in_shape.d[i + 1];
}
}
TensorRTEngine::Weight weight{trt_dtype, trt_ptr, trt_num};
nvinfer1::ILayer* layer =
TRT_ENGINE_ADD_LAYER(engine_, Constant, trt_in_shape, weight.get());
engine_->SetITensor(name, layer->getOutput(0));
Expand Down
15 changes: 15 additions & 0 deletions paddle/fluid/inference/tensorrt/engine.cc
Expand Up @@ -537,6 +537,21 @@ TensorRTEngine::Weight TensorRTEngine::GetTrtWeight(
}
weight.SetDataType(phi::DataType::FLOAT32);
weight.SetValues(fp32_data);
} else if (weight_tensor.dtype() == phi::DataType::INT64) {
framework::Tensor int64_tensor;
int64_tensor.clear();
paddle::framework::TensorCopySync(
weight_tensor, platform::CPUPlace(), &int64_tensor);
weight_map[name_with_suffix]->set_type(
paddle::experimental::DataType::INT32);
auto *int32_data =
weight_map[name_with_suffix]->mutable_data<int>(platform::CPUPlace());
auto *int64_data = int64_tensor.mutable_data<int64_t>(platform::CPUPlace());
for (int i = 0; i < weight_tensor.numel(); i++) {
int32_data[i] = int64_data[i];
}
weight.SetDataType(phi::DataType::FLOAT32);
weight.SetValues(int32_data);
} else {
paddle::framework::TensorCopySync(
weight_tensor, cpu_place, weight_map[name_with_suffix].get());
Expand Down

1 comment on commit 4def6c1

@paddle-bot-old
Copy link

@paddle-bot-old paddle-bot-old bot commented on 4def6c1 Jul 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🕵️ CI failures summary

🔍 PR: #44057 Commit ID: 4def6c1 contains failed CI.

🔹 Failed: PR-CI-APPROVAL

Unknown Failed
Unknown Failed

🔹 Failed: PR-CI-Static-Check

Unknown Failed
Unknown Failed

🔹 Failed: PR-CI-Inference

Unknown Failed
Unknown Failed

🔹 Failed: PR-CI-GpuPS

Unknown Failed
Unknown Failed

🔹 Failed: PR-CI-Coverage

Unknown Failed
Unknown Failed

Please sign in to comment.