Skip to content

Commit

Permalink
Merge branch 'develop' into my-cool-stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
Ligoml committed Dec 13, 2022
2 parents 6568bcc + acee3dd commit 11f10ff
Show file tree
Hide file tree
Showing 75 changed files with 1,150 additions and 3,007 deletions.
5 changes: 5 additions & 0 deletions paddle/fluid/inference/analysis/argument.h
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,11 @@ struct Argument {
// cinn compiler related
DECL_ARGUMENT_FIELD(use_cinn_compiler, UseCinnCompiler, bool);

// custom device
DECL_ARGUMENT_FIELD(use_custom_device, UseCustomDevice, bool);
DECL_ARGUMENT_FIELD(custom_device_type, CustomDeviceType, std::string);
DECL_ARGUMENT_FIELD(custom_device_id, CustomDeviceId, int);

private:
std::unordered_set<std::string> valid_fields_;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/fluid/inference/analysis/passes/ir_params_sync_among_devices_pass.h"

#include <cstdlib>
#include <string>
#include <unordered_set>

Expand All @@ -26,6 +27,11 @@
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/common/data_type.h"

DEFINE_bool(
custom_model_save_cpu,
false,
"Keep old mode for developers, the model is saved on cpu not device.");

namespace paddle {
namespace inference {
namespace analysis {
Expand Down Expand Up @@ -71,9 +77,9 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToNpu(Argument *argument) {
}
}
}
#endif

#else

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) {
// The parameters are on the cpu, therefore, synchronization is not necessary.
if (!argument->use_gpu()) return;
Expand Down Expand Up @@ -148,21 +154,83 @@ void IrParamsSyncAmongDevicesPass::CopyParamsToGpu(Argument *argument) {
}
}
}
#endif

#ifdef PADDLE_WITH_CUSTOM_DEVICE
void IrParamsSyncAmongDevicesPass::CopyParamsToCustomDevice(
Argument *argument) {
if (!argument->use_custom_device()) return;

// On old mode, the model is saved on cpu not device.
if (argument->custom_device_type() == "OpenCL") {
PADDLE_ENFORCE_EQ(
FLAGS_custom_model_save_cpu,
false,
phi::errors::InvalidArgument(
"'FLAGS_custom_model_save_cpu = false' is only for the developers "
"who have not completed custom device memory settings. Setting to "
"true will make "
"model memory reserve on the cpu, and make inference slower."));
}

if (FLAGS_custom_model_save_cpu) return;

auto &graph = argument->main_graph();
std::vector<std::string> repetitive_params;

if (graph.Has(framework::ir::kRepetitiveParamAttr))
repetitive_params = graph.Get<std::vector<std::string>>(
framework::ir::kRepetitiveParamAttr);

LOG(INFO) << "Sync params from CPU to CustomDevice"
<< argument->custom_device_type() << "/"
<< argument->custom_device_id();

platform::Place place = platform::CustomPlace(argument->custom_device_type(),
argument->custom_device_id());
auto *scope = argument->scope_ptr();
std::vector<std::string> all_vars = scope->LocalVarNames();

for (auto &var_name : all_vars) {
auto *var = scope->FindLocalVar(var_name);
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::PreconditionNotMet("The var should not be nullptr"));

if (var->IsType<phi::DenseTensor>() || var->IsType<phi::DenseTensor>()) {
auto *t = var->GetMutable<phi::DenseTensor>();

platform::CPUPlace cpu_place;
phi::DenseTensor temp_tensor;
temp_tensor.Resize(t->dims());

paddle::framework::TensorCopySync(*t, cpu_place, &temp_tensor);
t->clear();
paddle::framework::TensorCopySync(temp_tensor, place, t);
}
}
}
#endif

void IrParamsSyncAmongDevicesPass::RunImpl(Argument *argument) {
PADDLE_ENFORCE_EQ(
argument->scope_valid(),
true,
platform::errors::PreconditionNotMet("The scope field should be valid"));

#ifdef PADDLE_WITH_ASCEND_CL
if (!argument->use_npu_valid()) return;
CopyParamsToNpu(argument);
#else
if (!argument->use_gpu_valid()) return;
CopyParamsToGpu(argument);
if (argument->use_npu_valid()) {
CopyParamsToNpu(argument);
}
#endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (argument->use_gpu_valid()) {
CopyParamsToGpu(argument);
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
if (argument->use_custom_device_valid()) {
CopyParamsToCustomDevice(argument);
}
#endif
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,15 @@ class IrParamsSyncAmongDevicesPass : public AnalysisPass {
private:
#ifdef PADDLE_WITH_ASCEND_CL
void CopyParamsToNpu(Argument *argument);
#else
#endif

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
void CopyParamsToGpu(Argument *argument);
#endif

#ifdef PADDLE_WITH_CUSTOM_DEVICE
void CopyParamsToCustomDevice(Argument *argument);
#endif
};

} // namespace analysis
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/inference/api/analysis_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1242,6 +1242,15 @@ void AnalysisPredictor::PrepareArgument() {
}
#endif

