Skip to content

Commit

Permalink
fix(compiler): fix the bug that activation function is lost when conv…
Browse files Browse the repository at this point in the history
…erting convolution without bias but with activation function.

GitOrigin-RevId: 15d4378ae5759158b63945d41a4fda038ee2c480
  • Loading branch information
megvii-mge committed May 10, 2024
1 parent 4bb59d2 commit 6708597
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 22 deletions.
8 changes: 5 additions & 3 deletions compiler/lib/Conversion/MGBToKernel/MGBToKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,10 +248,12 @@ class ConvertConvLike final : public ConversionPattern {
return failure();
operand_segment_sizes = {1, 1, 0, 0, 1};
} else if (isa<MGB::ConvBias>(op)) {
// FIXME: only conv_bias(input, weight, bias) is supported now
if (operands.size() != 3)
if (operands.size() == 3)
operand_segment_sizes = {1, 1, 1, 0, 1};
else if (operands.size() == 2)
operand_segment_sizes = {1, 1, 0, 0, 1};
else
return failure();
operand_segment_sizes = {1, 1, 1, 0, 1};
} else if (isa<MGB::ConvolutionBackwardData>(op)) {
if (operands.size() != 2)
return failure();
Expand Down
24 changes: 6 additions & 18 deletions compiler/lib/Target/MGB/importer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -639,24 +639,12 @@ class Importer {
} else if (auto conv = opr->try_cast_final<opr::ConvBiasForward>()) {
auto&& p = conv->param();
auto&& out = opr->output(0);
if (opr->input().size() == 2) {
CC_ASSERT(
!is_dynamic_value(m_var2value.at(opr->input(0))) &&
!is_dynamic_value(m_var2value.at(opr->input(1))));
mlir::Value value = m_builder.create<mlir::MGB::Convolution>(
m_builder.getUnknownLoc(), var_to_shaped_type(out),
m_var2value.at(opr->input(0)), m_var2value.at(opr->input(1)),
p.mode, p.pad_h, p.pad_w, p.stride_h, p.stride_w, p.dilate_h,
p.dilate_w, p.sparse, p.format, p.compute_mode);
m_var2value.emplace(out, value);
} else {
mlir::Value value = m_builder.create<mlir::MGB::ConvBias>(
m_builder.getUnknownLoc(), var_to_shaped_type(out),
var_array_to_value_array(opr->input(), true), p.nonlineMode,
p.mode, p.sparse, p.format, p.pad_h, p.pad_w, p.stride_h,
p.stride_w, p.dilate_h, p.dilate_w, p.compute_mode);
m_var2value.emplace(out, value);
}
mlir::Value value = m_builder.create<mlir::MGB::ConvBias>(
m_builder.getUnknownLoc(), var_to_shaped_type(out),
var_array_to_value_array(opr->input(), true), p.nonlineMode, p.mode,
p.sparse, p.format, p.pad_h, p.pad_w, p.stride_h, p.stride_w,
p.dilate_h, p.dilate_w, p.compute_mode);
m_var2value.emplace(out, value);
} else if (auto resize_opr = opr->try_cast_final<opr::ResizeForward>()) {
auto&& p = resize_opr->param();
auto&& out = opr->output(0);
Expand Down
2 changes: 1 addition & 1 deletion script/build_and_test_not_standard_os.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ cmake --build "$MEGCC_BUILD_DIR" -j$(nproc) --target mgb-to-tinynn --target mgb-

function check_key_words() {
#elf self mangle words, we do not care!!
white_list="@MEGW mgb1 5Mbg6 MGBi O:MgBnWk <mbG =MEG>Yr]< 4emUi0B >HMgE kMEG RmEg MbGV4 MEgIy @MEg mGe#S BMgb MGB( mBg: MBgr8C A&mGB mEg; mGb>/ mEg= .strtab .shstrtab A=MgE= mgb=g MGe= g=MgE <mgE= =Mgb> MGE< 8<MGE= =Mge =Mgb=K <MBG <MGE =MGB="
white_list="@MEGW mgb1 5Mbg6 MGBi O:MgBnWk <mbG =MEG>Yr]< 4emUi0B >HMgE kMEG RmEg MbGV4 MEgIy @MEg mGe#S BMgb MGB( mBg: MBgr8C A&mGB mEg; mGb>/ mEg= .strtab .shstrtab A=MgE= mgb=g MGe= g=MgE <mgE= =Mgb> MGE< 8<MGE= =Mge =Mgb=K <MBG <MGE =MGB= y?=Mgb="
elf_file=$1
if [ ! -f ${elf_file} ];then
echo "ERR: can not find ${elf_file}"
Expand Down

0 comments on commit 6708597

Please sign in to comment.