Skip to content

Commit

Permalink
cublas v2 support (#1866)
Browse files Browse the repository at this point in the history
* cublas v2 support

* Now with cudamemset
  • Loading branch information
wsmoses committed May 8, 2024
1 parent acfb0f9 commit f3dd860
Show file tree
Hide file tree
Showing 5 changed files with 250 additions and 68 deletions.
14 changes: 9 additions & 5 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2658,13 +2658,17 @@ llvm::Optional<BlasInfo> extractBLAS(llvm::StringRef in)
const char *cuCFloatType[] = {"S", "D"}; // c, z
const char *cuFFloatType[] = {"s", "d"}; // c, z
const char *cuCPrefixes[] = {"cublas"};
const char *cuSuffixes[] = {"", "_v2", "_64", "_v2_64"};
for (auto t : llvm::enumerate(cuCFloatType)) {
for (auto f : extractable) {
for (auto p : cuCPrefixes) {
if (in == (Twine(p) + t.value() + f).str()) {
return BlasInfo{
t.value(), p, "", f, false,
};
for (auto s : cuSuffixes) {
if (in == (Twine(p) + t.value() + f + s).str()) {
bool is64 = llvm::StringRef(s).contains("64");
return BlasInfo{
t.value(), p, s, f, is64,
};
}
}
}
}
Expand Down Expand Up @@ -3131,4 +3135,4 @@ llvm::Value *get1ULP(llvm::IRBuilder<> &builder, llvm::Value *res) {
ArrayRef<Value *>(diff));

return absres;
}
}
67 changes: 63 additions & 4 deletions enzyme/test/Integration/ReverseMode/cublas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,15 @@ void ow_dgemv(cublasHandle_t *handle, cublasOperation_t trans, int M, int N,

double my_ddot(cublasHandle_t *handle, int N, double *__restrict__ X, int incx,
double *__restrict__ Y, int incy) {
double res = cublasDdot(handle, N, X, incx, Y, incy);
inDerivative = true;
return res;
}

double my_ddot2(cublasHandle_t *handle, int N, double *__restrict__ X, int incx,
double *__restrict__ Y, int incy) {
double res = 0.0;
cublasDdot(handle, N, X, incx, Y, incy, &res);
cublasDdot_v2(handle, N, X, incx, Y, incy, &res);
inDerivative = true;
return res;
}
Expand Down Expand Up @@ -78,19 +85,69 @@ static void dotTests() {
enzyme_const, incB);
foundCalls = calls;

init();

my_ddot(handle, N, A, incA, B, incB);

inDerivative = true;

cublasDaxpy(handle, N, 1.0, B, incB, dA, incA);
cublasDaxpy(handle, N, 1.0, A, incA, dB, incB);

checkTest(Test);

// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);

// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
}

static void dot2Tests() {

std::string Test = "DOTv2 active both ";
cublasHandle_t *handle = DEFAULT_CUBLAS_HANDLE;
BlasInfo inputs[6] = {
/*A*/ BlasInfo(A, N, incA),
/*B*/ BlasInfo(B, N, incB),
/*C*/ BlasInfo(C, M, incC), BlasInfo(), BlasInfo(), BlasInfo(),
};
init();
// cublasHandle_t handle;
my_ddot2(handle, N, A, incA, B, incB);
{
auto primal_stack_ret = (double *)calls[0].pout_arg1;
inputs[3] = BlasInfo(primal_stack_ret, 1, 1);
}

// Check memory of primal on own.
checkMemoryTrace(inputs, "Primal " + Test, calls);

init();
__enzyme_autodiff((void *)my_ddot2, enzyme_const, handle, enzyme_const, N,
enzyme_dup, A, dA, enzyme_const, incA, enzyme_dup, B, dB,
enzyme_const, incB);
{
auto primal_stack_ret = (double *)calls[0].pout_arg1;
inputs[3] = BlasInfo(primal_stack_ret, 1, 1);
}
foundCalls = calls;

auto stack_ret = (double*)foundCalls[1].pin_arg2;
inputs[4] = BlasInfo(stack_ret, 1, 1);

init();

my_ddot(handle, N, A, incA, B, incB);
my_ddot2(handle, N, A, incA, B, incB);

calls[0].pout_arg1 = (double*)foundCalls[0].pout_arg1;

inDerivative = true;

cublasDaxpy(handle, N, stack_ret, B, incB, dA, incA);
cublasDaxpy(handle, N, stack_ret, A, incA, dB, incB);
cublasDaxpy_v2(handle, N, stack_ret, B, incB, dA, incA);
cublasDaxpy_v2(handle, N, stack_ret, A, incA, dB, incB);
cudaMemset(stack_ret, 0, sizeof(double));

checkTest(Test);

Expand Down Expand Up @@ -355,4 +412,6 @@ int main() {
gemvTests();

dotTests();

dot2Tests();
}
Loading

0 comments on commit f3dd860

Please sign in to comment.