#ifdef PADDLE_WITH_CUSTOM_DEVICE
argument_.SetUseCustomDevice(config_.use_custom_device());
if (config_.use_custom_device()) {
LOG(INFO) << "CustomDevice is enabled";
argument_.SetCustomDeviceType(config_.custom_device_type());
argument_.SetCustomDeviceId(config_.custom_device_id());
}
#endif

auto *pass_builder = config_.pass_builder();
// TODO(inference): Need to reconstruct the pass_builder, pass should be
// processed in a single
Expand Down
3 changes: 2 additions & 1 deletion paddle/fluid/inference/tensorrt/convert/matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ class MatMulOpConverter : public OpConverter {
if (op_desc.HasAttr("support_int8") &&
PADDLE_GET_CONST(bool, op_desc.GetAttr("support_int8")) &&
engine_->precision() == AnalysisConfig::Precision::kInt8 &&
platform::GetGPUComputeCapability(0) >= 75) {
platform::GetGPUComputeCapability(platform::GetCurrentDeviceId()) >=
75) {
if (engine_->with_dynamic_shape()) {
VLOG(3) << "Convert a fluid matmul_op_int8_dynamic to TensorRT "
"MatmulPluginLayer";
Expand Down
12 changes: 6 additions & 6 deletions paddle/fluid/inference/tensorrt/convert/multihead_matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,10 @@ class MultiheadMatMulOpConverter : public OpConverter {
engine_->tensorrt_transformer_posid() != "" &&
engine_->tensorrt_transformer_maskid() != "";
if (engine_->with_dynamic_shape()) {
if (engine_->tensorrt_transformer_maskid() != "") {
if (engine_->precision() == AnalysisConfig::Precision::kFloat32) {
PADDLE_THROW(platform::errors::Fatal(
"use use_varseqlen must be int8 or half, not float32."));
}
if (engine_->tensorrt_transformer_maskid() != "" &&
engine_->precision() != AnalysisConfig::Precision::kFloat32 &&
platform::GetGPUComputeCapability(platform::GetCurrentDeviceId()) >=
75) {
nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
static_cast<int32_t>(weight_t->numel())};
Expand Down Expand Up @@ -401,7 +400,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
} else {
if (input_dims.d[1] <= 384 && !bias_qk_attr &&
engine_->precision() != AnalysisConfig::Precision::kFloat32 &&
platform::GetGPUComputeCapability(0) >= 75) {
platform::GetGPUComputeCapability(platform::GetCurrentDeviceId()) >=
75) {
/*
* input_dims.d[0]: batch(-1)
* input_dims.d[1]: length:256
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/inference/tensorrt/plugin/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ list(
generic_plugin.cu
lookup_table.cu
many_emb_layernorm_plugin.cu
many_emb_Layernorm_kernel.cu)
many_emb_layernorm_kernel.cu)

if(${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 7)
list(APPEND TRT_FILES many_emb_layernorm_varseqlen_plugin.cu
many_emb_Layernorm_varseqlen_kernel_mtron.cu
many_emb_Layernorm_varseqlen_kernel_hface.cu)
many_emb_layernorm_varseqlen_kernel_mtron.cu
many_emb_layernorm_varseqlen_kernel_hface.cu)
endif()

if(CUSPARSELT_FOUND AND ${TENSORRT_MAJOR_VERSION} GREATER_EQUAL 8)
Expand Down
29 changes: 21 additions & 8 deletions paddle/fluid/inference/tensorrt/plugin/generic_plugin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -300,35 +300,48 @@ bool GenericPlugin::supportsFormatCombination(
int nb_outputs) TRT_NOEXCEPT {
if (op_desc_.Type() == "gather_nd" || op_desc_.Type() == "yolo_box") {
if (pos == 0)
return (in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
return (in_out[pos].type == nvinfer1::DataType::kFLOAT ||
(isFp16Supported() &&
in_out[pos].type == nvinfer1::DataType::kHALF)) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
if (pos == 1)
return (in_out[pos].type == nvinfer1::DataType::kINT32) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
// output
if (pos == 2)
return (in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
return in_out[0].type == in_out[pos].type &&
in_out[0].format == in_out[pos].format;
} else if (op_desc_.Type() == "scatter_nd_add") {
// input X
if (pos == 0)
return (in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
return (in_out[pos].type == nvinfer1::DataType::kFLOAT ||
(isFp16Supported() &&
in_out[pos].type == nvinfer1::DataType::kHALF)) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
// input Index
if (pos == 1)
return (in_out[pos].type == nvinfer1::DataType::kINT32) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
// input Updates
if (pos == 2)
return (in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
return (in_out[pos].type == nvinfer1::DataType::kFLOAT ||
(isFp16Supported() &&
in_out[pos].type == nvinfer1::DataType::kHALF)) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
// output
if (pos == 3)
return (in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
return in_out[0].type == in_out[pos].type &&
in_out[0].format == in_out[pos].format;
} else if (op_desc_.Type() == "pad3d") {
return (in_out[pos].type == nvinfer1::DataType::kFLOAT ||
(isFp16Supported() &&
in_out[pos].type == nvinfer1::DataType::kHALF)) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR) &&
(in_out[0].type == in_out[pos].type);
} else {
return (in_out[pos].type == nvinfer1::DataType::kFLOAT) &&
return (in_out[pos].type == nvinfer1::DataType::kFLOAT ||
(isFp16Supported() &&
in_out[pos].type == nvinfer1::DataType::kHALF)) &&
(in_out[pos].format == nvinfer1::TensorFormat::kLINEAR);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ static void trt_ernie(bool with_fp16, std::vector<float> result) {
config.EnableTensorRtEngine(1 << 30, 1, 5, precision, true, false);
config.SetTRTDynamicShapeInfo(
min_input_shape, max_input_shape, opt_input_shape);
paddle_infer::experimental::InternalUtils::SetTransformerMaskid(
&config, "read_file_0.tmp_4");
AnalysisConfig* config_deser = new AnalysisConfig(config);

std::vector<float> out_data;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ void trt_ernie(bool with_fp16,
config.EnableTensorRtEngine(1 << 30, 1, 5, precision, false, false);
config.SetTRTDynamicShapeInfo(
min_input_shape, max_input_shape, opt_input_shape);
paddle_infer::experimental::InternalUtils::SetTransformerMaskid(
&config, "read_file_0.tmp_4");
std::vector<float> out_data;
run(config, &out_data, batch_size);

Expand Down Expand Up @@ -423,7 +425,7 @@ void run(paddle_infer::Predictor* predictor, std::vector<float>* out_data) {

TEST(AnalysisPredictor, ernie_varlen) {
#if IS_TRT_VERSION_GE(7234)
if (platform::GetGPUComputeCapability(0) >= 75) {
if (platform::GetGPUComputeCapability(platform::GetCurrentDeviceId()) >= 75) {
auto predictor = InitPredictor();
std::vector<float> out_data;
run(predictor.get(), &out_data);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ PADDLE_DEFINE_EXPORTED_READONLY_bool(

PADDLE_DEFINE_EXPORTED_READONLY_bool(
free_when_no_cache_hit,
true,
false,
"Whether to free idle chunks when no cache hit. If true, idle "
"chunk would be freed when no cache hit; if false, idle "
"chunk would be freed when out of memory occurs. This flag "
Expand Down
2 changes: 0 additions & 2 deletions paddle/phi/core/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,6 @@ PADDLE_DEFINE_EXPORTED_string(jit_engine_type,
"Predictor",
"Choose default funciton type in JitLayer.");

#ifdef PADDLE_WITH_CUSTOM_DEVICE
/**
* Custom Device NPU related FLAG
* Name: FLAGS_npu_storage_format
Expand All @@ -1050,7 +1049,6 @@ PADDLE_DEFINE_EXPORTED_string(jit_engine_type,
* Note: Enable NPU Storage Format for Ascend910 performance improvement.
*/
PADDLE_DEFINE_EXPORTED_bool(npu_storage_format, false, "");
#endif

#ifdef PADDLE_WITH_CUDNN_FRONTEND
/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def _setup_nccl_op(self, startup_program, main_program, build_strategy):
attrs={
"trainers": trainer_endpoints,
"trainer_id": trainer_id,
"nccl_comm_num": build_strategy.nccl_comm_num,
"bkcl_comm_num": build_strategy.bkcl_comm_num,
"use_hierarchical_allreduce": build_strategy.use_hierarchical_allreduce,
"hierarchical_allreduce_inter_ranks": build_strategy.hierarchical_allreduce_inter_nranks,
},
Expand Down
Loading

0 comments on commit 11f10ff

Please sign in to comment.