Skip to content

Commit

Permalink
Inference support mixed-precision model [3] (#44057)
Browse files Browse the repository at this point in the history
  • Loading branch information
jiweibo committed Jul 8, 2022
1 parent b2c1247 commit 7f95872
Show file tree
Hide file tree
Showing 32 changed files with 651 additions and 268 deletions.
3 changes: 3 additions & 0 deletions paddle/fluid/inference/analysis/argument.h
Expand Up @@ -331,6 +331,9 @@ struct Argument {

// mixed precision related
DECL_ARGUMENT_FIELD(model_precision, ModelPrecision, int);
DECL_ARGUMENT_FIELD(mixed_black_list,
MixedBlackList,
std::unordered_set<std::string>);

private:
std::unordered_set<std::string> valid_fields_;
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/inference/analysis/ir_pass_manager.cc
Expand Up @@ -87,6 +87,9 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("with_dynamic_shape", new bool(with_dynamic_shape));

pass->Set("model_precision", new int(argument->model_precision()));
pass->Set(
"mixed_black_list",
new std::unordered_set<std::string>(argument->mixed_black_list()));

if (pass_name == "graph_viz_pass") {
std::string optim_cache_dir = argument->optim_cache_dir();
Expand Down
105 changes: 104 additions & 1 deletion paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.cc
Expand Up @@ -13,26 +13,117 @@
// limitations under the License.

#include "paddle/fluid/inference/analysis/ir_passes/tensorrt_subgraph_pass.h"
#include <cstddef>
#include <string>
#include <unordered_set>

#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/subgraph_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/inference/analysis/helper.h"
#include "paddle/fluid/inference/analysis/passes/convert_to_mixed_precision.h"
#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/inference/tensorrt/op_teller.h"
#include "paddle/fluid/inference/utils/io_utils.h"
#include "paddle/phi/common/backend.h"
#include "paddle/phi/common/data_type.h"

namespace paddle {
namespace inference {
namespace analysis {
namespace {

bool IsFloat(framework::proto::VarType::Type t) {
if (t == framework::proto::VarType::FP16 ||
t == framework::proto::VarType::FP32 ||
t == framework::proto::VarType::FP64 ||
t == framework::proto::VarType::BF16)
return true;
return false;
}

// if in mixed model precision, we should make all tensorrt_engine's output
// floats dtype to float32 dtype.
void OutputProcess(framework::ir::Graph *graph,
const std::unordered_set<framework::ir::Node *> &trt_outputs,
phi::Backend backend,
phi::DataType precision,
const std::unordered_set<std::string> &blacklist) {
framework::BlockDesc *block_desc{nullptr};
int suffix = 0;
std::unordered_map<framework::ir::Node *, framework::ir::Node *>
var_to_cast_op_map;

framework::proto::VarType::Type to_type;
if (precision == phi::DataType::FLOAT16) {
to_type = framework::proto::VarType::FP16;
} else if (precision == phi::DataType::BFLOAT16) {
to_type = framework::proto::VarType::BF16;
} else if (precision == phi::DataType::FLOAT32) {
return;
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"mixed_precision currently not supported dtype %d, we now only support "
"fp16 and bf16.",
static_cast<int>(precision)));
}

for (auto *op_node : framework::ir::TopologySortOperations(*graph)) {
if (!op_node->IsOp()) continue;
auto op_type = op_node->Op()->Type();
if (op_type == "feed") block_desc = op_node->Op()->Block();
if (op_type != "tensorrt_engine") continue;
for (auto *var_node : op_node->outputs) {
if (!trt_outputs.count(var_node)) continue;
if (!var_node->Var()->Persistable() &&
IsFloat(var_node->Var()->GetDataType()) &&
var_node->Var()->GetDataType() != framework::proto::VarType::FP32) {
for (auto *next_op : var_node->outputs) {
// if next_op support mixed_precision, we need to add cast op.
if (OpSupportPrecision(
phi::TransToPhiKernelName(next_op->Op()->Type()),
backend,
precision,
blacklist)) {
AddCastOp(graph,
var_node,
next_op,
framework::proto::VarType::FP32,
to_type,
&suffix,
block_desc,
&var_to_cast_op_map);
var_node->Var()->SetDataType(framework::proto::VarType::FP32);
}
}
}
}
}
}

} // namespace

using framework::ir::Node;

void analysis::TensorRtSubgraphPass::ApplyImpl(
framework::ir::Graph *graph) const {
framework::ir::FusePassBase::Init("tensorrt_subgraph_pass", graph);

auto model_precision =
static_cast<phi::DataType>(Get<int>("model_precision"));
if (model_precision == phi::DataType::BFLOAT16) {
LOG(WARNING)
<< "Paddle-TRT not support bf16 mixed precison, just fallback.";
return;
}

auto enable_int8 = Get<bool>("enable_int8");
auto use_calib_mode = Get<bool>("use_calib_mode");
bool no_calib_int8 = enable_int8 && !(use_calib_mode);
Expand Down Expand Up @@ -181,15 +272,25 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
}
}

auto model_precision =
static_cast<phi::DataType>(Get<int>("model_precision"));
auto mixed_black_list =
Get<std::unordered_set<std::string>>("mixed_black_list");

std::set<std::string> output_names;
std::set<std::string> output_names_with_id;
std::map<std::string, int> origin_name_output_dims;
std::unordered_set<Node *> trt_outputs;
for (auto *x : node->outputs) {
output_names.insert(x->Name());
output_names_with_id.insert(x->Name() + std::to_string(x->id()));
origin_name_output_dims[x->Name()] = x->Var()->GetShape().size();
trt_outputs.insert(x);
}

OutputProcess(
graph, trt_outputs, phi::Backend::GPU, model_precision, mixed_black_list);

std::unordered_map<std::string, std::string> output_name_map;
std::unordered_map<std::string, framework::ir::Node *> graph_var_map;

Expand Down Expand Up @@ -285,6 +386,7 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
op_desc->SetAttr("allow_build_at_runtime", allow_build_at_runtime);
op_desc->SetAttr("shape_range_info_path", shape_range_info_path);
op_desc->SetAttr("use_inspector", Get<bool>("use_inspector"));
op_desc->SetAttr("model_precision", Get<int>("model_precision"));

// we record all inputs' shapes in attr to check if they are consistent
// with the real inputs' shapes retrieved from scope when trt runs.
Expand Down Expand Up @@ -404,7 +506,8 @@ void TensorRtSubgraphPass::CreateTensorRTOp(
min_input_shape,
max_input_shape,
opt_input_shape,
disable_trt_plugin_fp16);
disable_trt_plugin_fp16,
static_cast<phi::DataType>(Get<int>("model_precision")));
trt_engine->SetUseOSS(Get<bool>("use_varseqlen"));
trt_engine->SetWithInterleaved(Get<bool>("with_interleaved"));
trt_engine->SetTransformerPosid(
Expand Down
Expand Up @@ -18,6 +18,7 @@

#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
Expand Down Expand Up @@ -379,27 +380,21 @@ void ConvertToMixedPrecision(const std::string& model_file,
};

std::unordered_set<std::string> weights_should_be_fp32;
for (auto* node : paddle::framework::ir::TopologySortOperations(*graph)) {
if (!node->IsOp()) continue;
auto* op_desc = node->Op();
if (op_desc->Type() == "feed" || op_desc->Type() == "fetch") continue;

if (op_desc->Type() == "batch_norm") {
auto vecs = op_desc->Input("Bias");
for (auto s : vecs) {
weights_should_be_fp32.insert(s);
}
vecs = op_desc->Input("Mean");
for (auto s : vecs) {
weights_should_be_fp32.insert(s);
}
vecs = op_desc->Input("Scale");
for (auto s : vecs) {
weights_should_be_fp32.insert(s);
}
vecs = op_desc->Input("Variance");
for (auto s : vecs) {
weights_should_be_fp32.insert(s);
for (auto* node : graph->Nodes()) {
if (!node->IsVar()) continue;
if (node->Var()->GetType() ==
paddle::framework::proto::VarType::SELECTED_ROWS ||
node->Var()->GetType() ==
paddle::framework::proto::VarType::LOD_TENSOR ||
node->Var()->GetType() ==
paddle::framework::proto::VarType::LOD_TENSOR_ARRAY ||
node->Var()->GetType() == paddle::framework::proto::VarType::STRINGS ||
node->Var()->GetType() == paddle::framework::proto::VarType::VOCAB) {
if (node->Var()->Persistable() &&
node->Var()->GetDataType() ==
paddle::framework::proto::VarType::FP32) {
VLOG(2) << "weights keep to fp32: " << node->Name();
weights_should_be_fp32.insert(node->Name());
}
}
}
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/inference/api/analysis_config.cc
Expand Up @@ -256,6 +256,9 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(gpu_device_id_);
CP_MEMBER(memory_pool_init_size_mb_);

// Mixed related.
CP_MEMBER(mixed_black_list_);

CP_MEMBER(enable_memory_optim_);
// TensorRT related.
CP_MEMBER(use_tensorrt_);
Expand Down Expand Up @@ -871,6 +874,7 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss << ipu_available_memory_proportion_;
ss << ipu_enable_half_partial_;

for (auto &op : mixed_black_list_) ss << op.c_str();
return ss.str();
}

Expand Down Expand Up @@ -1188,4 +1192,10 @@ bool AnalysisConfig::tuned_tensorrt_dynamic_shape() {
bool AnalysisConfig::trt_allow_build_at_runtime() {
return trt_allow_build_at_runtime_;
}

void AnalysisConfig::Exp_SetBlackListOpsForMixedModel(
const std::unordered_set<std::string> &black_list) {
mixed_black_list_ = black_list;
}

} // namespace paddle
2 changes: 2 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Expand Up @@ -1216,7 +1216,9 @@ void AnalysisPredictor::PrepareArgument() {
argument_.SetAnalysisPasses(config_.pass_builder()->AnalysisPasses());
argument_.SetScopeNotOwned(scope_.get());

// mixed precison.
argument_.SetModelPrecision(static_cast<int>(model_precision_));
argument_.SetMixedBlackList(config_.mixed_black_list_);
}

// NOTE All the members in AnalysisConfig should be copied to Argument.
Expand Down
11 changes: 11 additions & 0 deletions paddle/fluid/inference/api/paddle_analysis_config.h
Expand Up @@ -914,6 +914,14 @@ struct PD_INFER_DECL AnalysisConfig {

const DistConfig& dist_config() const { return dist_config_; }

///
/// \brief Set a list of operators that do not support mixed precision. This
/// interface is in the experimental stage and may change in the future. Note
/// that the blacklist must be the same as the model conversion blacklist.
///
void Exp_SetBlackListOpsForMixedModel(
const std::unordered_set<std::string>& black_list);

protected:
// Update the config.
void Update();
Expand All @@ -926,6 +934,9 @@ struct PD_INFER_DECL AnalysisConfig {
mutable std::string prog_file_;
mutable std::string params_file_;

// Mixed precision.
std::unordered_set<std::string> mixed_black_list_;

// GPU related.
bool use_gpu_{false};
int gpu_device_id_{0};
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/inference/api/paddle_pass_builder.cc
Expand Up @@ -160,6 +160,10 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
const std::vector<std::string> kTrtLowerPrecisionPasses{
// "conv_bn_fuse_pass",
// "conv_eltwiseadd_bn_fuse_pass",
"trt_map_matmul_v2_to_mul_pass",
"trt_map_matmul_v2_to_matmul_pass",
"trt_map_matmul_to_mul_pass",
"fc_fuse_pass",
"tensorrt_subgraph_pass",
};

Expand Down
20 changes: 12 additions & 8 deletions paddle/fluid/inference/tensorrt/convert/affine_channel_op.cc
Expand Up @@ -50,22 +50,26 @@ class AffineChannelOpConverter : public OpConverter {

auto* scale_v = scope.FindVar(scale_name);
auto* scale_t = scale_v->GetMutable<framework::LoDTensor>();
float* scale_ptr = engine_->GetWeightCPUData(scale_name, scale_t);
float* scale_ptr = const_cast<float*>(static_cast<const float*>(
engine_->GetFp32TrtWeight(scale_name, *scale_t).get().values));

auto* bias_v = scope.FindVar(bias_name);
auto* bias_t = bias_v->GetMutable<framework::LoDTensor>();
float* bias_ptr = engine_->GetWeightCPUData(bias_name, bias_t);
float* bias_ptr = const_cast<float*>(static_cast<const float*>(
engine_->GetFp32TrtWeight(bias_name, *bias_t).get().values));

// tensorrt scalend layer only support spatial dims >= 2,
// so nhwc is not availabe (spatial dims == 0)
const int channel_axis = engine_->with_dynamic_shape();

TensorRTEngine::Weight scale_weights{nvinfer1::DataType::kFLOAT,
static_cast<void*>(scale_ptr),
(size_t)idim.d[channel_axis]};
TensorRTEngine::Weight bias_weights{nvinfer1::DataType::kFLOAT,
static_cast<void*>(bias_ptr),
(size_t)idim.d[channel_axis]};
TensorRTEngine::Weight scale_weights{
nvinfer1::DataType::kFLOAT,
static_cast<void*>(scale_ptr),
static_cast<size_t>(idim.d[channel_axis])};
TensorRTEngine::Weight bias_weights{
nvinfer1::DataType::kFLOAT,
static_cast<void*>(bias_ptr),
static_cast<size_t>(idim.d[channel_axis])};
TensorRTEngine::Weight power_weights{
nvinfer1::DataType::kFLOAT, nullptr, 0};

Expand Down
23 changes: 11 additions & 12 deletions paddle/fluid/inference/tensorrt/convert/conv2d_op.cc
Expand Up @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/engine.h"
#include "paddle/phi/common/data_type.h"

namespace paddle {
namespace framework {
Expand Down Expand Up @@ -48,7 +50,7 @@ void ConvertConv2d(TensorRTEngine* engine,
platform::errors::NotFound("Can not find %s presistale var in scope.",
filter_var_name));
auto* Y_t = Y_v->GetMutable<framework::LoDTensor>();
float* weight_data = nullptr;

bool enable_int8 = op_desc.HasAttr("enable_int8");

if (enable_int8) {
Expand All @@ -57,7 +59,6 @@ void ConvertConv2d(TensorRTEngine* engine,
engine->SetTensorDynamicRange(X, in_scale);
#endif
}
weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t);

PADDLE_ENFORCE_EQ(Y_t->dims().size(),
4UL,
Expand Down Expand Up @@ -104,21 +105,19 @@ void ConvertConv2d(TensorRTEngine* engine,
nv_post_paddings.d[1] = paddings[3];
}

TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
static_cast<size_t>(Y_t->numel())};
float* bias_data = nullptr;
size_t bias_size = 0;
auto weight = engine->GetTrtWeight(op_desc.Input("Filter").front(), *Y_t);

TensorRTEngine::Weight bias;
bias.SetDataType(weight.get().type);
bias.SetCount(0);
bias.SetValues(nullptr);
if (op_desc.Type() == "conv2d_fusion") {
auto* bias_tensor = scope.GetVar(op_desc.Input("Bias").front());
auto* bias_tensor_data = bias_tensor->GetMutable<framework::LoDTensor>();
bias_data = engine->GetWeightCPUData(op_desc.Input("Bias").front(),
bias_tensor_data);
bias_size = static_cast<size_t>(bias_tensor_data->numel());
bias =
engine->GetTrtWeight(op_desc.Input("Bias").front(), *bias_tensor_data);
}

TensorRTEngine::Weight bias{
nvinfer1::DataType::kFLOAT, static_cast<void*>(bias_data), bias_size};
// In conv2d_transpose and depthwise_conv2d_transpose,
// output channels = filter_dims[1] * groups
auto* layer = (op_desc.Type() == "conv2d_transpose" ||
Expand Down

0 comments on commit 7f95872

Please sign in to comment.