Skip to content

Commit

Permalink
Cublas byref fixup (#1877)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed May 12, 2024
1 parent 2851f39 commit d18f50c
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 65 deletions.
2 changes: 2 additions & 0 deletions enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1117,8 +1117,10 @@ void TypeAnalyzer::updateAnalysis(Value *Val, TypeTree Data, Value *Origin) {
}
if (auto I = dyn_cast<Instruction>(Val)) {
EmitFailure("IllegalUpdateAnalysis", I->getDebugLoc(), I, ss.str());
exit(1);
} else if (auto I = dyn_cast_or_null<Instruction>(Origin)) {
EmitFailure("IllegalUpdateAnalysis", I->getDebugLoc(), I, ss.str());
exit(1);
} else {
llvm::errs() << ss.str() << "\n";
}
Expand Down
29 changes: 16 additions & 13 deletions enzyme/test/Integration/ReverseMode/cublas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ void my_dgemv(cublasHandle_t *handle, cublasOperation_t trans, int M, int N,
double alpha, double *__restrict__ A, int lda,
double *__restrict__ X, int incx, double beta,
double *__restrict__ Y, int incy) {
cublasDgemv(handle, trans, M, N, alpha, A, lda, X, incx, beta, Y, incy);
cublasDgemv(handle, trans, M, N, &alpha, A, lda, X, incx, &beta, Y, incy);
inDerivative = true;
}

void ow_dgemv(cublasHandle_t *handle, cublasOperation_t trans, int M, int N,
double alpha, double *A, int lda, double *X, int incx,
double beta, double *Y, int incy) {
cublasDgemv(handle, trans, M, N, alpha, A, lda, X, incx, beta, Y, incy);
cublasDgemv(handle, trans, M, N, &alpha, A, lda, X, incx, &beta, Y, incy);
inDerivative = true;
}

Expand All @@ -55,8 +55,8 @@ void my_dgemm(cublasHandle_t *handle, cublasOperation_t transA,
cublasOperation_t transB, int M, int N, int K, double alpha,
double *__restrict__ A, int lda, double *__restrict__ B, int ldb,
double beta, double *__restrict__ C, int ldc) {
cublasDgemm(handle, transA, transB, M, N, K, alpha, A, lda, B, ldb, beta, C,
ldc);
cublasDgemm(handle, transA, transB, M, N, K, &alpha, A, lda, B, ldb, &beta, C,
ldc);
inDerivative = true;
}

Expand Down Expand Up @@ -212,10 +212,10 @@ static void gemvTests() {

inDerivative = true;
// dC = alpha * X * transpose(Y) + A
cublasDger(handle, M, N, alpha, trans ? B : dC, trans ? incB : incC,
trans ? dC : B, trans ? incC : incB, dA, lda);
cublasDger(handle, M, N, &alpha, trans ? B : dC, trans ? incB : incC,
trans ? dC : B, trans ? incC : incB, dA, lda);
// dY = beta * dY
cublasDscal(handle, trans ? N : M, beta, dC, incC);
cublasDscal(handle, trans ? N : M, &beta, dC, incC);

checkTest(Test);

Expand All @@ -241,15 +241,16 @@ static void gemvTests() {

inDerivative = true;
// dC = alpha * X * transpose(Y) + A
cublasDger(handle, M, N, alpha, trans ? B : dC, trans ? incB : incC,
trans ? dC : B, trans ? incC : incB, dA, lda);
cublasDger(handle, M, N, &alpha, trans ? B : dC, trans ? incB : incC,
trans ? dC : B, trans ? incC : incB, dA, lda);

// dB = alpha * trans(A) * dC + dB
cublasDgemv(handle, transpose(transA), M, N, alpha, A, lda, dC, incC,
1.0, dB, incB);
double c1 = 1.0;
cublasDgemv(handle, transpose(transA), M, N, &alpha, A, lda, dC, incC,
&c1, dB, incB);

// dY = beta * dY
cublasDscal(handle, trans ? N : M, beta, dC, incC);
cublasDscal(handle, trans ? N : M, &beta, dC, incC);

checkTest(Test);

Expand Down Expand Up @@ -391,7 +392,9 @@ static void gemmTests() {
transB_bool ? A : dC, transB_bool ? lda : incC, 1.0, dB, incB);

// TODO we are currently faking support here, this needs to be actually implemented
cublasDlascl(handle, (cublasOperation_t)'G', 0, 0, 1.0, beta, M, N, dC, incC, 0);
double c10 = 1.0;
cublasDlascl(handle, (cublasOperation_t)'G', 0, 0, &c10, &beta, M, N,
dC, incC, 0);

checkTest(Test);

Expand Down
78 changes: 43 additions & 35 deletions enzyme/test/Integration/blasinfra.h
Original file line number Diff line number Diff line change
Expand Up @@ -931,14 +931,12 @@ __attribute__((noinline)) void dlacpy(char *uplo_p, int *M_p, int *N_p, double *

__attribute__((noinline)) cublasStatus_t
cublasDlascl(cublasHandle_t *handle, cublasOperation_t type, int KL, int KU,
double cfrom, double cto, int M, int N, double *A, int lda, int info) {
calls.push_back((BlasCall){ABIType::CUBLAS,handle,
inDerivative, CallType::LASCL,
A, UNUSED_POINTER, UNUSED_POINTER,
cfrom, cto,
CUBLAS_LAYOUT,
(char)type, UNUSED_TRANS,
M, N, UNUSED_INT, lda, KL, KU});
double *cfrom, double *cto, int M, int N, double *A, int lda,
int info) {
calls.push_back((BlasCall){ABIType::CUBLAS, handle, inDerivative,
CallType::LASCL, A, UNUSED_POINTER, UNUSED_POINTER,
*cfrom, *cto, CUBLAS_LAYOUT, (char)type,
UNUSED_TRANS, M, N, UNUSED_INT, lda, KL, KU});
return cublasStatus_t::CUBLAS_STATUS_SUCCESS;
}
__attribute__((noinline)) cublasStatus_t cublasDlacpy(cublasHandle_t *handle, char uplo, int M,
Expand Down Expand Up @@ -1054,47 +1052,57 @@ __attribute__((noinline)) cublasStatus_t cublasDaxpy(cublasHandle_t *handle,
}
__attribute__((noinline)) cublasStatus_t
cublasDgemv(cublasHandle_t *handle, cublasOperation_t trans, int M, int N,
double alpha, double *A, int lda, double *X, int incx, double beta,
double *Y, int incy) {
BlasCall call = {ABIType::CUBLAS,handle,
inDerivative, CallType::GEMV, Y, A, X, alpha, beta, CUBLAS_LAYOUT,
(char)trans, UNUSED_TRANS, M, N, UNUSED_INT, lda, incx, incy};
double *alpha, double *A, int lda, double *X, int incx,
double *beta, double *Y, int incy) {
BlasCall call = {ABIType::CUBLAS,
handle,
inDerivative,
CallType::GEMV,
Y,
A,
X,
*alpha,
*beta,
CUBLAS_LAYOUT,
(char)trans,
UNUSED_TRANS,
M,
N,
UNUSED_INT,
lda,
incx,
incy};
calls.push_back(call);
return cublasStatus_t::CUBLAS_STATUS_SUCCESS;
}
__attribute__((noinline)) cublasStatus_t
cublasDgemm(cublasHandle_t *handle, cublasOperation_t transA,
cublasOperation_t transB, int M, int N, int K, double alpha,
double *A, int lda, double *B, int ldb, double beta, double *C,
int ldc) {
calls.push_back((BlasCall){ABIType::CUBLAS,handle,inDerivative, CallType::GEMM, C, A, B, alpha,
beta,
CUBLAS_LAYOUT,
(char)transA, (char)transB, M, N, K, lda,
ldb, ldc});
cublasOperation_t transB, int M, int N, int K, double *alpha,
double *A, int lda, double *B, int ldb, double *beta, double *C,
int ldc) {
calls.push_back((BlasCall){ABIType::CUBLAS, handle, inDerivative,
CallType::GEMM, C, A, B, *alpha, *beta,
CUBLAS_LAYOUT, (char)transA, (char)transB, M, N, K,
lda, ldb, ldc});
return cublasStatus_t::CUBLAS_STATUS_SUCCESS;
}
__attribute__((noinline)) cublasStatus_t
cublasDscal(cublasHandle_t *handle, int N, double alpha, double *X, int incX) {
cublasDscal(cublasHandle_t *handle, int N, double *alpha, double *X, int incX) {
calls.push_back((BlasCall){
ABIType::CUBLAS,handle,inDerivative, CallType::SCAL, X, UNUSED_POINTER, UNUSED_POINTER, alpha,
UNUSED_DOUBLE,
CUBLAS_LAYOUT,
UNUSED_TRANS, UNUSED_TRANS, N, UNUSED_INT,
UNUSED_INT, incX, UNUSED_INT, UNUSED_INT});
ABIType::CUBLAS, handle, inDerivative, CallType::SCAL, X, UNUSED_POINTER,
UNUSED_POINTER, *alpha, UNUSED_DOUBLE, CUBLAS_LAYOUT, UNUSED_TRANS,
UNUSED_TRANS, N, UNUSED_INT, UNUSED_INT, incX, UNUSED_INT, UNUSED_INT});
return cublasStatus_t::CUBLAS_STATUS_SUCCESS;
}

// A = alpha * X * transpose(Y) + A
__attribute__((noinline)) cublasStatus_t
cublasDger(cublasHandle_t *handle, int M, int N, double alpha, double *X,
int incX, double *Y, int incY, double *A, int lda) {
calls.push_back((BlasCall){ABIType::CUBLAS,handle,inDerivative, CallType::GER, A, X, Y, alpha,
UNUSED_DOUBLE,
CUBLAS_LAYOUT,
UNUSED_TRANS,
UNUSED_TRANS, M, N, UNUSED_INT, incX, incY,
lda});
cublasDger(cublasHandle_t *handle, int M, int N, double *alpha, double *X,
int incX, double *Y, int incY, double *A, int lda) {
calls.push_back((BlasCall){ABIType::CUBLAS, handle, inDerivative,
CallType::GER, A, X, Y, *alpha, UNUSED_DOUBLE,
CUBLAS_LAYOUT, UNUSED_TRANS, UNUSED_TRANS, M, N,
UNUSED_INT, incX, incY, lda});
return cublasStatus_t::CUBLAS_STATUS_SUCCESS;
}

Expand Down
30 changes: 19 additions & 11 deletions enzyme/tools/enzyme-tblgen/blas-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ void emit_helper(const TGPattern &pattern, raw_ostream &os) {

os << " const bool byRef = blas.prefix == \"\" || blas.prefix == "
"\"cublas_\";\n";
os << "const bool byRefFloat = byRef || blas.prefix == \"cublas\";\n";
os << "(void)byRefFloat;\n";
os << " const bool cblas = blas.prefix == \"cblas_\";\n";
os << " const bool cublas = blas.prefix == \"cublas_\" || blas.prefix == "
"\"cublas\";\n";
Expand Down Expand Up @@ -355,7 +357,7 @@ void emit_helper(const TGPattern &pattern, raw_ostream &os) {
auto ty = argTypeMap.lookup(actArgs[i]);
os << " if (";
if (ty == ArgType::fp)
os << "byRef && ";
os << "byRefFloat && ";
os << "active_" << name << ") {\n"
<< " auto shadow_" << name << " = gutils->invertPointerM(orig_"
<< name << ", BuilderZ);\n"
Expand Down Expand Up @@ -385,7 +387,7 @@ void emit_helper(const TGPattern &pattern, raw_ostream &os) {
auto ty = argTypeMap.lookup(actArgs[i]);
os << " if (";
if (ty == ArgType::fp)
os << "byRef && ";
os << "byRefFloat && ";
os << "active_" << name << ") {\n"
<< " rt_inactive_" << name << " = BuilderZ.CreateOr(rt_inactive_"
<< name << ", rt_inactive_out, \"rt.inactive.\" \"" << name << "\");\n"
Expand All @@ -406,7 +408,8 @@ void emit_helper(const TGPattern &pattern, raw_ostream &os) {
}
}
if (!hasFP)
os << " Type* blasFPType = byRef ? (Type*)PointerType::getUnqual(fpType) "
os << " Type* blasFPType = byRefFloat ? "
"(Type*)PointerType::getUnqual(fpType) "
": (Type*)fpType;\n";

bool hasChar = false;
Expand Down Expand Up @@ -609,25 +612,29 @@ void emit_extract_calls(const TGPattern &pattern, raw_ostream &os) {
<< " if (Mode != DerivativeMode::ForwardModeSplit)\n"
<< " cacheval = lookup(cacheval, Builder2);\n"
<< " }\n"
<< "\n"
<< " if (byRef) {\n";
<< "\n";

for (size_t i = 0; i < nameVec.size(); i++) {
auto ty = typeMap.lookup(i);
auto name = nameVec[i];
// this branch used "true_" << name everywhere instead of "arg_" << name
// before. probably randomly, but check to make sure
if (ty == ArgType::len || ty == ArgType::vincInc || ty == ArgType::mldLD) {
os << " if (byRef) {\n";
extract_scalar(name, "intType", os);
os << " }\n";
} else if (ty == ArgType::fp) {
os << " if (byRefFloat) {\n";
extract_scalar(name, "fpType", os);
os << " }\n";
} else if (ty == ArgType::trans) {
// we are in the byRef branch and trans only exist in lv23.
// So just unconditionally asume that no layout exist and use i-1
os << " if (byRef) {\n";
extract_scalar(name, "charType", os);
os << " }\n";
}
}
os << " }\n";

std::string input_var = "";
size_t actVar = 0;
Expand Down Expand Up @@ -1207,8 +1214,8 @@ void rev_call_arg(DagInit *ruleDag, Rule &rule, size_t actArg, size_t pos,
} else if (Def->isSubClassOf("Constant")) {
auto val = Def->getValueAsString("value");
os << "{to_blas_fp_callconv(Builder2, ConstantFP::get(fpType, " << val
<< "), byRef, blasFPType, allocationBuilder, \"constant.fp." << val
<< "\")}";
<< "), byRefFloat, blasFPType, allocationBuilder, \"constant.fp."
<< val << "\")}";
} else if (Def->isSubClassOf("Char")) {
auto val = Def->getValueAsString("value");
if (val == "N") {
Expand Down Expand Up @@ -1382,7 +1389,7 @@ void emit_fret_call(StringRef dfnc_name, StringRef argName, StringRef name,
<< bb << ".CreateCall(derivcall_" << dfnc_name << ", " << argName
<< ", Defs));\n";
}
os << " if (byRef) {\n"
os << " if (byRefFloat) {\n"
<< " ((DiffeGradientUtils *)gutils)"
<< "->addToInvertedPtrDiffe(&call, nullptr, fpType, 0, "
<< "(called->getParent()->getDataLayout().getTypeSizeInBits(fpType)/8), "
Expand All @@ -1401,7 +1408,7 @@ void emit_runtime_condition(DagInit *ruleDag, StringRef name, StringRef tab,
StringRef B, bool isFP, raw_ostream &os) {
os << tab << "BasicBlock *nextBlock_" << name << " = nullptr;\n"
<< tab << "if (EnzymeRuntimeActivityCheck && cacheMode"
<< (isFP ? " && byRef" : "") << ") {\n"
<< (isFP ? " && byRefFloat" : "") << ") {\n"
<< tab << " BasicBlock *current = Builder2.GetInsertBlock();\n"
<< tab << " auto activeBlock = gutils->addReverseBlock(current,"
<< "bb_name + \"." << name << ".active\");\n"
Expand All @@ -1415,7 +1422,8 @@ void emit_runtime_condition(DagInit *ruleDag, StringRef name, StringRef tab,

void emit_runtime_continue(DagInit *ruleDag, StringRef name, StringRef tab,
StringRef B, bool isFP, raw_ostream &os) {
os << tab << "if (nextBlock_" << name << (isFP ? " && byRef" : "") << ") {\n"
os << tab << "if (nextBlock_" << name << (isFP ? " && byRefFloat" : "")
<< ") {\n"
<< tab << " " << B << ".CreateBr(nextBlock_" << name << ");\n"
<< tab << " " << B << ".SetInsertPoint(nextBlock_" << name << ");\n"
<< tab << "}\n";
Expand Down
10 changes: 6 additions & 4 deletions enzyme/tools/enzyme-tblgen/blasDeclUpdater.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ void emit_attributeBLAS(const TGPattern &pattern, raw_ostream &os) {
os << " return;\n";
os << " const bool byRef = blas.prefix == \"\" || blas.prefix == "
"\"cublas_\";\n";
os << "const bool byRefFloat = byRef || blas.prefix == \"cublas\";\n";
os << "(void)byRefFloat;\n";
if (lv23)
os << " const bool cblas = blas.prefix == \"cblas_\";\n";
os << " const bool cublas = blas.prefix == \"cublas_\" || blas.prefix == "
Expand Down Expand Up @@ -104,26 +106,26 @@ void emit_attributeBLAS(const TGPattern &pattern, raw_ostream &os) {
}
}

os << " if (byRef) {\n";

for (size_t argPos = 0; argPos < numArgs; argPos++) {
const auto typeOfArg = argTypeMap.lookup(argPos);
size_t i = (lv23 ? argPos - 1 : argPos);

if (is_char_arg(typeOfArg) || typeOfArg == ArgType::len ||
typeOfArg == ArgType::vincInc || typeOfArg == ArgType::fp ||
typeOfArg == ArgType::mldLD) {
os << " if (" << (typeOfArg == ArgType::fp ? "byRefFloat" : "byRef")
<< ") {\n";
os << " F->removeParamAttr(" << i << " + offset"
<< ", llvm::Attribute::ReadNone);\n"
<< " F->addParamAttr(" << i << " + offset"
<< ", llvm::Attribute::ReadOnly);\n"
<< " F->addParamAttr(" << i << " + offset"
<< ", llvm::Attribute::NoCapture);\n";
os << " }\n";
}
}

os << " }\n"
<< " // Julia declares double* pointers as Int64,\n"
os << " // Julia declares double* pointers as Int64,\n"
<< " // so LLVM won't let us add these Attributes.\n"
<< " if (!julia_decl) {\n";
for (size_t argPos = 0; argPos < numArgs; argPos++) {
Expand Down
4 changes: 3 additions & 1 deletion enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ void emit_BLASDiffUse(TGPattern &pattern, llvm::raw_ostream &os) {

os << " const bool byRef = blas.prefix == \"\" || blas.prefix == "
"\"cublas_\";\n";
os << "const bool byRefFloat = byRef || blas.prefix == \"cublas\";\n";
os << "(void)byRefFloat;\n";
if (lv23)
os << " const bool cblas = blas.prefix == \"cblas_\";\n";
os << " const bool cublas = blas.prefix == \"cublas_\" || blas.prefix == "
Expand Down Expand Up @@ -77,7 +79,7 @@ void emit_BLASDiffUse(TGPattern &pattern, llvm::raw_ostream &os) {

// We need the shadow of the value we're updating
if (typeMap[argPos] == ArgType::fp) {
os << " if (shadow && byRef && active_" << argname
os << " if (shadow && byRefFloat && active_" << argname
<< ") return true;\n";
} else if (typeMap[argPos] == ArgType::vincData ||
typeMap[argPos] == ArgType::mldData) {
Expand Down
5 changes: 4 additions & 1 deletion enzyme/tools/enzyme-tblgen/blasTAUpdater.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
void emit_BLASTypes(raw_ostream &os) {
os << "const bool byRef = blas.prefix == \"\" || blas.prefix == "
"\"cublas_\";\n";
os << "const bool byRefFloat = byRef || blas.prefix == "
"\"cublas\";\n";
os << "(void)byRefFloat;\n";
os << "const bool cblas = blas.prefix == \"cblas_\";\n";
os << "const bool cublas = blas.prefix == \"cublas_\" || blas.prefix == "
"\"cublas\";\n";
Expand All @@ -18,7 +21,7 @@ void emit_BLASTypes(raw_ostream &os) {
<< "} else {\n"
<< " llvm_unreachable(\"unknown float type of blas\");\n"
<< "}\n"
<< "if (byRef) {\n"
<< "if (byRefFloat) {\n"
<< " ttFloat.insert({-1},BaseType::Pointer);\n"
<< " ttFloat.insert({-1,0},floatType);\n"
<< "} else { \n"
Expand Down

0 comments on commit d18f50c

Please sign in to comment.