Skip to content

Commit

Permalink
WIP: cublas dscal2 (#1879)
Browse files Browse the repository at this point in the history
* WIP: cublas dscal2

* Fix cuscal

* fixup cublas scal
  • Loading branch information
wsmoses committed May 14, 2024
1 parent 611912e commit de30014
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 14 deletions.
56 changes: 56 additions & 0 deletions enzyme/test/Integration/ReverseMode/cublas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ int enzyme_out;
int enzyme_const;
template <typename... T> void __enzyme_autodiff(void *, T...);

void my_dscal_v2(cublasHandle_t *handle, int N, double alpha,
double *__restrict__ X, int incx) {
cublasDscal_v2(handle, N, &alpha, X, incx);
inDerivative = true;
}

void my_dgemv(cublasHandle_t *handle, cublasOperation_t trans, int M, int N,
double alpha, double *__restrict__ A, int lda,
double *__restrict__ X, int incx, double beta,
Expand Down Expand Up @@ -60,6 +66,54 @@ void my_dgemm(cublasHandle_t *handle, cublasOperation_t transA,
inDerivative = true;
}

static void scal2Tests() {

std::string Test = "SCAL2 active both ";
cublasHandle_t *handle = DEFAULT_CUBLAS_HANDLE;
BlasInfo inputs[6] = {
/*A*/ BlasInfo(A, N, incA),
BlasInfo(),
BlasInfo(),
BlasInfo(),
BlasInfo(),
BlasInfo(),
};
init();

double alpha = 3.14;
// cublasHandle_t handle;
my_dscal_v2(handle, N, alpha, A, incA);

// Check memory of primal on own.
checkMemoryTrace(inputs, "Primal " + Test, calls);

init();
__enzyme_autodiff((void *)my_dscal_v2, enzyme_const, handle, enzyme_const, N,
enzyme_out, alpha, enzyme_dup, A, dA, enzyme_const, incA);
foundCalls = calls;

init();

my_dscal_v2(handle, N, alpha, A, incA);

inDerivative = true;

double *dalpha = (double *)foundCalls[1].pout_arg1;
inputs[3] = BlasInfo(dalpha, 1, 1);

cublasDdot_v2(handle, N, A, incA, dA, incA, dalpha);
cublasDscal_v2(handle, N, &alpha, dA, incA);

checkTest(Test);

// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);

// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
}

static void dotTests() {

std::string Test = "DOT active both ";
Expand Down Expand Up @@ -417,4 +471,6 @@ int main() {
dotTests();

dot2Tests();

scal2Tests();
}
9 changes: 9 additions & 0 deletions enzyme/test/Integration/blasinfra.h
Original file line number Diff line number Diff line change
Expand Up @@ -1095,6 +1095,15 @@ cublasDscal(cublasHandle_t *handle, int N, double *alpha, double *X, int incX) {
return cublasStatus_t::CUBLAS_STATUS_SUCCESS;
}

__attribute__((noinline)) cublasStatus_t
cublasDscal_v2(cublasHandle_t *handle, int N, double *alpha, double *X, int incX) {
calls.push_back((BlasCall){
ABIType::CUBLASv2, handle, inDerivative, CallType::SCAL, X, UNUSED_POINTER,
UNUSED_POINTER, *alpha, UNUSED_DOUBLE, CUBLAS_LAYOUT, UNUSED_TRANS,
UNUSED_TRANS, N, UNUSED_INT, UNUSED_INT, incX, UNUSED_INT, UNUSED_INT});
return cublasStatus_t::CUBLAS_STATUS_SUCCESS;
}

// A = alpha * X * transpose(Y) + A
__attribute__((noinline)) cublasStatus_t
cublasDger(cublasHandle_t *handle, int M, int N, double *alpha, double *X,
Expand Down
31 changes: 20 additions & 11 deletions enzyme/tools/enzyme-tblgen/blas-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1314,7 +1314,7 @@ void rev_call_arg(DagInit *ruleDag, Rule &rule, size_t actArg, size_t pos,

// fill the result string and return the number of added args
void rev_call_args(StringRef argName, Rule &rule, size_t actArg,
raw_ostream &os, int subRule, StringRef func) {
raw_ostream &os, int subRule, StringRef func, ArgType ty) {

const auto nameMap = rule.getArgNameMap();

Expand Down Expand Up @@ -1353,6 +1353,10 @@ void rev_call_args(StringRef argName, Rule &rule, size_t actArg,
os << " " << argName
<< ".push_back(ConstantInt::get(intType, 1));\n";
os << " }\n";
if (ty == ArgType::fp) {
os << " if (cublasv2) " << argName
<< ".push_back(Builder2.CreateAlloca(fpType));\n";
}
}

void emit_fret_call(StringRef dfnc_name, StringRef argName, StringRef name,
Expand All @@ -1373,7 +1377,8 @@ void emit_fret_call(StringRef dfnc_name, StringRef argName, StringRef name,
<< ") tys.push_back(arg->getType());\n";
std::string dfnc_ret_ty = get_blas_ret_ty(dfnc_name);
os << " llvm::FunctionType *FT" << dfnc_name << " = FunctionType::get("
<< dfnc_ret_ty << ", tys, false);\n";
<< "cublasv2 ? Type::getVoidTy(fpType->getContext()) : " << dfnc_ret_ty
<< ", tys, false);\n";
os << " auto derivcall_" << dfnc_name
<< " = gutils->oldFunc->getParent()->getOrInsertFunction(\n"
<< " blas.prefix + blas.floatType + \"" << dfnc_name
Expand All @@ -1389,14 +1394,17 @@ void emit_fret_call(StringRef dfnc_name, StringRef argName, StringRef name,
<< bb << ".CreateCall(derivcall_" << dfnc_name << ", " << argName
<< ", Defs));\n";
}
os << " Value *dres = cubcall;\n";
os << " if (cublasv2) dres = " << bb << ".CreateLoad(fpType, "
<< argName << "[" << argName << ".size()-1]);\n";
os << " if (byRefFloat) {\n"
<< " ((DiffeGradientUtils *)gutils)"
<< "->addToInvertedPtrDiffe(&call, nullptr, fpType, 0, "
<< "(called->getParent()->getDataLayout().getTypeSizeInBits(fpType)/8), "
"orig_"
<< name << ", cubcall, " << bb << ");\n"
<< name << ", dres, " << bb << ");\n"
<< " } else {\n"
<< " addToDiffe(orig_" << name << ", cubcall, " << bb
<< " addToDiffe(orig_" << name << ", dres, " << bb
<< ", fpType);\n"
<< " }\n";
os << "}\n";
Expand Down Expand Up @@ -1604,18 +1612,19 @@ void emit_rev_rewrite_rules(const StringMap<TGPattern> &patternMap,
emit_runtime_condition(ruleDag, name, " ", "Builder2",
(ty == ArgType::fp), os);
const auto dfnc_name = Def->getValueAsString("s");
rev_call_args("args1", rule, actArg, os, -1, dfnc_name);
rev_call_args("args1", rule, actArg, os, -1, dfnc_name, ty);
os << " const auto Defs = gutils->getInvertedBundles(&call, {"
<< valueTypes << "}, Builder2, /* lookup */ true);\n";

if (ty == ArgType::fp) {
// extra handling, since we will update only a fp scalar as part of the
// return struct it's presumably done by setting it to the value
// returned by this call
os << " if (!cublas) {\n";
os << " if (!cublas || cublasv2) {\n";
emit_fret_call(dfnc_name, "ArrayRef<Value *>(args1)", name, "Builder2",
os);
os << " } else {\n";
os << " assert(\"unsupported cublas\");\n";
} else {
os << " SmallVector<Type*, 1> tys; for (auto arg : args1) "
"tys.push_back(arg->getType());\n";
Expand Down Expand Up @@ -1654,7 +1663,7 @@ void emit_rev_rewrite_rules(const StringMap<TGPattern> &patternMap,
os << " // DiagUpdateSPMV\n";
emit_if_rule_condition(ruleDag, name, " ", os);
emit_runtime_condition(ruleDag, name, " ", "Builder2", true, os);
rev_call_args("args1", rule, actArg, os, -1, "");
rev_call_args("args1", rule, actArg, os, -1, "", ty);
os << " const auto Defs = gutils->getInvertedBundles(&call, {"
<< valueTypes << "}, Builder2, /* lookup */ true);\n";
// Now that we have the defs, we can create the call
Expand All @@ -1670,7 +1679,7 @@ void emit_rev_rewrite_rules(const StringMap<TGPattern> &patternMap,
os << " // FrobInnerProd\n";
emit_if_rule_condition(ruleDag, name, " ", os);
emit_runtime_condition(ruleDag, name, " ", "Builder2", true, os);
rev_call_args("args1", rule, actArg, os, -1, "");
rev_call_args("args1", rule, actArg, os, -1, "", ty);
os << " const auto Defs = gutils->getInvertedBundles(&call, {"
<< valueTypes << "}, Builder2, /* lookup */ true);\n";
// Now that we have the defs, we can create the call
Expand Down Expand Up @@ -1701,7 +1710,7 @@ void emit_rev_rewrite_rules(const StringMap<TGPattern> &patternMap,
if (sub_Def->isSubClassOf("b")) {
const auto dfnc_name = sub_Def->getValueAsString("s");
std::string argName = "args" + std::to_string(i);
rev_call_args(argName, rule, actArg, os, i, dfnc_name);
rev_call_args(argName, rule, actArg, os, i, dfnc_name, ty);
os << " //handling nested blas: " << std::to_string(i) << "\n";
// emit_deriv_blas_call(sub_Dag, patternMap, handled, os);
if (get_blas_ret_ty(dfnc_name) == "fpType") {
Expand Down Expand Up @@ -1734,13 +1743,13 @@ void emit_rev_rewrite_rules(const StringMap<TGPattern> &patternMap,
os << " //handled nested blas: " << std::to_string(i) << "\n";
} else if (sub_Def->isSubClassOf("FrobInnerProd")) {
std::string argName = "args" + std::to_string(i);
rev_call_args(argName, rule, actArg, os, i, "");
rev_call_args(argName, rule, actArg, os, i, "", ty);
assert(sub_Dag->getNumArgs() == 4);
assert(ty == ArgType::fp);
emit_fret_call("inner_prod", argName, name, "Builder2", os);
} else if (sub_Def->isSubClassOf("DiagUpdateSPMV")) {
std::string argName = "args" + std::to_string(i);
rev_call_args(argName, rule, actArg, os, i, "");
rev_call_args(argName, rule, actArg, os, i, "", ty);
assert(sub_Dag->getNumArgs() == 6);
assert(ty == ArgType::ap);
os << "callSPMVDiagUpdate(Builder2, *gutils->oldFunc->getParent(), "
Expand Down
6 changes: 3 additions & 3 deletions enzyme/tools/enzyme-tblgen/blasTAUpdater.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,11 @@ void emit_BLASTA(TGPattern &pattern, raw_ostream &os) {
break;
case ArgType::mldData:
os << " updateAnalysis(call.getArgOperand(" << i
<< (lv23 ? " + offset" : "") << "), ttPtr, &call);\n";
<< " + offset), ttPtr, &call);\n";
break;
case ArgType::fp:
os << " updateAnalysis(call.getArgOperand(" << i
<< (lv23 ? " + offset" : "") << "), ttFloat, &call);\n";
<< " + offset), ttFloat, &call);\n";
break;
case ArgType::ap:
// TODO
Expand All @@ -151,7 +151,7 @@ void emit_BLASTA(TGPattern &pattern, raw_ostream &os) {
case ArgType::uplo:
case ArgType::trans:
os << " updateAnalysis(call.getArgOperand(" << i
<< (lv23 ? " + offset" : "") << "), ttChar, &call);\n";
<< " + offset), ttChar, &call);\n";
break;
case ArgType::diag:
case ArgType::side:
Expand Down

0 comments on commit de30014

Please sign in to comment.