Skip to content

Commit

Permalink
Fix julia decl (#1933)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jun 21, 2024
1 parent 9da1683 commit cd5e51f
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions enzyme/tools/enzyme-tblgen/blas-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ void rev_call_arg(bool forward, DagInit *ruleDag, const TGPattern &pattern,
<< "), res);\n";
}
os << "SmallVector<Value *, 1>vs = { to_blas_callconv(Builder2, res, "
"byRef, cublas, nullptr, allocationBuilder, \""
"byRef, cublas, julia_decl_type, allocationBuilder, \""
<< Def->getName() << "." << name << "\") }; vs; })";
return;
}
Expand Down Expand Up @@ -999,7 +999,7 @@ void rev_call_arg(bool forward, DagInit *ruleDag, const TGPattern &pattern,
os << "marg_" << i << "[marg_" << i << ".size() == 1 ? 0 : i]";
}
if (op != "Select")
os << "), byRef, cublas, nullptr, "
os << "), byRef, cublas, julia_decl_type, "
"allocationBuilder, \""
<< Def->getValueAsString("s") << "\" )";
else
Expand Down Expand Up @@ -1260,8 +1260,12 @@ void rev_call_args(bool forward, Twine argName, const TGPattern &pattern,
int n = 0;
if (func == "gemv" || func == "lascl")
n = 1;
if (func == "gemm")
if (func == "gemm" || func == "syrk")
n = 2;
if (func == "trmv")
n = 3;
if (func == "trmm")
n = 4;
for (int i = 0; i < n; i++)
os << " " << argName
<< ".push_back(ConstantInt::get(intType, 1));\n";
Expand Down

0 comments on commit cd5e51f

Please sign in to comment.