Skip to content

Commit

Permalink
add maximum op (PaddlePaddle#43)
Browse files Browse the repository at this point in the history
* add maximum
  • Loading branch information
Zhang Ting committed Sep 2, 2021
1 parent 5cd1723 commit 023f984
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 71 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/compiler/piano/backends/llvm_ir/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cc_library(llvm_utils SRCS llvm_utils.cc DEPS llvm)

cc_library(primitive_ir_emitter SRCS primitive_ir_emitter.cc DEPS llvm)
cc_library(gpu_primitive_ir_emitter SRCS gpu_primitive_ir_emitter.cc DEPS primitive_ir_emitter llvm)
cc_library(gpu_primitive_ir_emitter SRCS gpu_primitive_ir_emitter.cc DEPS primitive_ir_emitter llvm note_template_util)
cc_library(nvptx_primitive_ir_emitter SRCS nvptx_primitive_ir_emitter.cc
DEPS gpu_primitive_ir_emitter llvm_utils llvm)

Expand All @@ -14,7 +14,7 @@ cc_library(gpu_ir_emitter SRCS gpu_ir_emitter.cc DEPS ir_emitter llvm_utils llvm
cc_library(nvptx_ir_emitter SRCS nvptx_ir_emitter.cc DEPS gpu_ir_emitter nvptx_primitive_ir_emitter llvm)

cc_test(nvptx_ir_emitter_test SRCS nvptx_ir_emitter_test.cc
DEPS nvptx_ir_emitter note_ir note_proto)
DEPS nvptx_ir_emitter note_ir note_proto piano_symbolization_builder)
target_link_libraries(nvptx_ir_emitter_test ${LLVM_LIBS})

cc_library(llvm_compiler SRCS llvm_compiler.cc DEPS nvptx_ir_emitter nvptx_primitive_ir_emitter note_ir llvm)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "paddle/fluid/compiler/piano/backends/llvm_ir/llvm_utils.h"
#include "paddle/fluid/compiler/piano/backends/llvm_ir/nvptx_primitive_ir_emitter.h"
#include "paddle/fluid/compiler/piano/note/instruction.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace piano {
Expand All @@ -29,6 +30,12 @@ void GpuIrEmitter::VisitElementwiseBinary(const note::Instruction& instr) {
auto lhs_type = instr.operand(0).shape().element_type();
auto rhs_type = instr.operand(1).shape().element_type();
auto out_type = instr.shape().element_type();
PADDLE_ENFORCE_EQ(
lhs_type, rhs_type,
platform::errors::InvalidArgument(
"The inputs of Binary Op should have the same data type, "
"but received the types of inputs are %s and %s.",
lhs_type, rhs_type));

auto func = CreateLLVMFunction(instr.name(), {lhs_type, rhs_type, out_type},
llvm_module_);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "paddle/fluid/compiler/piano/backends/llvm_ir/gpu_primitive_ir_emitter.h"
#include "paddle/fluid/compiler/piano/backends/llvm_ir/primitive_ir_emitter.h"
#include "paddle/fluid/compiler/piano/note/element_type_util.h"
#include "paddle/fluid/compiler/piano/note/instruction.h"
#include "paddle/fluid/compiler/piano/note/opcode.h"

Expand All @@ -23,13 +24,17 @@ namespace backends {

BinaryFunction GpuPrimitiveIrEmitter::GetBinaryComputation(
const note::Instruction& instr) {
return [&instr, this](llvm::Value* lhs, llvm::Value* rhs,
llvm::IRBuilder<>* builder) -> llvm::Value* {
auto lhs_type = instr.operand(0).shape().element_type();
bool is_signed = note::IsSignedInt(lhs_type);
return [&instr, is_signed, this](llvm::Value* lhs, llvm::Value* rhs,
llvm::IRBuilder<>* builder) -> llvm::Value* {
switch (instr.opcode()) {
case note::OpCode::kAdd:
return this->Add(lhs, rhs, builder);
case note::OpCode::kMultiply:
return this->Multiply(lhs, rhs, builder);
case note::OpCode::kMaximum:
return this->Maximum(lhs, rhs, is_signed, builder);
default:
PADDLE_THROW(platform::errors::InvalidArgument("Invalid OpCode."));
}
Expand Down
7 changes: 7 additions & 0 deletions paddle/fluid/compiler/piano/backends/llvm_ir/llvm_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
// limitations under the License.

#include "paddle/fluid/compiler/piano/backends/llvm_ir/llvm_compiler.h"
#include "llvm/IR/Verifier.h"
#include "paddle/fluid/compiler/piano/backends/llvm_ir/nvptx_ir_emitter.h"

namespace paddle {
Expand All @@ -33,6 +34,12 @@ KernelExecutableMap LlvmCompiler::Apply(note::Module* note_module) {
// conver operator to llvm ir
ConvertToIr(*note_module, &llvm_module, &kernel_executable_map);

// verify llvm module
std::string errors;
llvm::raw_string_ostream llvm_errors(errors);
PADDLE_ENFORCE_NE(llvm::verifyModule(llvm_module, &llvm_errors), true,
llvm_errors.str());

// compiler llvm ir to lowring ir
Compile(*note_module, &llvm_module);

Expand Down
136 changes: 85 additions & 51 deletions paddle/fluid/compiler/piano/backends/llvm_ir/nvptx_ir_emitter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,68 +19,102 @@ limitations under the License. */
#include "paddle/fluid/compiler/piano/backends/note_visitor_base.h"
#include "paddle/fluid/compiler/piano/note/function.h"
#include "paddle/fluid/compiler/piano/note/instruction.h"
#include "paddle/fluid/compiler/piano/note/module.h"
#include "paddle/fluid/compiler/piano/note/note.pb.h"
#include "paddle/fluid/compiler/piano/note/opcode.h"
#include "paddle/fluid/compiler/piano/shape.h"
#include "paddle/fluid/compiler/piano/symbolization/note_builder.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace piano {
namespace backends {

void CreadInstructionProto(const Shape& shape, const std::string& name,
const std::string& op_code, uint64_t id,
uint64_t params_index,
note::InstructionProto* instr_proto) {
instr_proto->set_name(name);
instr_proto->set_opcode(op_code);
instr_proto->set_id(id);
instr_proto->set_parameter_index(params_index);
*instr_proto->mutable_shape() = shape.ToProto();
class BinaryOpTest {
public:
void SetInstructionProto(const std::vector<int64_t>& arg1_shape_vec,
const std::vector<int64_t>& arg2_shape_vec,
const std::vector<int64_t>& result_shape_vec,
note::ElementTypeProto type, note::OpCode op_code) {
const Shape arg1_shape(type, arg1_shape_vec);
const Shape arg2_shape(type, arg2_shape_vec);
const Shape result_shape(type, result_shape_vec);
op_code_ = op_code;

SetProto(arg1_shape, "arg1.1", "parameter", 0, &arg1_proto_);
SetProto(arg2_shape, "arg2.2", "parameter", 1, &arg2_proto_);
SetProto(result_shape, note::GetOpName(op_code), note::GetOpName(op_code),
2, &instr_proto_);
}

void SetProto(const Shape& shape, const std::string& name,
const std::string& op_code, uint64_t params_index,
note::InstructionProto* instr_proto) {
instr_proto->set_name(name);
instr_proto->set_opcode(op_code);
instr_proto->set_parameter_index(params_index);
*instr_proto->mutable_shape() = shape.ToProto();
}

void GenLLVMIR() {
// build note module
symbolization::NoteBuilder note_builder("test_note_builder");
std::vector<symbolization::Operand> ops;
ops.push_back(note_builder.AppendInstruction(std::move(arg1_proto_),
note::OpCode::kParameter, {}));
ops.push_back(note_builder.AppendInstruction(std::move(arg2_proto_),
note::OpCode::kParameter, {}));
note_builder.AppendInstruction(std::move(instr_proto_), op_code_, ops);

auto note_proto = note_builder.Build();
note::Module note_module(note_proto);

auto& entry_function = note_module.entry_function();
auto instr = entry_function.instruction(2);

llvm::LLVMContext llvm_context;
llvm::Module llvm_module("", llvm_context);
KernelExecutableMap kernel_executable_map;

NvptxIrEmitter nvptx_ir_emitter(&llvm_module, &kernel_executable_map);
instr->Accept(&nvptx_ir_emitter);

// Printing may be disabled with the increase of test cases.
llvm_module.print(llvm::errs(), nullptr);

std::string errors;
llvm::raw_string_ostream llvm_errors(errors);
PADDLE_ENFORCE_NE(llvm::verifyModule(llvm_module, &llvm_errors), true,
llvm_errors.str());
}

private:
note::InstructionProto arg1_proto_;
note::InstructionProto arg2_proto_;
note::InstructionProto instr_proto_;
note::OpCode op_code_;
};

TEST(NvptxIrEmitter, FP32OpTest) {
std::vector<note::OpCode> op_codes = {note::OpCode::kAdd,
note::OpCode::kMaximum};
BinaryOpTest fp32_test;
for (auto op_code : op_codes) {
fp32_test.SetInstructionProto({3, 6}, {3, 6}, {3, 6},
note::ElementTypeProto::F32, op_code);
fp32_test.GenLLVMIR();
}
}

TEST(NvptxIrEmitter, AddOp) {
const Shape arg1_shape(note::ElementTypeProto::F32, {3, 6});
const Shape arg2_shape(note::ElementTypeProto::F32, {3, 6});
const Shape result_shape(note::F32, {3, 6});

// set arg1_proto
note::InstructionProto arg1_proto;
CreadInstructionProto(arg1_shape, "arg1.1", "parameter", 1, 0, &arg1_proto);
std::unordered_map<std::int64_t, note::Instruction*> instr1_index;
std::unordered_map<std::int64_t, note::Function*> func_index;
note::Instruction arg1_instr(arg1_proto, instr1_index, func_index);

// set arg2_proto
note::InstructionProto arg2_proto;
CreadInstructionProto(arg2_shape, "arg2.2", "parameter", 2, 1, &arg2_proto);
std::unordered_map<std::int64_t, note::Instruction*> instr2_index;
note::Instruction arg2_instr(arg2_proto, instr2_index, func_index);

// set add_proto
note::InstructionProto add_proto;
CreadInstructionProto(result_shape, "add", "add", 3, 2, &add_proto);
add_proto.add_operand_ids(1);
add_proto.add_operand_ids(2);
std::unordered_map<std::int64_t, note::Instruction*> instr3_index;
instr3_index.insert(
std::pair<std::int64_t, note::Instruction*>(1, &arg1_instr));
instr3_index.insert(
std::pair<std::int64_t, note::Instruction*>(2, &arg2_instr));
note::Instruction instr(add_proto, instr3_index, func_index);

llvm::LLVMContext llvm_context;
llvm::Module llvm_module("add", llvm_context);
KernelExecutableMap kernel_executable_map;

NvptxIrEmitter nvptx_ir_emitter(&llvm_module, &kernel_executable_map);
nvptx_ir_emitter.VisitAdd(instr);
llvm_module.print(llvm::errs(), nullptr);

std::string errors;
llvm::raw_string_ostream llvm_errors(errors);
PADDLE_ENFORCE_NE(llvm::verifyModule(llvm_module, &llvm_errors), true,
llvm_errors.str());
TEST(NvptxIrEmitter, Int32OpTest) {
std::vector<note::OpCode> op_codes = {note::OpCode::kAdd,
note::OpCode::kMaximum};
BinaryOpTest int32_test;
for (auto op_code : op_codes) {
int32_test.SetInstructionProto({3, 6}, {3, 6}, {3, 6},
note::ElementTypeProto::S32, op_code);
int32_test.GenLLVMIR();
}
}

} // namespace backends
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
// limitations under the License.

#include "paddle/fluid/compiler/piano/backends/llvm_ir/primitive_ir_emitter.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace piano {
Expand Down Expand Up @@ -89,31 +88,37 @@ void PrimitiveIrEmitter::VisitXor(const note::Instruction& instr) {

llvm::Value* PrimitiveIrEmitter::Add(llvm::Value* lhs, llvm::Value* rhs,
llvm::IRBuilder<>* ir_builder) {
if (lhs->getType()->isIntegerTy() && rhs->getType()->isIntegerTy()) {
if (lhs->getType()->isIntegerTy()) {
return ir_builder->CreateAdd(lhs, rhs);
} else if (lhs->getType()->isFloatingPointTy() &&
rhs->getType()->isFloatingPointTy()) {
return ir_builder->CreateFAdd(lhs, rhs);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The inputs of Add Op should have the same data type, "
"but received the types of inputs are %s and %s.",
lhs->getType(), rhs->getType()));
return ir_builder->CreateFAdd(lhs, rhs);
}
}

llvm::Value* PrimitiveIrEmitter::Multiply(llvm::Value* lhs, llvm::Value* rhs,
llvm::IRBuilder<>* ir_builder) {
if (lhs->getType()->isIntegerTy() && rhs->getType()->isIntegerTy()) {
if (lhs->getType()->isIntegerTy()) {
return ir_builder->CreateMul(lhs, rhs);
} else if (lhs->getType()->isFloatingPointTy() &&
rhs->getType()->isFloatingPointTy()) {
} else {
return ir_builder->CreateFMul(lhs, rhs);
}
}

llvm::Value* PrimitiveIrEmitter::Maximum(llvm::Value* lhs, llvm::Value* rhs,
bool is_signed,
llvm::IRBuilder<>* ir_builder) {
if (lhs->getType()->isIntegerTy()) {
llvm::CmpInst::Predicate predicate =
is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE;
return ir_builder->CreateSelect(ir_builder->CreateICmp(predicate, lhs, rhs),
lhs, rhs);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"The inputs of Multiply Op should have the same data type, "
"but received the types of inputs are %s and %s.",
lhs->getType(), rhs->getType()));
// Implements IEEE 754-2018 maximum semantics. If one of the
// elements being compared is a NaN, then that element is returned.
// So we use unordered comparisons because it always return true
// when one of the operands is NaN.
auto cmp = ir_builder->CreateFCmpUGE(lhs, rhs);
return ir_builder->CreateSelect(cmp, lhs, rhs);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class PrimitiveIrEmitter : public NoteVisitorBase<const note::Instruction&> {
llvm::IRBuilder<>* ir_builder);
llvm::Value* Multiply(llvm::Value* lhs, llvm::Value* rhs,
llvm::IRBuilder<>* ir_builder);
llvm::Value* Maximum(llvm::Value* lhs, llvm::Value* rhs, bool is_signed,
llvm::IRBuilder<>* ir_builder);

llvm::Value* Load(llvm::Value* input, llvm::Value* index,
llvm::IRBuilder<>* ir_builder);
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/compiler/piano/note/element_type_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ ElementTypeProto NativeToElementTypeProto<double>() {
return F64;
}

bool IsSignedInt(ElementTypeProto type) {
return type == ElementTypeProto::S8 || type == ElementTypeProto::S16 ||
type == ElementTypeProto::S32 || type == ElementTypeProto::S64;
}

} // namespace note
} // namespace piano
} // namespace paddle
2 changes: 2 additions & 0 deletions paddle/fluid/compiler/piano/note/element_type_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ struct ElementTypeProtoToNativeT<F64> {
using type = double;
};

bool IsSignedInt(ElementTypeProto type);

} // namespace note
} // namespace piano
} // namespace paddle

0 comments on commit 023f984

Please sign in to comment.