From 4102362aea28d70f4e27665aeabb06b304db3115 Mon Sep 17 00:00:00 2001 From: Pratyush Das Date: Wed, 4 Aug 2021 12:33:51 +0530 Subject: [PATCH] Calculate adjoint of dgemm and sgemm --- enzyme/Enzyme/AdjointGenerator.h | 624 +++++++++++++----- enzyme/Enzyme/Enzyme.cpp | 8 + .../ReverseMode/blas/cblas_dgemm_col_nomod.ll | 121 ++++ .../blas/cblas_dgemm_col_nomod_transa.ll | 123 ++++ .../blas/cblas_dgemm_col_nomod_transb.ll | 112 ++++ .../blas/cblas_dgemm_col_nomod_transboth.ll | 114 ++++ ...las_dgemm_col_transboth_inactive_second.ll | 121 ++++ .../ReverseMode/blas/cblas_dgemm_row_nomod.ll | 120 ++++ .../blas/cblas_dgemm_row_nomod_transa.ll | 108 +++ .../blas/cblas_dgemm_row_nomod_transb.ll | 116 ++++ .../blas/cblas_dgemm_row_nomod_transboth.ll | 116 ++++ .../ReverseMode/blas/cblas_sgemm_col_nomod.ll | 109 +++ .../blas/cblas_sgemm_col_nomod_transa.ll | 111 ++++ .../blas/cblas_sgemm_col_nomod_transb.ll | 100 +++ .../blas/cblas_sgemm_col_nomod_transboth.ll | 102 +++ ...las_sgemm_col_transboth_inactive_second.ll | 97 +++ .../ReverseMode/blas/cblas_sgemm_row_nomod.ll | 96 +++ .../blas/cblas_sgemm_row_nomod_transa.ll | 96 +++ .../blas/cblas_sgemm_row_nomod_transb.ll | 104 +++ .../blas/cblas_sgemm_row_nomod_transboth.ll | 104 +++ 20 files changed, 2431 insertions(+), 171 deletions(-) create mode 100644 enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_col_nomod.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_col_nomod_transa.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_col_nomod_transb.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_col_nomod_transboth.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_col_transboth_inactive_second.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_row_nomod.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_row_nomod_transa.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_row_nomod_transb.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_row_nomod_transboth.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_col_nomod.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_col_nomod_transa.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_col_nomod_transb.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_col_nomod_transboth.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_col_transboth_inactive_second.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_row_nomod.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_row_nomod_transa.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_row_nomod_transb.ll create mode 100644 enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_row_nomod_transboth.ll diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index ebb7be706eaa..92339bbb3f41 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -5938,6 +5938,456 @@ class AdjointGenerator llvm_unreachable("Unhandled MPI FUNCTION"); } + void handleBlas(llvm::CallInst &call, Function *called, StringRef funcName, + CallInst *newCall, + std::map uncacheable_args) { + IRBuilder<> BuilderZ(newCall); + BuilderZ.setFastMathFlags(getFast()); + CallInst *orig = &call; + if (funcName == "cblas_ddot" || funcName == "cblas_sdot") { + Type *innerType; + std::string dfuncName; + if (funcName == "cblas_ddot") { + innerType = Type::getDoubleTy(call.getContext()); + dfuncName = "cblas_daxpy"; + } else if (funcName == "cblas_sdot") { + innerType = Type::getFloatTy(call.getContext()); + dfuncName = "cblas_saxpy"; + } else { + assert(false && "Unreachable"); + } + Type *castvals[2] = {call.getArgOperand(1)->getType(), + call.getArgOperand(3)->getType()}; + auto *cachetype = StructType::get(call.getContext(), ArrayRef(castvals)); + Value *undefinit = UndefValue::get(cachetype); + Value *cacheval; + auto in_arg = call.getCalledFunction()->arg_begin(); + in_arg++; + Argument *xfuncarg = in_arg; + in_arg++; + in_arg++; + Argument *yfuncarg = in_arg; + bool xcache = !gutils->isConstantValue(call.getArgOperand(3)) && + uncacheable_args.find(xfuncarg)->second; + bool ycache = !gutils->isConstantValue(call.getArgOperand(1)) && + uncacheable_args.find(yfuncarg)->second; + if ((Mode == DerivativeMode::ReverseModeCombined || + Mode == DerivativeMode::ReverseModePrimal) && + (xcache || ycache)) { + Value *arg1, *arg2; + auto size = ConstantExpr::getSizeOf(innerType); + if (xcache) { + auto dmemcpy = + getOrInsertMemcpyStrided(*gutils->oldFunc->getParent(), + PointerType::getUnqual(innerType), 0, 0); + auto malins = CallInst::CreateMalloc( + gutils->getNewFromOriginal(&call), size->getType(), innerType, + size, call.getArgOperand(0), nullptr, ""); + arg1 = + BuilderZ.CreateBitCast(malins, call.getArgOperand(1)->getType()); + SmallVector args; + args.push_back(arg1); + args.push_back(gutils->getNewFromOriginal(call.getArgOperand(1))); + args.push_back(call.getArgOperand(0)); + args.push_back(call.getArgOperand(2)); + BuilderZ.CreateCall(dmemcpy, args); + } + if (ycache) { + auto dmemcpy = + getOrInsertMemcpyStrided(*gutils->oldFunc->getParent(), + PointerType::getUnqual(innerType), 0, 0); + auto malins = CallInst::CreateMalloc( + gutils->getNewFromOriginal(&call), size->getType(), innerType, + size, call.getArgOperand(0), nullptr, ""); + arg2 = + BuilderZ.CreateBitCast(malins, call.getArgOperand(3)->getType()); + SmallVector args; + args.push_back(arg2); + args.push_back(gutils->getNewFromOriginal(call.getArgOperand(3))); + args.push_back(call.getArgOperand(0)); + args.push_back(call.getArgOperand(4)); + BuilderZ.CreateCall(dmemcpy, args); + } + if (xcache && ycache) { + auto valins1 = BuilderZ.CreateInsertValue(undefinit, arg1, 0); + cacheval = BuilderZ.CreateInsertValue(valins1, arg2, 1); + } else if (xcache) + cacheval = arg1; + else if (ycache) + cacheval = arg2; + gutils->cacheForReverse(BuilderZ, cacheval, + getIndex(&call, CacheType::Tape)); + } + if (Mode == DerivativeMode::ReverseModeCombined || + Mode == DerivativeMode::ReverseModeGradient) { + IRBuilder<> Builder2(call.getParent()); + getReverseBuilder(Builder2); + auto derivcall = gutils->oldFunc->getParent()->getOrInsertFunction( + dfuncName, Builder2.getVoidTy(), Builder2.getInt32Ty(), innerType, + call.getArgOperand(1)->getType(), Builder2.getInt32Ty(), + call.getArgOperand(3)->getType(), Builder2.getInt32Ty()); + Value *structarg1; + Value *structarg2; + if (xcache || ycache) { + if (Mode == DerivativeMode::ReverseModeGradient && + (!gutils->isConstantValue(call.getArgOperand(1)) || + !gutils->isConstantValue(call.getArgOperand(3)))) { + cacheval = BuilderZ.CreatePHI(cachetype, 0); + } + cacheval = + lookup(gutils->cacheForReverse(BuilderZ, cacheval, + getIndex(&call, CacheType::Tape)), + Builder2); + if (xcache && ycache) { + structarg1 = BuilderZ.CreateExtractValue(cacheval, 0); + structarg2 = BuilderZ.CreateExtractValue(cacheval, 1); + } else if (xcache) + structarg1 = cacheval; + else if (ycache) + structarg2 = cacheval; + } + if (!xcache) + structarg1 = lookup( + gutils->getNewFromOriginal(orig->getArgOperand(1)), Builder2); + if (!ycache) + structarg2 = lookup( + gutils->getNewFromOriginal(orig->getArgOperand(3)), Builder2); + CallInst *firstdcall, *seconddcall; + if (!gutils->isConstantValue(call.getArgOperand(3))) { + Value *estride; + if (xcache) + estride = Builder2.getInt32(1); + else + estride = lookup(gutils->getNewFromOriginal(orig->getArgOperand(2)), + Builder2); + SmallVector args1 = { + lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)), + Builder2), + diffe(orig, Builder2), + structarg1, + estride, + gutils->invertPointerM(orig->getArgOperand(3), Builder2), + lookup(gutils->getNewFromOriginal(orig->getArgOperand(4)), + Builder2)}; + firstdcall = Builder2.CreateCall(derivcall, args1); + } + if (!gutils->isConstantValue(call.getArgOperand(1))) { + Value *estride; + if (ycache) + estride = Builder2.getInt32(1); + else + estride = lookup(gutils->getNewFromOriginal(orig->getArgOperand(4)), + Builder2); + SmallVector args2 = { + lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)), + Builder2), + diffe(orig, Builder2), + structarg2, + estride, + gutils->invertPointerM(orig->getArgOperand(1), Builder2), + lookup(gutils->getNewFromOriginal(orig->getArgOperand(2)), + Builder2)}; + seconddcall = Builder2.CreateCall(derivcall, args2); + } + setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2); + if (xcache) + CallInst::CreateFree(structarg1, firstdcall->getNextNode()); + if (ycache) + CallInst::CreateFree(structarg2, seconddcall->getNextNode()); + } + + if (gutils->knownRecomputeHeuristic.find(orig) != + gutils->knownRecomputeHeuristic.end()) { + if (!gutils->knownRecomputeHeuristic[orig]) { + gutils->cacheForReverse(BuilderZ, newCall, + getIndex(orig, CacheType::Self)); + } + } + + if (Mode == DerivativeMode::ReverseModeGradient) { + eraseIfUnused(*orig, /*erase*/ true, /*check*/ false); + } else { + eraseIfUnused(*orig); + } + return; + } + + if (funcName == "cblas_dgemm" || funcName == "cblas_sgemm") { + Type *innerType; + std::string scfuncname; + if (funcName == "cblas_dgemm") { + innerType = Type::getDoubleTy(call.getContext()); + scfuncname = "cblas_dscal"; + } else if (funcName == "cblas_sgemm") { + innerType = Type::getFloatTy(call.getContext()); + scfuncname = "cblas_sscal"; + } else + assert(false && "Unreachable"); + auto in_arg = call.getCalledFunction()->arg_begin(); + in_arg++; + in_arg++; + in_arg++; + in_arg++; + in_arg++; + in_arg++; + in_arg++; + Argument *Aarg = in_arg; + in_arg++; + in_arg++; + Argument *Barg = in_arg; + auto aactive = !gutils->isConstantValue(call.getArgOperand(7)); + auto bactive = !gutils->isConstantValue(call.getArgOperand(9)); + assert(uncacheable_args.find(Aarg) != uncacheable_args.end()); + assert(uncacheable_args.find(Barg) != uncacheable_args.end()); + auto Acache = bactive && uncacheable_args.find(Aarg)->second; + auto Bcache = aactive && uncacheable_args.find(Barg)->second; + Type *castvals[2] = {call.getArgOperand(7)->getType(), + call.getArgOperand(9)->getType()}; + auto *cachetype = StructType::get(call.getContext(), ArrayRef(castvals)); + Value *undefinit = UndefValue::get(cachetype); + Value *cacheval; + if ((Mode == DerivativeMode::ReverseModeCombined || + Mode == DerivativeMode::ReverseModePrimal) && + (Acache || Bcache)) { + Value *arg1, *arg2; + auto size = ConstantExpr::getSizeOf(innerType); + if (Acache) { + auto Asize = + BuilderZ.CreateMul(call.getArgOperand(3), call.getArgOperand(5)); + auto malins = CallInst::CreateMalloc( + gutils->getNewFromOriginal(&call), size->getType(), innerType, + size, Asize, nullptr, ""); + arg1 = + BuilderZ.CreateBitCast(malins, call.getArgOperand(7)->getType()); +#if LLVM_VERSION_MAJOR >= 10 + BuilderZ.CreateMemCpy( + arg1, MaybeAlign(), + gutils->getNewFromOriginal(call.getArgOperand(7)), MaybeAlign(), + Asize); +#else + BuilderZ.CreateMemCpy( + arg1, 0, gutils->getNewFromOriginal(call.getArgOperand(7)), 0, + Asize); +#endif + } + if (Bcache) { + auto Bsize = + BuilderZ.CreateMul(call.getArgOperand(4), call.getArgOperand(5)); + auto malins = CallInst::CreateMalloc( + gutils->getNewFromOriginal(&call), size->getType(), innerType, + size, Bsize, nullptr, ""); + arg2 = + BuilderZ.CreateBitCast(malins, call.getArgOperand(9)->getType()); +#if LLVM_VERSION_MAJOR >= 10 + BuilderZ.CreateMemCpy( + arg2, MaybeAlign(), + gutils->getNewFromOriginal(call.getArgOperand(9)), MaybeAlign(), + Bsize); +#else + BuilderZ.CreateMemCpy( + arg2, 0, gutils->getNewFromOriginal(call.getArgOperand(9)), 0, + Bsize); +#endif + } + if (Acache && Bcache) { + auto valins1 = BuilderZ.CreateInsertValue(undefinit, arg1, 0); + cacheval = BuilderZ.CreateInsertValue(valins1, arg2, 1); + } else if (Acache) + cacheval = arg1; + else if (Bcache) + cacheval = arg2; + gutils->cacheForReverse(BuilderZ, cacheval, + getIndex(&call, CacheType::Tape)); + } + if (Mode == DerivativeMode::ReverseModeCombined || + Mode == DerivativeMode::ReverseModeGradient) { + IRBuilder<> Builder2(call.getParent()); + getReverseBuilder(Builder2); + auto dfunc = gutils->oldFunc->getParent()->getOrInsertFunction( + funcName, Builder2.getVoidTy(), Builder2.getInt32Ty(), + Builder2.getInt32Ty(), Builder2.getInt32Ty(), Builder2.getInt32Ty(), + Builder2.getInt32Ty(), Builder2.getInt32Ty(), + call.getArgOperand(6)->getType(), call.getArgOperand(7)->getType(), + Builder2.getInt32Ty(), call.getArgOperand(7)->getType(), + Builder2.getInt32Ty(), call.getArgOperand(6)->getType(), + call.getArgOperand(7)->getType(), Builder2.getInt32Ty()); + auto oneval = Builder2.getInt32(1); + auto doneval = Builder2.CreateSIToFP(oneval, innerType); + Value *sabtrans, *saldb, *sbatrans, *sblda, *saldc, *salda, *sbldb, + *sbldc; + Value *structarg1, *structarg2; + if (Acache || Bcache) { + if (Mode == DerivativeMode::ReverseModeGradient && + (aactive || bactive)) { + cacheval = BuilderZ.CreatePHI(cachetype, 0); + } + cacheval = + lookup(gutils->cacheForReverse(BuilderZ, cacheval, + getIndex(&call, CacheType::Tape)), + Builder2); + if (Acache && Bcache) { + structarg1 = BuilderZ.CreateExtractValue(cacheval, 0); + structarg2 = BuilderZ.CreateExtractValue(cacheval, 1); + } else if (Acache) + structarg1 = cacheval; + else if (Bcache) + structarg2 = cacheval; + } + if (!Acache) + structarg1 = lookup(gutils->getNewFromOriginal(call.getArgOperand(7)), + Builder2); + if (!Bcache) + structarg2 = lookup(gutils->getNewFromOriginal(call.getArgOperand(9)), + Builder2); + if (cast(call.getArgOperand(0))->getValue() == 102) { + if (aactive) { + salda = lookup(gutils->getNewFromOriginal(call.getArgOperand(3)), + Builder2); + saldc = lookup(gutils->getNewFromOriginal(call.getArgOperand(3)), + Builder2); + } + if (bactive) { + sbldb = lookup(gutils->getNewFromOriginal(call.getArgOperand(3)), + Builder2); + sbldc = lookup(gutils->getNewFromOriginal(call.getArgOperand(5)), + Builder2); + } + } else if (cast(call.getArgOperand(0))->getValue() == + 101) { + if (aactive) { + salda = lookup(gutils->getNewFromOriginal(call.getArgOperand(4)), + Builder2); + saldc = lookup(gutils->getNewFromOriginal(call.getArgOperand(5)), + Builder2); + } + if (bactive) { + sbldb = lookup(gutils->getNewFromOriginal(call.getArgOperand(4)), + Builder2); + sbldc = lookup(gutils->getNewFromOriginal(call.getArgOperand(4)), + Builder2); + } + } else + assert(false && "Wrong value"); + CallInst *safunccall, *sbfunccall; + if (aactive) { + if (cast(call.getArgOperand(2))->getValue() == 112 || + cast(call.getArgOperand(2))->getValue() == 113) { + sabtrans = Builder2.getInt32(111); + if (cast(call.getArgOperand(0))->getValue() == 102) + saldb = lookup(gutils->getNewFromOriginal(call.getArgOperand(4)), + Builder2); + else + saldb = lookup(gutils->getNewFromOriginal(call.getArgOperand(5)), + Builder2); + } else if (cast(call.getArgOperand(2))->getValue() == + 111) { + sabtrans = Builder2.getInt32(112); + saldb = lookup(gutils->getNewFromOriginal(call.getArgOperand(4)), + Builder2); + } else + assert(false && "Wrong value"); + SmallVector safuncargs = { + lookup(gutils->getNewFromOriginal(call.getArgOperand(0)), + Builder2), + Builder2.getInt32(111), + sabtrans, + lookup(gutils->getNewFromOriginal(call.getArgOperand(3)), + Builder2), + lookup(gutils->getNewFromOriginal(call.getArgOperand(5)), + Builder2), + lookup(gutils->getNewFromOriginal(call.getArgOperand(4)), + Builder2), + lookup(gutils->getNewFromOriginal(call.getArgOperand(6)), + Builder2), + gutils->invertPointerM(call.getArgOperand(12), Builder2), + salda, + structarg2, + saldb, + doneval, + gutils->invertPointerM(call.getArgOperand(7), Builder2), + saldc}; + safunccall = Builder2.CreateCall(dfunc, safuncargs); + } + if (bactive) { + if (cast(call.getArgOperand(1))->getValue() == 112 || + cast(call.getArgOperand(1))->getValue() == 113) { + sbatrans = Builder2.getInt32(111); + if (cast(call.getArgOperand(0))->getValue() == 102) + sblda = lookup(gutils->getNewFromOriginal(call.getArgOperand(5)), + Builder2); + else + sblda = lookup(gutils->getNewFromOriginal(call.getArgOperand(3)), + Builder2); + } else if (cast(call.getArgOperand(1))->getValue() == + 111) { + sbatrans = Builder2.getInt32(112); + if (cast(call.getArgOperand(0))->getValue() == 102) + sblda = lookup(gutils->getNewFromOriginal(call.getArgOperand(3)), + Builder2); + else + sblda = lookup(gutils->getNewFromOriginal(call.getArgOperand(5)), + Builder2); + } else + assert(false && "Wrong value"); + SmallVector sbfuncargs = { + lookup(gutils->getNewFromOriginal(call.getArgOperand(0)), + Builder2), + sbatrans, + Builder2.getInt32(111), + lookup(gutils->getNewFromOriginal(call.getArgOperand(5)), + Builder2), + lookup(gutils->getNewFromOriginal(call.getArgOperand(4)), + Builder2), + lookup(gutils->getNewFromOriginal(call.getArgOperand(3)), + Builder2), + lookup(gutils->getNewFromOriginal(call.getArgOperand(6)), + Builder2), + structarg1, + sblda, + gutils->invertPointerM(call.getArgOperand(12), Builder2), + sbldb, + doneval, + gutils->invertPointerM(call.getArgOperand(9), Builder2), + sbldc}; + sbfunccall = Builder2.CreateCall(dfunc, sbfuncargs); + } + auto scfunc = gutils->oldFunc->getParent()->getOrInsertFunction( + scfuncname, Builder2.getVoidTy(), Builder2.getInt32Ty(), + call.getArgOperand(6)->getType(), call.getArgOperand(7)->getType(), + Builder2.getInt32Ty()); + auto clen = Builder2.CreateMul( + gutils->getNewFromOriginal(call.getArgOperand(3)), + gutils->getNewFromOriginal(call.getArgOperand(4))); + SmallVector scfuncargs = { + clen, + lookup(gutils->getNewFromOriginal(call.getArgOperand(11)), + Builder2), + gutils->invertPointerM(call.getArgOperand(12), Builder2), + Builder2.getInt32(1)}; + auto scfunccall = Builder2.CreateCall(scfunc, scfuncargs); + if (Acache) + CallInst::CreateFree(structarg1, safunccall->getNextNode()); + if (Bcache) + CallInst::CreateFree(structarg2, sbfunccall->getNextNode()); + } + + if (gutils->knownRecomputeHeuristic.find(orig) != + gutils->knownRecomputeHeuristic.end()) { + if (!gutils->knownRecomputeHeuristic[orig]) { + gutils->cacheForReverse(BuilderZ, newCall, + getIndex(orig, CacheType::Self)); + } + } + + if (Mode == DerivativeMode::ReverseModeGradient) { + eraseIfUnused(*orig, /*erase*/ true, /*check*/ false); + } else { + eraseIfUnused(*orig); + } + return; + } + } + // Return void visitCallInst(llvm::CallInst &call) { CallInst *const newCall = cast(gutils->getNewFromOriginal(&call)); @@ -6134,177 +6584,9 @@ class AdjointGenerator return; } - if ((funcName == "cblas_ddot" || funcName == "cblas_sdot") && - called->isDeclaration()) { - Type *innerType; - std::string dfuncName; - if (funcName == "cblas_ddot") { - innerType = Type::getDoubleTy(call.getContext()); - dfuncName = "cblas_daxpy"; - } else if (funcName == "cblas_sdot") { - innerType = Type::getFloatTy(call.getContext()); - dfuncName = "cblas_saxpy"; - } else { - assert(false && "Unreachable"); - } - Type *castvals[2] = {call.getArgOperand(1)->getType(), - call.getArgOperand(3)->getType()}; - auto *cachetype = - StructType::get(call.getContext(), ArrayRef(castvals)); - Value *undefinit = UndefValue::get(cachetype); - Value *cacheval; - auto in_arg = call.getCalledFunction()->arg_begin(); - in_arg++; - Argument *xfuncarg = in_arg; - in_arg++; - in_arg++; - Argument *yfuncarg = in_arg; - bool xcache = !gutils->isConstantValue(call.getArgOperand(3)) && - uncacheable_args.find(xfuncarg)->second; - bool ycache = !gutils->isConstantValue(call.getArgOperand(1)) && - uncacheable_args.find(yfuncarg)->second; - if ((Mode == DerivativeMode::ReverseModeCombined || - Mode == DerivativeMode::ReverseModePrimal) && - (xcache || ycache)) { - Value *arg1, *arg2; - auto size = ConstantExpr::getSizeOf(innerType); - if (xcache) { - auto dmemcpy = - getOrInsertMemcpyStrided(*gutils->oldFunc->getParent(), - PointerType::getUnqual(innerType), 0, 0); - auto malins = CallInst::CreateMalloc( - gutils->getNewFromOriginal(&call), size->getType(), innerType, - size, call.getArgOperand(0), nullptr, ""); - arg1 = - BuilderZ.CreateBitCast(malins, call.getArgOperand(1)->getType()); - SmallVector args; - args.push_back(arg1); - args.push_back(gutils->getNewFromOriginal(call.getArgOperand(1))); - args.push_back(call.getArgOperand(0)); - args.push_back(call.getArgOperand(2)); - BuilderZ.CreateCall(dmemcpy, args); - } - if (ycache) { - auto dmemcpy = - getOrInsertMemcpyStrided(*gutils->oldFunc->getParent(), - PointerType::getUnqual(innerType), 0, 0); - auto malins = CallInst::CreateMalloc( - gutils->getNewFromOriginal(&call), size->getType(), innerType, - size, call.getArgOperand(0), nullptr, ""); - arg2 = - BuilderZ.CreateBitCast(malins, call.getArgOperand(3)->getType()); - SmallVector args; - args.push_back(arg2); - args.push_back(gutils->getNewFromOriginal(call.getArgOperand(3))); - args.push_back(call.getArgOperand(0)); - args.push_back(call.getArgOperand(4)); - BuilderZ.CreateCall(dmemcpy, args); - } - if (xcache && ycache) { - auto valins1 = BuilderZ.CreateInsertValue(undefinit, arg1, 0); - cacheval = BuilderZ.CreateInsertValue(valins1, arg2, 1); - } else if (xcache) - cacheval = arg1; - else { - assert(ycache); - cacheval = arg2; - } - gutils->cacheForReverse(BuilderZ, cacheval, - getIndex(&call, CacheType::Tape)); - } - if (Mode == DerivativeMode::ReverseModeCombined || - Mode == DerivativeMode::ReverseModeGradient) { - IRBuilder<> Builder2(call.getParent()); - getReverseBuilder(Builder2); - auto derivcall = gutils->oldFunc->getParent()->getOrInsertFunction( - dfuncName, Builder2.getVoidTy(), Builder2.getInt32Ty(), innerType, - call.getArgOperand(1)->getType(), Builder2.getInt32Ty(), - call.getArgOperand(3)->getType(), Builder2.getInt32Ty()); - Value *structarg1; - Value *structarg2; - if (xcache || ycache) { - if (Mode == DerivativeMode::ReverseModeGradient && - (!gutils->isConstantValue(call.getArgOperand(1)) || - !gutils->isConstantValue(call.getArgOperand(3)))) { - cacheval = BuilderZ.CreatePHI(cachetype, 0); - } - cacheval = - lookup(gutils->cacheForReverse(BuilderZ, cacheval, - getIndex(&call, CacheType::Tape)), - Builder2); - if (xcache && ycache) { - structarg1 = BuilderZ.CreateExtractValue(cacheval, 0); - structarg2 = BuilderZ.CreateExtractValue(cacheval, 1); - } else if (xcache) - structarg1 = cacheval; - else if (ycache) - structarg2 = cacheval; - } - if (!xcache) - structarg1 = lookup( - gutils->getNewFromOriginal(orig->getArgOperand(1)), Builder2); - if (!ycache) - structarg2 = lookup( - gutils->getNewFromOriginal(orig->getArgOperand(3)), Builder2); - CallInst *firstdcall, *seconddcall; - if (!gutils->isConstantValue(call.getArgOperand(3))) { - Value *estride; - if (xcache) - estride = Builder2.getInt32(1); - else - estride = lookup(gutils->getNewFromOriginal(orig->getArgOperand(2)), - Builder2); - SmallVector args1 = { - lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)), - Builder2), - diffe(orig, Builder2), - structarg1, - estride, - lookup(gutils->invertPointerM(orig->getArgOperand(3), Builder2), - Builder2), - lookup(gutils->getNewFromOriginal(orig->getArgOperand(4)), - Builder2)}; - firstdcall = Builder2.CreateCall(derivcall, args1); - } - if (!gutils->isConstantValue(call.getArgOperand(1))) { - Value *estride; - if (ycache) - estride = Builder2.getInt32(1); - else - estride = lookup(gutils->getNewFromOriginal(orig->getArgOperand(4)), - Builder2); - SmallVector args2 = { - lookup(gutils->getNewFromOriginal(orig->getArgOperand(0)), - Builder2), - diffe(orig, Builder2), - structarg2, - estride, - lookup(gutils->invertPointerM(orig->getArgOperand(1), Builder2), - Builder2), - lookup(gutils->getNewFromOriginal(orig->getArgOperand(2)), - Builder2)}; - seconddcall = Builder2.CreateCall(derivcall, args2); - } - setDiffe(orig, Constant::getNullValue(orig->getType()), Builder2); - if (xcache) - CallInst::CreateFree(structarg1, firstdcall->getNextNode()); - if (ycache) - CallInst::CreateFree(structarg2, seconddcall->getNextNode()); - } - - if (gutils->knownRecomputeHeuristic.find(orig) != - gutils->knownRecomputeHeuristic.end()) { - if (!gutils->knownRecomputeHeuristic[orig]) { - gutils->cacheForReverse(BuilderZ, newCall, - getIndex(orig, CacheType::Self)); - } - } - - if (Mode == DerivativeMode::ReverseModeGradient) { - eraseIfUnused(*orig, /*erase*/ true, /*check*/ false); - } else { - eraseIfUnused(*orig); - } + if (funcName.startswith("cblas_") && + !gutils->isConstantInstruction(&call) && called->isDeclaration()) { + handleBlas(call, called, funcName, newCall, uncacheable_args); return; } diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 566bd532cc67..d5e33db6b371 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -1398,6 +1398,14 @@ class Enzyme : public ModulePass { CI->addParamAttr(3, Attribute::ReadOnly); CI->addParamAttr(3, Attribute::NoCapture); } + if ((Fn->getName() == "cblas_dgemm" || + Fn->getName() == "cblas_sgemm") && + Fn->isDeclaration()) { + CI->addParamAttr(7, Attribute::ReadOnly); + CI->addParamAttr(7, Attribute::NoCapture); + CI->addParamAttr(9, Attribute::ReadOnly); + CI->addParamAttr(9, Attribute::NoCapture); + } if (Fn->getName() == "frexp" || Fn->getName() == "frexpf" || Fn->getName() == "frexpl") { CI->addAttribute(AttributeList::FunctionIndex, Attribute::ArgMemOnly); diff --git a/enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_col_nomod.ll b/enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_col_nomod.ll new file mode 100644 index 000000000000..3f1e582dddcd --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_col_nomod.ll @@ -0,0 +1,121 @@ +;RUN: %opt < %s %loadEnzyme -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +;#include +; +;extern double __enzyme_autodiff(void *, double *, double *, double *, double *, double *, double*, double, double); +; +;void g(double *restrict A, double *restrict B, double *C, double alpha, double beta) { +; cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, 4, 3, 2, alpha, A, 4, B, 2, beta, C, 4); +;} +; +;int main() { +; double A[] = {0.11, 0.12, 0.13, 0.14, +; 0.21, 0.22, 0.23, 0.24}; +; double B[] = {1011, 1012, +; 1021, 1022, +; 1031, 1032}; +; double C[] = {0.00, 0.00, 0.00, 0.00, +; 0.00, 0.00, 0.00, 0.00, +; 0.00, 0.00, 0.00, 0.00}; +; double A1[] = {0, 0, 0, 0, 0, 0, 0, 0}; +; double B1[] = {0, 0, 0, 0, 0, 0}; +; double C1[] = {1, 3, 7, 11, +; 0.00, 0.00, 0.00, 0.00, +; 0.00, 0.00, 0.00, 0.00}; +; __enzyme_autodiff((void*)g, A, A1, B, B1, C, C1, 2.0, 3.0); +;} + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@__const.main.A = private unnamed_addr constant [8 x double] [double 1.100000e-01, double 1.200000e-01, double 1.300000e-01, double 1.400000e-01, double 2.100000e-01, double 2.200000e-01, double 2.300000e-01, double 2.400000e-01], align 16 + +define dso_local void @g(double* noalias %A, double* noalias %B, double* %C, double %alpha, double %beta) { +entry: + %A.addr = alloca double*, align 8 + %B.addr = alloca double*, align 8 + %C.addr = alloca double*, align 8 + %alpha.addr = alloca double, align 8 + %beta.addr = alloca double, align 8 + store double* %A, double** %A.addr, align 8 + store double* %B, double** %B.addr, align 8 + store double* %C, double** %C.addr, align 8 + store double %alpha, double* %alpha.addr, align 8 + store double %beta, double* %beta.addr, align 8 + %0 = load double, double* %alpha.addr, align 8 + %1 = load double*, double** %A.addr, align 8 + %2 = load double*, double** %B.addr, align 8 + %3 = load double, double* %beta.addr, align 8 + %4 = load double*, double** %C.addr, align 8 + call void @cblas_dgemm(i32 102, i32 111, i32 111, i32 4, i32 3, i32 2, double %0, double* %1, i32 4, double* %2, i32 2, double %3, double* %4, i32 4) + ret void +} + +declare dso_local void @cblas_dgemm(i32, i32, i32, i32, i32, i32, double, double*, i32, double*, i32, double, double*, i32) + +define dso_local i32 @main() { +entry: + %A = alloca [8 x double], align 16 + %B = alloca [6 x double], align 16 + %C = alloca [12 x double], align 16 + %A1 = alloca [8 x double], align 16 + %B1 = alloca [6 x double], align 16 + %C1 = alloca [12 x double], align 16 + %0 = bitcast [8 x double]* %A to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([8 x double]* @__const.main.A to i8*), i64 64, i1 false) + %1 = bitcast [6 x double]* %B to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %1, i8 0, i64 48, i1 false) + %2 = bitcast i8* %1 to [6 x double]* + %3 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 0 + store double 1.011000e+03, double* %3, align 16 + %4 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 1 + store double 1.012000e+03, double* %4, align 8 + %5 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 2 + store double 1.021000e+03, double* %5, align 16 + %6 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 3 + store double 1.022000e+03, double* %6, align 8 + %7 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 4 + store double 1.031000e+03, double* %7, align 16 + %8 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 5 + store double 1.032000e+03, double* %8, align 8 + %9 = bitcast [12 x double]* %C to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %9, i8 0, i64 96, i1 false) + %10 = bitcast [8 x double]* %A1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %10, i8 0, i64 64, i1 false) + %11 = bitcast [6 x double]* %B1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %11, i8 0, i64 48, i1 false) + %12 = bitcast [12 x double]* %C1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %12, i8 0, i64 96, i1 false) + %13 = bitcast i8* %12 to <{ double, double, double, double, [8 x double] }>* + %14 = getelementptr inbounds <{ double, double, double, double, [8 x double] }>, <{ double, double, double, double, [8 x double] }>* %13, i32 0, i32 0 + store double 1.000000e+00, double* %14, align 16 + %15 = getelementptr inbounds <{ double, double, double, double, [8 x double] }>, <{ double, double, double, double, [8 x double] }>* %13, i32 0, i32 1 + store double 3.000000e+00, double* %15, align 8 + %16 = getelementptr inbounds <{ double, double, double, double, [8 x double] }>, <{ double, double, double, double, [8 x double] }>* %13, i32 0, i32 2 + store double 7.000000e+00, double* %16, align 16 + %17 = getelementptr inbounds <{ double, double, double, double, [8 x double] }>, <{ double, double, double, double, [8 x double] }>* %13, i32 0, i32 3 + store double 1.100000e+01, double* %17, align 8 + %arraydecay = getelementptr inbounds [8 x double], [8 x double]* %A, i32 0, i32 0 + %arraydecay1 = getelementptr inbounds [8 x double], [8 x double]* %A1, i32 0, i32 0 + %arraydecay2 = getelementptr inbounds [6 x double], [6 x double]* %B, i32 0, i32 0 + %arraydecay3 = getelementptr inbounds [6 x double], [6 x double]* %B1, i32 0, i32 0 + %arraydecay4 = getelementptr inbounds [12 x double], [12 x double]* %C, i32 0, i32 0 + %arraydecay5 = getelementptr inbounds [12 x double], [12 x double]* %C1, i32 0, i32 0 + %call = call double @__enzyme_autodiff(i8* bitcast (void (double*, double*, double*, double, double)* @g to i8*), double* %arraydecay, double* %arraydecay1, double* %arraydecay2, double* %arraydecay3, double* %arraydecay4, double* %arraydecay5, double 2.000000e+00, double 3.000000e+00) + ret i32 0 +} + +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1) + +declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1) + +declare dso_local double @__enzyme_autodiff(i8*, double*, double*, double*, double*, double*, double*, double, double) + +;CHECK:define internal { double, double } @diffeg(double* noalias %A, double* %"A'", double* noalias %B, double* %"B'", double* %C, double* %"C'", double %alpha, double %beta) { +;CHECK-NEXT:entry: +;CHECK-NEXT: call void @cblas_dgemm(i32 102, i32 111, i32 111, i32 4, i32 3, i32 2, double %alpha, double* nocapture readonly %A, i32 4, double* nocapture readonly %B, i32 2, double %beta, double* %C, i32 4) +;CHECK-NEXT: call void @cblas_dgemm(i32 102, i32 111, i32 112, i32 4, i32 2, i32 3, double %alpha, double* nocapture readonly %"C'", i32 4, double* nocapture readonly %B, i32 3, double 1.000000e+00, double* %"A'", i32 4) +;CHECK-NEXT: call void @cblas_dgemm(i32 102, i32 112, i32 111, i32 2, i32 3, i32 4, double %alpha, double* nocapture readonly %A, i32 4, double* nocapture readonly %"C'", i32 4, double 1.000000e+00, double* %"B'", i32 2) +;CHECK-NEXT: call void @cblas_dscal(i32 12, double %beta, double* %"C'", i32 1) +;CHECK-NEXT: ret { double, double } zeroinitializer +;CHECK-NEXT:} diff --git a/enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_col_nomod_transa.ll b/enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_col_nomod_transa.ll new file mode 100644 index 000000000000..7a923ae0c813 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_col_nomod_transa.ll @@ -0,0 +1,123 @@ +;RUN: %opt < %s %loadEnzyme -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +;#include +; +;extern double __enzyme_autodiff(void *, double *, double *, double *, double *, double *, double*, double, double); +; +;void g(double *restrict A, double *restrict B, double *C, double alpha, double beta) { +; cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, 4, 3, 2, alpha, A, 4, B, 2, beta, C, 4); +;} +; +;int main() { +; double A[] = {0.11, 0.21, +; 0.12, 0.22, +; 0.13, 0.23, +; 0.14, 0.24}; +; double B[] = {1011, 1012, +; 1021, 1022, +; 1031, 1032}; +; double C[] = {0.00, 0.00, 0.00, 0.00, +; 0.00, 0.00, 0.00, 0.00, +; 0.00, 0.00, 0.00, 0.00}; +; double A1[] = {0, 0, 0, 0, 0, 0, 0, 0}; +; double B1[] = {0, 0, 0, 0, 0, 0}; +; double C1[] = {1, 3, 7, 11, +; 0.00, 0.00, 0.00, 0.00, +; 0.00, 0.00, 0.00, 0.00}; +; __enzyme_autodiff((void*)g, A, A1, B, B1, C, C1, 2.0, 3.0); +;} + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@__const.main.A = private unnamed_addr constant [8 x double] [double 1.100000e-01, double 2.100000e-01, double 1.200000e-01, double 2.200000e-01, double 1.300000e-01, double 2.300000e-01, double 1.400000e-01, double 2.400000e-01], align 16 + +define dso_local void @g(double* noalias %A, double* noalias %B, double* %C, double %alpha, double %beta) { +entry: + %A.addr = alloca double*, align 8 + %B.addr = alloca double*, align 8 + %C.addr = alloca double*, align 8 + %alpha.addr = alloca double, align 8 + %beta.addr = alloca double, align 8 + store double* %A, double** %A.addr, align 8 + store double* %B, double** %B.addr, align 8 + store double* %C, double** %C.addr, align 8 + store double %alpha, double* %alpha.addr, align 8 + store double %beta, double* %beta.addr, align 8 + %0 = load double, double* %alpha.addr, align 8 + %1 = load double*, double** %A.addr, align 8 + %2 = load double*, double** %B.addr, align 8 + %3 = load double, double* %beta.addr, align 8 + %4 = load double*, double** %C.addr, align 8 + call void @cblas_dgemm(i32 102, i32 111, i32 111, i32 4, i32 3, i32 2, double %0, double* %1, i32 4, double* %2, i32 2, double %3, double* %4, i32 4) + ret void +} + +declare dso_local void @cblas_dgemm(i32, i32, i32, i32, i32, i32, double, double*, i32, double*, i32, double, double*, i32) + +define dso_local i32 @main() { +entry: + %A = alloca [8 x double], align 16 + %B = alloca [6 x double], align 16 + %C = alloca [12 x double], align 16 + %A1 = alloca [8 x double], align 16 + %B1 = alloca [6 x double], align 16 + %C1 = alloca [12 x double], align 16 + %0 = bitcast [8 x double]* %A to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([8 x double]* @__const.main.A to i8*), i64 64, i1 false) + %1 = bitcast [6 x double]* %B to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %1, i8 0, i64 48, i1 false) + %2 = bitcast i8* %1 to [6 x double]* + %3 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 0 + store double 1.011000e+03, double* %3, align 16 + %4 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 1 + store double 1.012000e+03, double* %4, align 8 + %5 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 2 + store double 1.021000e+03, double* %5, align 16 + %6 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 3 + store double 1.022000e+03, double* %6, align 8 + %7 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 4 + store double 1.031000e+03, double* %7, align 16 + %8 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 5 + store double 1.032000e+03, double* %8, align 8 + %9 = bitcast [12 x double]* %C to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %9, i8 0, i64 96, i1 false) + %10 = bitcast [8 x double]* %A1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %10, i8 0, i64 64, i1 false) + %11 = bitcast [6 x double]* %B1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %11, i8 0, i64 48, i1 false) + %12 = bitcast [12 x double]* %C1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %12, i8 0, i64 96, i1 false) + %13 = bitcast i8* %12 to <{ double, double, double, double, [8 x double] }>* + %14 = getelementptr inbounds <{ double, double, double, double, [8 x double] }>, <{ double, double, double, double, [8 x double] }>* %13, i32 0, i32 0 + store double 1.000000e+00, double* %14, align 16 + %15 = getelementptr inbounds <{ double, double, double, double, [8 x double] }>, <{ double, double, double, double, [8 x double] }>* %13, i32 0, i32 1 + store double 3.000000e+00, double* %15, align 8 + %16 = getelementptr inbounds <{ double, double, double, double, [8 x double] }>, <{ double, double, double, double, [8 x double] }>* %13, i32 0, i32 2 + store double 7.000000e+00, double* %16, align 16 + %17 = getelementptr inbounds <{ double, double, double, double, [8 x double] }>, <{ double, double, double, double, [8 x double] }>* %13, i32 0, i32 3 + store double 1.100000e+01, double* %17, align 8 + %arraydecay = getelementptr inbounds [8 x double], [8 x double]* %A, i32 0, i32 0 + %arraydecay1 = getelementptr inbounds [8 x double], [8 x double]* %A1, i32 0, i32 0 + %arraydecay2 = getelementptr inbounds [6 x double], [6 x double]* %B, i32 0, i32 0 + %arraydecay3 = getelementptr inbounds [6 x double], [6 x double]* %B1, i32 0, i32 0 + %arraydecay4 = getelementptr inbounds [12 x double], [12 x double]* %C, i32 0, i32 0 + %arraydecay5 = getelementptr inbounds [12 x double], [12 x double]* %C1, i32 0, i32 0 + %call = call double @__enzyme_autodiff(i8* bitcast (void (double*, double*, double*, double, double)* @g to i8*), double* %arraydecay, double* %arraydecay1, double* %arraydecay2, double* %arraydecay3, double* %arraydecay4, double* %arraydecay5, double 2.000000e+00, double 3.000000e+00) + ret i32 0 +} + +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1) + +declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1) + +declare dso_local double @__enzyme_autodiff(i8*, double*, double*, double*, double*, double*, double*, double, double) + +;CHECK:define internal { double, double } @diffeg(double* noalias %A, double* %"A'", double* noalias %B, double* %"B'", double* %C, double* %"C'", double %alpha, double %beta) { +;CHECK-NEXT:entry: +;CHECK-NEXT: call void @cblas_dgemm(i32 102, i32 111, i32 111, i32 4, i32 3, i32 2, double %alpha, double* nocapture readonly %A, i32 4, double* nocapture readonly %B, i32 2, double %beta, double* %C, i32 4) +;CHECK-NEXT: call void @cblas_dgemm(i32 102, i32 111, i32 112, i32 4, i32 2, i32 3, double %alpha, double* nocapture readonly %"C'", i32 4, double* nocapture readonly %B, i32 3, double 1.000000e+00, double* %"A'", i32 4) +;CHECK-NEXT: call void @cblas_dgemm(i32 102, i32 112, i32 111, i32 2, i32 3, i32 4, double %alpha, double* nocapture readonly %A, i32 4, double* nocapture readonly %"C'", i32 4, double 1.000000e+00, double* %"B'", i32 2) +;CHECK-NEXT: call void @cblas_dscal(i32 12, double %beta, double* %"C'", i32 1) +;CHECK-NEXT: ret { double, double } zeroinitializer +;CHECK-NEXT:} diff --git a/enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_col_nomod_transb.ll b/enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_col_nomod_transb.ll new file mode 100644 index 000000000000..5b702dcb2145 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_col_nomod_transb.ll @@ -0,0 +1,112 @@ +;RUN: %opt < %s %loadEnzyme -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +;#include +; +;extern double __enzyme_autodiff(void *, double *, double *, double *, double *, double *, double*, double, double); +; +;void g(double *restrict A, double *restrict B, double *C, double alpha, double beta) { +; cblas_dgemm(CblasColMajor, CblasNoTrans, CblasTrans, 4, 3, 2, alpha, A, 4, B, 3, beta, C, 4); +;} +; +;int main() { +; double A[] = {0.11, 0.12, 0.13, 0.14, +; 0.21, 0.22, 0.23, 0.24}; +; double B[] = {1011, 1021, 1031, +; 1012, 1022, 1032}; +; double C[] = {0.00, 0.00, 0.00, 0.00, +; 0.00, 0.00, 0.00, 0.00, +; 0.00, 0.00, 0.00, 0.00}; +; double A1[] = {0, 0, 0, 0, 0, 0, 0, 0}; +; double B1[] = {0, 0, 0, 0, 0, 0}; +; double C1[] = {1, 1, 1, 1, +; 1, 1, 1, 1, +; 1, 1, 1, 1}; +; __enzyme_autodiff((void*)g, A, A1, B, B1, C, C1, 2.0, 3.0); +;} + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@__const.main.A = private unnamed_addr constant [8 x double] [double 1.100000e-01, double 1.200000e-01, double 1.300000e-01, double 1.400000e-01, double 2.100000e-01, double 2.200000e-01, double 2.300000e-01, double 2.400000e-01], align 16 +@__const.main.C1 = private unnamed_addr constant [12 x double] [double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00], align 16 + +define dso_local void @g(double* noalias %A, double* noalias %B, double* %C, double %alpha, double %beta) { +entry: + %A.addr = alloca double*, align 8 + %B.addr = alloca double*, align 8 + %C.addr = alloca double*, align 8 + %alpha.addr = alloca double, align 8 + %beta.addr = alloca double, align 8 + store double* %A, double** %A.addr, align 8 + store double* %B, double** %B.addr, align 8 + store double* %C, double** %C.addr, align 8 + store double %alpha, double* %alpha.addr, align 8 + store double %beta, double* %beta.addr, align 8 + %0 = load double, double* %alpha.addr, align 8 + %1 = load double*, double** %A.addr, align 8 + %2 = load double*, double** %B.addr, align 8 + %3 = load double, double* %beta.addr, align 8 + %4 = load double*, double** %C.addr, align 8 + call void @cblas_dgemm(i32 102, i32 111, i32 112, i32 4, i32 3, i32 2, double %0, double* %1, i32 4, double* %2, i32 3, double %3, double* %4, i32 4) + ret void +} + +declare dso_local void @cblas_dgemm(i32, i32, i32, i32, i32, i32, double, double*, i32, double*, i32, double, double*, i32) + +define dso_local i32 @main() { +entry: + %A = alloca [8 x double], align 16 + %B = alloca [6 x double], align 16 + %C = alloca [12 x double], align 16 + %A1 = alloca [8 x double], align 16 + %B1 = alloca [6 x double], align 16 + %C1 = alloca [12 x double], align 16 + %0 = bitcast [8 x double]* %A to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([8 x double]* @__const.main.A to i8*), i64 64, i1 false) + %1 = bitcast [6 x double]* %B to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %1, i8 0, i64 48, i1 false) + %2 = bitcast i8* %1 to [6 x double]* + %3 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 0 + store double 1.011000e+03, double* %3, align 16 + %4 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 1 + store double 1.021000e+03, double* %4, align 8 + %5 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 2 + store double 1.031000e+03, double* %5, align 16 + %6 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 3 + store double 1.012000e+03, double* %6, align 8 + %7 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 4 + store double 1.022000e+03, double* %7, align 16 + %8 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 5 + store double 1.032000e+03, double* %8, align 8 + %9 = bitcast [12 x double]* %C to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %9, i8 0, i64 96, i1 false) + %10 = bitcast [8 x double]* %A1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %10, i8 0, i64 64, i1 false) + %11 = bitcast [6 x double]* %B1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %11, i8 0, i64 48, i1 false) + %12 = bitcast [12 x double]* %C1 to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %12, i8* align 16 bitcast ([12 x double]* @__const.main.C1 to i8*), i64 96, i1 false) + %arraydecay = getelementptr inbounds [8 x double], [8 x double]* %A, i32 0, i32 0 + %arraydecay1 = getelementptr inbounds [8 x double], [8 x double]* %A1, i32 0, i32 0 + %arraydecay2 = getelementptr inbounds [6 x double], [6 x double]* %B, i32 0, i32 0 + %arraydecay3 = getelementptr inbounds [6 x double], [6 x double]* %B1, i32 0, i32 0 + %arraydecay4 = getelementptr inbounds [12 x double], [12 x double]* %C, i32 0, i32 0 + %arraydecay5 = getelementptr inbounds [12 x double], [12 x double]* %C1, i32 0, i32 0 + %call = call double @__enzyme_autodiff(i8* bitcast (void (double*, double*, double*, double, double)* @g to i8*), double* %arraydecay, double* %arraydecay1, double* %arraydecay2, double* %arraydecay3, double* %arraydecay4, double* %arraydecay5, double 2.000000e+00, double 3.000000e+00) + ret i32 0 +} + +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1) + +declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1) + +declare dso_local double @__enzyme_autodiff(i8*, double*, double*, double*, double*, double*, double*, double, double) + +;CHECK:define internal { double, double } @diffeg(double* noalias %A, double* %"A'", double* noalias %B, double* %"B'", double* %C, double* %"C'", double %alpha, double %beta) { +;CHECK-NEXT:entry: +;CHECK-NEXT: call void @cblas_dgemm(i32 102, i32 111, i32 112, i32 4, i32 3, i32 2, double %alpha, double* nocapture readonly %A, i32 4, double* nocapture readonly %B, i32 3, double %beta, double* %C, i32 4) +;CHECK-NEXT: call void @cblas_dgemm(i32 102, i32 111, i32 111, i32 4, i32 2, i32 3, double %alpha, double* nocapture readonly %"C'", i32 4, double* nocapture readonly %B, i32 3, double 1.000000e+00, double* %"A'", i32 4) +;CHECK-NEXT: call void @cblas_dgemm(i32 102, i32 112, i32 111, i32 2, i32 3, i32 4, double %alpha, double* nocapture readonly %A, i32 4, double* nocapture readonly %"C'", i32 4, double 1.000000e+00, double* %"B'", i32 2) +;CHECK-NEXT: call void @cblas_dscal(i32 12, double %beta, double* %"C'", i32 1) +;CHECK-NEXT: ret { double, double } zeroinitializer +;CHECK-NEXT:} diff --git a/enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_col_nomod_transboth.ll b/enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_col_nomod_transboth.ll new file mode 100644 index 000000000000..cd4d6aaf3810 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_col_nomod_transboth.ll @@ -0,0 +1,114 @@ +;RUN: %opt < %s %loadEnzyme -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +;#include +; +;extern double __enzyme_autodiff(void *, double *, double *, double *, double *, double *, double *, double, double); +; +;void g(double *restrict A, double *restrict B, double *C, double alpha, double beta) { +; cblas_dgemm(CblasColMajor, CblasTrans, CblasTrans, 4, 3, 2, alpha, A, 2, B, 3, beta, C, 4); +;} +; +;int main() { +; double A[] = {0.11, 0.21, +; 0.12, 0.22, +; 0.13, 0.23, +; 0.14, 0.24}; +; double B[] = {1011, 1021, 1031, +; 1012, 1022, 1032}; +; double C[] = {0.00, 0.00, 0.00, 0.00, +; 0.00, 0.00, 0.00, 0.00, +; 0.00, 0.00, 0.00, 0.00}; +; double A1[] = {0, 0, 0, 0, 0, 0, 0, 0}; +; double B1[] = {0, 0, 0, 0, 0, 0}; +; double C1[] = {1, 1, 1, 1, +; 1, 1, 1, 1, +; 1, 1, 1, 1}; +; __enzyme_autodiff((void*)g, A, A1, B, B1, C, C1, 2.0, 3.0); +;} + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@__const.main.A = private unnamed_addr constant [8 x double] [double 1.100000e-01, double 2.100000e-01, double 1.200000e-01, double 2.200000e-01, double 1.300000e-01, double 2.300000e-01, double 1.400000e-01, double 2.400000e-01], align 16 +@__const.main.C1 = private unnamed_addr constant [12 x double] [double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00], align 16 + +define dso_local void @g(double* noalias %A, double* noalias %B, double* %C, double %alpha, double %beta) { +entry: + %A.addr = alloca double*, align 8 + %B.addr = alloca double*, align 8 + %C.addr = alloca double*, align 8 + %alpha.addr = alloca double, align 8 + %beta.addr = alloca double, align 8 + store double* %A, double** %A.addr, align 8 + store double* %B, double** %B.addr, align 8 + store double* %C, double** %C.addr, align 8 + store double %alpha, double* %alpha.addr, align 8 + store double %beta, double* %beta.addr, align 8 + %0 = load double, double* %alpha.addr, align 8 + %1 = load double*, double** %A.addr, align 8 + %2 = load double*, double** %B.addr, align 8 + %3 = load double, double* %beta.addr, align 8 + %4 = load double*, double** %C.addr, align 8 + call void @cblas_dgemm(i32 102, i32 112, i32 112, i32 4, i32 3, i32 2, double %0, double* %1, i32 2, double* %2, i32 3, double %3, double* %4, i32 4) + ret void +} + +declare dso_local void @cblas_dgemm(i32, i32, i32, i32, i32, i32, double, double*, i32, double*, i32, double, double*, i32) + +define dso_local i32 @main() { +entry: + %A = alloca [8 x double], align 16 + %B = alloca [6 x double], align 16 + %C = alloca [12 x double], align 16 + %A1 = alloca [8 x double], align 16 + %B1 = alloca [6 x double], align 16 + %C1 = alloca [12 x double], align 16 + %0 = bitcast [8 x double]* %A to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([8 x double]* @__const.main.A to i8*), i64 64, i1 false) + %1 = bitcast [6 x double]* %B to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %1, i8 0, i64 48, i1 false) + %2 = bitcast i8* %1 to [6 x double]* + %3 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 0 + store double 1.011000e+03, double* %3, align 16 + %4 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 1 + store double 1.021000e+03, double* %4, align 8 + %5 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 2 + store double 1.031000e+03, double* %5, align 16 + %6 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 3 + store double 1.012000e+03, double* %6, align 8 + %7 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 4 + store double 1.022000e+03, double* %7, align 16 + %8 = getelementptr inbounds [6 x double], [6 x double]* %2, i32 0, i32 5 + store double 1.032000e+03, double* %8, align 8 + %9 = bitcast [12 x double]* %C to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %9, i8 0, i64 96, i1 false) + %10 = bitcast [8 x double]* %A1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %10, i8 0, i64 64, i1 false) + %11 = bitcast [6 x double]* %B1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %11, i8 0, i64 48, i1 false) + %12 = bitcast [12 x double]* %C1 to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %12, i8* align 16 bitcast ([12 x double]* @__const.main.C1 to i8*), i64 96, i1 false) + %arraydecay = getelementptr inbounds [8 x double], [8 x double]* %A, i32 0, i32 0 + %arraydecay1 = getelementptr inbounds [8 x double], [8 x double]* %A1, i32 0, i32 0 + %arraydecay2 = getelementptr inbounds [6 x double], [6 x double]* %B, i32 0, i32 0 + %arraydecay3 = getelementptr inbounds [6 x double], [6 x double]* %B1, i32 0, i32 0 + %arraydecay4 = getelementptr inbounds [12 x double], [12 x double]* %C, i32 0, i32 0 + %arraydecay5 = getelementptr inbounds [12 x double], [12 x double]* %C1, i32 0, i32 0 + %call = call double @__enzyme_autodiff(i8* bitcast (void (double*, double*, double*, double, double)* @g to i8*), double* %arraydecay, double* %arraydecay1, double* %arraydecay2, double* %arraydecay3, double* %arraydecay4, double* %arraydecay5, double 2.000000e+00, double 3.000000e+00) + ret i32 0 +} + +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1) + +declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1) + +declare dso_local double @__enzyme_autodiff(i8*, double*, double*, double*, double*, double*, double*, double, double) + +;CHECK:define internal { double, double } @diffeg(double* noalias %A, double* %"A'", double* noalias %B, double* %"B'", double* %C, double* %"C'", double %alpha, double %beta) { +;CHECK-NEXT:entry: +;CHECK-NEXT: call void @cblas_dgemm(i32 102, i32 112, i32 112, i32 4, i32 3, i32 2, double %alpha, double* nocapture readonly %A, i32 2, double* nocapture readonly %B, i32 3, double %beta, double* %C, i32 4) +;CHECK-NEXT: call void @cblas_dgemm(i32 102, i32 111, i32 111, i32 4, i32 2, i32 3, double %alpha, double* nocapture readonly %"C'", i32 4, double* nocapture readonly %B, i32 3, double 1.000000e+00, double* %"A'", i32 4) +;CHECK-NEXT: call void @cblas_dgemm(i32 102, i32 111, i32 111, i32 2, i32 3, i32 4, double %alpha, double* nocapture readonly %A, i32 2, double* nocapture readonly %"C'", i32 4, double 1.000000e+00, double* %"B'", i32 2) +;CHECK-NEXT: call void @cblas_dscal(i32 12, double %beta, double* %"C'", i32 1) +;CHECK-NEXT: ret { double, double } zeroinitializer +;CHECK-NEXT:} diff --git a/enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_col_transboth_inactive_second.ll b/enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_col_transboth_inactive_second.ll new file mode 100644 index 000000000000..abeffd3ab4c1 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_col_transboth_inactive_second.ll @@ -0,0 +1,121 @@ +;RUN: if [ %llvmver -ge 8 ]; then %opt < %s %loadEnzyme -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi + +;#include +; +;extern double __enzyme_autodiff(void *, double *, double *, double *, double*, double, double);; +; +;void g(double *restrict A, double *C, double alpha, double beta) { +; double B[] = {1011, 1021, 1031, +; 1012, 1022, 1032}; +; cblas_dgemm(CblasColMajor, CblasTrans, CblasTrans, 4, 3, 2, alpha, A, 2, B, 3, beta, C, 4); +;} +; +;int main() { +; double A[] = {0.11, 0.21, +; 0.12, 0.22, +; 0.13, 0.23, +; 0.14, 0.24}; +; double C[] = {0.00, 0.00, 0.00, 0.00, +; 0.00, 0.00, 0.00, 0.00, +; 0.00, 0.00, 0.00, 0.00}; +; double A1[] = {0, 0, 0, 0, 0, 0, 0, 0}; +; double C1[] = {1, 1, 1, 1, +; 1, 1, 1, 1, +; 1, 1, 1, 1}; +; __enzyme_autodiff((void*)g, A, A1, C, C1, 2.0, 3.0); +;} + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@__const.main.A = private unnamed_addr constant [8 x double] [double 1.100000e-01, double 2.100000e-01, double 1.200000e-01, double 2.200000e-01, double 1.300000e-01, double 2.300000e-01, double 1.400000e-01, double 2.400000e-01], align 16 +@__const.main.C1 = private unnamed_addr constant [12 x double] [double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00], align 16 + +define dso_local void @g(double* noalias %A, double* %C, double %alpha, double %beta) { +entry: + %A.addr = alloca double*, align 8 + %C.addr = alloca double*, align 8 + %alpha.addr = alloca double, align 8 + %beta.addr = alloca double, align 8 + %B = alloca [6 x double], align 16 + store double* %A, double** %A.addr, align 8 + store double* %C, double** %C.addr, align 8 + store double %alpha, double* %alpha.addr, align 8 + store double %beta, double* %beta.addr, align 8 + %0 = bitcast [6 x double]* %B to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %0, i8 0, i64 48, i1 false) + %1 = bitcast i8* %0 to [6 x double]* + %2 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 0 + store double 1.011000e+03, double* %2, align 16 + %3 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 1 + store double 1.021000e+03, double* %3, align 8 + %4 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 2 + store double 1.031000e+03, double* %4, align 16 + %5 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 3 + store double 1.012000e+03, double* %5, align 8 + %6 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 4 + store double 1.022000e+03, double* %6, align 16 + %7 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 5 + store double 1.032000e+03, double* %7, align 8 + %8 = load double, double* %alpha.addr, align 8 + %9 = load double*, double** %A.addr, align 8 + %arraydecay = getelementptr inbounds [6 x double], [6 x double]* %B, i32 0, i32 0 + %10 = load double, double* %beta.addr, align 8 + %11 = load double*, double** %C.addr, align 8 + call void @cblas_dgemm(i32 102, i32 112, i32 112, i32 4, i32 3, i32 2, double %8, double* %9, i32 2, double* %arraydecay, i32 3, double %10, double* %11, i32 4) + ret void +} + +declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1) + +declare dso_local void @cblas_dgemm(i32, i32, i32, i32, i32, i32, double, double*, i32, double*, i32, double, double*, i32) + +define dso_local i32 @main() { +entry: + %A = alloca [8 x double], align 16 + %C = alloca [12 x double], align 16 + %A1 = alloca [8 x double], align 16 + %C1 = alloca [12 x double], align 16 + %0 = bitcast [8 x double]* %A to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([8 x double]* @__const.main.A to i8*), i64 64, i1 false) + %1 = bitcast [12 x double]* %C to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %1, i8 0, i64 96, i1 false) + %2 = bitcast [8 x double]* %A1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %2, i8 0, i64 64, i1 false) + %3 = bitcast [12 x double]* %C1 to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %3, i8* align 16 bitcast ([12 x double]* @__const.main.C1 to i8*), i64 96, i1 false) + %arraydecay = getelementptr inbounds [8 x double], [8 x double]* %A, i32 0, i32 0 + %arraydecay1 = getelementptr inbounds [8 x double], [8 x double]* %A1, i32 0, i32 0 + %arraydecay2 = getelementptr inbounds [12 x double], [12 x double]* %C, i32 0, i32 0 + %arraydecay3 = getelementptr inbounds [12 x double], [12 x double]* %C1, i32 0, i32 0 + %call = call double @__enzyme_autodiff(i8* bitcast (void (double*, double*, double, double)* @g to i8*), double* %arraydecay, double* %arraydecay1, double* %arraydecay2, double* %arraydecay3, double 2.000000e+00, double 3.000000e+00) + ret i32 0 +} + +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1) + +declare dso_local double @__enzyme_autodiff(i8*, double*, double*, double*, double*, double, double) + +;CHECK:define internal { double, double } @diffeg(double* noalias %A, double* %"A'", double* %C, double* %"C'", double %alpha, double %beta) { +;CHECK-NEXT:entry: +;CHECK-NEXT: %B = alloca [6 x double], align 16 +;CHECK-NEXT: %0 = bitcast [6 x double]* %B to i8* +;CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* align 16 %0, i8 0, i64 48, i1 false) +;CHECK-NEXT: %1 = getelementptr inbounds [6 x double], [6 x double]* %B, i32 0, i32 0 +;CHECK-NEXT: store double 1.011000e+03, double* %1, align 16 +;CHECK-NEXT: %2 = getelementptr inbounds [6 x double], [6 x double]* %B, i32 0, i32 1 +;CHECK-NEXT: store double 1.021000e+03, double* %2, align 8 +;CHECK-NEXT: %3 = getelementptr inbounds [6 x double], [6 x double]* %B, i32 0, i32 2 +;CHECK-NEXT: store double 1.031000e+03, double* %3, align 16 +;CHECK-NEXT: %4 = getelementptr inbounds [6 x double], [6 x double]* %B, i32 0, i32 3 +;CHECK-NEXT: store double 1.012000e+03, double* %4, align 8 +;CHECK-NEXT: %5 = getelementptr inbounds [6 x double], [6 x double]* %B, i32 0, i32 4 +;CHECK-NEXT: store double 1.022000e+03, double* %5, align 16 +;CHECK-NEXT: %6 = getelementptr inbounds [6 x double], [6 x double]* %B, i32 0, i32 5 +;CHECK-NEXT: store double 1.032000e+03, double* %6, align 8 +;CHECK-NEXT: %arraydecay = getelementptr inbounds [6 x double], [6 x double]* %B, i32 0, i32 0 +;CHECK-NEXT: call void @cblas_dgemm(i32 102, i32 112, i32 112, i32 4, i32 3, i32 2, double %alpha, double* nocapture readonly %A, i32 2, double* nocapture readonly %arraydecay, i32 3, double %beta, double* %C, i32 4) +;CHECK-NEXT: call void @cblas_dgemm(i32 102, i32 111, i32 111, i32 4, i32 2, i32 3, double %alpha, double* nocapture readonly %"C'", i32 4, double* nocapture readonly %arraydecay, i32 3, double 1.000000e+00, double* %"A'", i32 4) +;CHECK-NEXT: call void @cblas_dscal(i32 12, double %beta, double* %"C'", i32 1) +;CHECK-NEXT: ret { double, double } zeroinitializer +;CHECK-NEXT:} diff --git a/enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_row_nomod.ll b/enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_row_nomod.ll new file mode 100644 index 000000000000..245a07bb5f66 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_row_nomod.ll @@ -0,0 +1,120 @@ +;RUN: %opt < %s %loadEnzyme -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +;#include +; +;extern double __enzyme_autodiff(void *, double *, double *, double *, double *, +; double *, double *, double, double); +; +;void g(double *restrict A, double *restrict B, double *C, double alpha, double beta) { +; cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, 2, 2, 3, alpha, A, 3, +; B, 2, beta, C, 2); +;} +; +;int main() { +; double A[] = {0.11, 0.12, 0.13, 0.21, 0.22, 0.23}; +; double B[] = {1011, 1012, 1021, 1022, 1031, 1032}; +; double C[] = {0.00, 0.00, 0.00, 0.00}; +; double A1[] = {0, 0, 0, 0, 0, 0}; +; double B1[] = {0, 0, 0, 0, 0, 0}; +; double C1[] = {1, 3, 7, 11}; +; __enzyme_autodiff((void *)g, A, A1, B, B1, C, C1, 2.0, 3.0); +;} + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@__const.main.C1 = private unnamed_addr constant [4 x double] [double 1.000000e+00, double 3.000000e+00, double 7.000000e+00, double 1.100000e+01], align 16 + +define dso_local void @g(double* noalias %A, double* noalias %B, double* %C, double %alpha, double %beta) { +entry: + %A.addr = alloca double*, align 8 + %B.addr = alloca double*, align 8 + %C.addr = alloca double*, align 8 + %alpha.addr = alloca double, align 8 + %beta.addr = alloca double, align 8 + store double* %A, double** %A.addr, align 8 + store double* %B, double** %B.addr, align 8 + store double* %C, double** %C.addr, align 8 + store double %alpha, double* %alpha.addr, align 8 + store double %beta, double* %beta.addr, align 8 + %0 = load double, double* %alpha.addr, align 8 + %1 = load double*, double** %A.addr, align 8 + %2 = load double*, double** %B.addr, align 8 + %3 = load double, double* %beta.addr, align 8 + %4 = load double*, double** %C.addr, align 8 + call void @cblas_dgemm(i32 101, i32 111, i32 111, i32 2, i32 2, i32 3, double %0, double* %1, i32 3, double* %2, i32 2, double %3, double* %4, i32 2) + ret void +} + +declare dso_local void @cblas_dgemm(i32, i32, i32, i32, i32, i32, double, double*, i32, double*, i32, double, double*, i32) + +define dso_local i32 @main() { +entry: + %A = alloca [6 x double], align 16 + %B = alloca [6 x double], align 16 + %C = alloca [4 x double], align 16 + %A1 = alloca [6 x double], align 16 + %B1 = alloca [6 x double], align 16 + %C1 = alloca [4 x double], align 16 + %0 = bitcast [6 x double]* %A to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %0, i8 0, i64 48, i1 false) + %1 = bitcast i8* %0 to [6 x double]* + %2 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 0 + store double 1.100000e-01, double* %2, align 16 + %3 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 1 + store double 1.200000e-01, double* %3, align 8 + %4 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 2 + store double 1.300000e-01, double* %4, align 16 + %5 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 3 + store double 2.100000e-01, double* %5, align 8 + %6 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 4 + store double 2.200000e-01, double* %6, align 16 + %7 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 5 + store double 2.300000e-01, double* %7, align 8 + %8 = bitcast [6 x double]* %B to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %8, i8 0, i64 48, i1 false) + %9 = bitcast i8* %8 to [6 x double]* + %10 = getelementptr inbounds [6 x double], [6 x double]* %9, i32 0, i32 0 + store double 1.011000e+03, double* %10, align 16 + %11 = getelementptr inbounds [6 x double], [6 x double]* %9, i32 0, i32 1 + store double 1.012000e+03, double* %11, align 8 + %12 = getelementptr inbounds [6 x double], [6 x double]* %9, i32 0, i32 2 + store double 1.021000e+03, double* %12, align 16 + %13 = getelementptr inbounds [6 x double], [6 x double]* %9, i32 0, i32 3 + store double 1.022000e+03, double* %13, align 8 + %14 = getelementptr inbounds [6 x double], [6 x double]* %9, i32 0, i32 4 + store double 1.031000e+03, double* %14, align 16 + %15 = getelementptr inbounds [6 x double], [6 x double]* %9, i32 0, i32 5 + store double 1.032000e+03, double* %15, align 8 + %16 = bitcast [4 x double]* %C to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %16, i8 0, i64 32, i1 false) + %17 = bitcast [6 x double]* %A1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %17, i8 0, i64 48, i1 false) + %18 = bitcast [6 x double]* %B1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %18, i8 0, i64 48, i1 false) + %19 = bitcast [4 x double]* %C1 to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %19, i8* align 16 bitcast ([4 x double]* @__const.main.C1 to i8*), i64 32, i1 false) + %arraydecay = getelementptr inbounds [6 x double], [6 x double]* %A, i32 0, i32 0 + %arraydecay1 = getelementptr inbounds [6 x double], [6 x double]* %A1, i32 0, i32 0 + %arraydecay2 = getelementptr inbounds [6 x double], [6 x double]* %B, i32 0, i32 0 + %arraydecay3 = getelementptr inbounds [6 x double], [6 x double]* %B1, i32 0, i32 0 + %arraydecay4 = getelementptr inbounds [4 x double], [4 x double]* %C, i32 0, i32 0 + %arraydecay5 = getelementptr inbounds [4 x double], [4 x double]* %C1, i32 0, i32 0 + %call = call double @__enzyme_autodiff(i8* bitcast (void (double*, double*, double*, double, double)* @g to i8*), double* %arraydecay, double* %arraydecay1, double* %arraydecay2, double* %arraydecay3, double* %arraydecay4, double* %arraydecay5, double 2.000000e+00, double 3.000000e+00) + ret i32 0 +} + +declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1) + +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1) + +declare dso_local double @__enzyme_autodiff(i8*, double*, double*, double*, double*, double*, double*, double, double) + +;CHECK:define internal { double, double } @diffeg(double* noalias %A, double* %"A'", double* noalias %B, double* %"B'", double* %C, double* %"C'", double %alpha, double %beta) { +;CHECK-NEXT:entry: +;CHECK-NEXT: call void @cblas_dgemm(i32 101, i32 111, i32 111, i32 2, i32 2, i32 3, double %alpha, double* nocapture readonly %A, i32 3, double* nocapture readonly %B, i32 2, double %beta, double* %C, i32 2) +;CHECK-NEXT: call void @cblas_dgemm(i32 101, i32 111, i32 112, i32 2, i32 3, i32 2, double %alpha, double* nocapture readonly %"C'", i32 2, double* nocapture readonly %B, i32 2, double 1.000000e+00, double* %"A'", i32 3) +;CHECK-NEXT: call void @cblas_dgemm(i32 101, i32 112, i32 111, i32 3, i32 2, i32 2, double %alpha, double* nocapture readonly %A, i32 3, double* nocapture readonly %"C'", i32 2, double 1.000000e+00, double* %"B'", i32 2) +;CHECK-NEXT: call void @cblas_dscal(i32 4, double %beta, double* %"C'", i32 1) +;CHECK-NEXT: ret { double, double } zeroinitializer +;CHECK-NEXT:} diff --git a/enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_row_nomod_transa.ll b/enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_row_nomod_transa.ll new file mode 100644 index 000000000000..330c890b25ed --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_row_nomod_transa.ll @@ -0,0 +1,108 @@ +;RUN: %opt < %s %loadEnzyme -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +;#include +; +;extern double __enzyme_autodiff(void *, double *, double *, double *, double *, +; double *, double *, double, double); +; +;void g(double *restrict A, double *restrict B, double *C, double alpha, double beta) { +; cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, 2, 4, 3, alpha, A, 2, B, +; 4, beta, C, 4); +;} +; +;int main() { +; double A[] = {1, 4, 2, 5, 3, 6}; +; double B[] = {21, 0.9, 30, 33, 0.3, 1, 31, 34, 0.7, 26, 32, 35}; +; double C[] = {0.00, 0.00, 0.0, 0.0, 0.00, 0.00, 0.0, 0.0}; +; double A1[] = {0, 0, 0, 0, 0, 0}; +; double B1[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; +; double C1[] = {1, 1, 1, 1, 1, 1, 1, 1}; +; __enzyme_autodiff((void *)g, A, A1, B, B1, C, C1, 2.0, 2.0); +;} + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@__const.main.B = private unnamed_addr constant [12 x double] [double 2.100000e+01, double 9.000000e-01, double 3.000000e+01, double 3.300000e+01, double 3.000000e-01, double 1.000000e+00, double 3.100000e+01, double 3.400000e+01, double 0x3FE6666666666666, double 2.600000e+01, double 3.200000e+01, double 3.500000e+01], align 16 +@__const.main.C1 = private unnamed_addr constant [8 x double] [double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00], align 16 + +define dso_local void @g(double* noalias %A, double* noalias %B, double* %C, double %alpha, double %beta) { +entry: + %A.addr = alloca double*, align 8 + %B.addr = alloca double*, align 8 + %C.addr = alloca double*, align 8 + %alpha.addr = alloca double, align 8 + %beta.addr = alloca double, align 8 + store double* %A, double** %A.addr, align 8 + store double* %B, double** %B.addr, align 8 + store double* %C, double** %C.addr, align 8 + store double %alpha, double* %alpha.addr, align 8 + store double %beta, double* %beta.addr, align 8 + %0 = load double, double* %alpha.addr, align 8 + %1 = load double*, double** %A.addr, align 8 + %2 = load double*, double** %B.addr, align 8 + %3 = load double, double* %beta.addr, align 8 + %4 = load double*, double** %C.addr, align 8 + call void @cblas_dgemm(i32 101, i32 112, i32 111, i32 2, i32 4, i32 3, double %0, double* %1, i32 2, double* %2, i32 4, double %3, double* %4, i32 4) + ret void +} + +declare dso_local void @cblas_dgemm(i32, i32, i32, i32, i32, i32, double, double*, i32, double*, i32, double, double*, i32) + +define dso_local i32 @main() { +entry: + %A = alloca [6 x double], align 16 + %B = alloca [12 x double], align 16 + %C = alloca [8 x double], align 16 + %A1 = alloca [6 x double], align 16 + %B1 = alloca [12 x double], align 16 + %C1 = alloca [8 x double], align 16 + %0 = bitcast [6 x double]* %A to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %0, i8 0, i64 48, i1 false) + %1 = bitcast i8* %0 to [6 x double]* + %2 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 0 + store double 1.000000e+00, double* %2, align 16 + %3 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 1 + store double 4.000000e+00, double* %3, align 8 + %4 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 2 + store double 2.000000e+00, double* %4, align 16 + %5 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 3 + store double 5.000000e+00, double* %5, align 8 + %6 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 4 + store double 3.000000e+00, double* %6, align 16 + %7 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 5 + store double 6.000000e+00, double* %7, align 8 + %8 = bitcast [12 x double]* %B to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %8, i8* align 16 bitcast ([12 x double]* @__const.main.B to i8*), i64 96, i1 false) + %9 = bitcast [8 x double]* %C to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %9, i8 0, i64 64, i1 false) + %10 = bitcast [6 x double]* %A1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %10, i8 0, i64 48, i1 false) + %11 = bitcast [12 x double]* %B1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %11, i8 0, i64 96, i1 false) + %12 = bitcast [8 x double]* %C1 to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %12, i8* align 16 bitcast ([8 x double]* @__const.main.C1 to i8*), i64 64, i1 false) + %arraydecay = getelementptr inbounds [6 x double], [6 x double]* %A, i32 0, i32 0 + %arraydecay1 = getelementptr inbounds [6 x double], [6 x double]* %A1, i32 0, i32 0 + %arraydecay2 = getelementptr inbounds [12 x double], [12 x double]* %B, i32 0, i32 0 + %arraydecay3 = getelementptr inbounds [12 x double], [12 x double]* %B1, i32 0, i32 0 + %arraydecay4 = getelementptr inbounds [8 x double], [8 x double]* %C, i32 0, i32 0 + %arraydecay5 = getelementptr inbounds [8 x double], [8 x double]* %C1, i32 0, i32 0 + %call = call double @__enzyme_autodiff(i8* bitcast (void (double*, double*, double*, double, double)* @g to i8*), double* %arraydecay, double* %arraydecay1, double* %arraydecay2, double* %arraydecay3, double* %arraydecay4, double* %arraydecay5, double 2.000000e+00, double 2.000000e+00) + ret i32 0 +} + +declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1) + +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1) + +declare dso_local double @__enzyme_autodiff(i8*, double*, double*, double*, double*, double*, double*, double, double) + +;CHECK:define internal { double, double } @diffeg(double* noalias %A, double* %"A'", double* noalias %B, double* %"B'", double* %C, double* %"C'", double %alpha, double %beta) { +;CHECK-NEXT:entry: +;CHECK-NEXT: call void @cblas_dgemm(i32 101, i32 112, i32 111, i32 2, i32 4, i32 3, double %alpha, double* nocapture readonly %A, i32 2, double* nocapture readonly %B, i32 4, double %beta, double* %C, i32 4) +;CHECK-NEXT: call void @cblas_dgemm(i32 101, i32 111, i32 112, i32 2, i32 3, i32 4, double %alpha, double* nocapture readonly %"C'", i32 4, double* nocapture readonly %B, i32 4, double 1.000000e+00, double* %"A'", i32 3) +;CHECK-NEXT: call void @cblas_dgemm(i32 101, i32 111, i32 111, i32 3, i32 4, i32 2, double %alpha, double* nocapture readonly %A, i32 2, double* nocapture readonly %"C'", i32 4, double 1.000000e+00, double* %"B'", i32 4) +;CHECK-NEXT: call void @cblas_dscal(i32 8, double %beta, double* %"C'", i32 1) +;CHECK-NEXT: ret { double, double } zeroinitializer +;CHECK-NEXT:} diff --git a/enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_row_nomod_transb.ll b/enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_row_nomod_transb.ll new file mode 100644 index 000000000000..0e68253846df --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_row_nomod_transb.ll @@ -0,0 +1,116 @@ +;RUN: %opt < %s %loadEnzyme -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +;#include +; +;extern double __enzyme_autodiff(void *, double *, double *, double *, double *, double *, double*, double, double); +; +;void g(double *restrict A, double *restrict B, double *C, double alpha, double beta) { +; cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, 2, 4, 3, alpha, A, 3, B, 3, beta, C, 4); +;} +; +;int main() { +; double A[] = { 1, 2, 3, +; 4, 5, 6}; +; double B[] = {21, 0.3, 0.7, +; 0.9, 1, 26, +; 30, 31, 32, +; 33, 34, 35}; +; double C[] = { 0.00, 0.00, 0.0, 0.0, +; 0.00, 0.00, 0.0, 0.0}; +; double A1[] = {0, 0, 0, +; 0, 0, 0}; +; double B1[] = {0, 0, 0, +; 0, 0, 0, +; 0, 0, 0, +; 0, 0, 0}; +; double C1[] = {1, 1, 1, 1, +; 1, 1, 1, 1}; +; __enzyme_autodiff((void*)g, A, A1, B, B1, C, C1, 2.0, 2.0); +;} + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@__const.main.B = private unnamed_addr constant [12 x double] [double 2.100000e+01, double 3.000000e-01, double 0x3FE6666666666666, double 9.000000e-01, double 1.000000e+00, double 2.600000e+01, double 3.000000e+01, double 3.100000e+01, double 3.200000e+01, double 3.300000e+01, double 3.400000e+01, double 3.500000e+01], align 16 +@__const.main.C1 = private unnamed_addr constant [8 x double] [double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00], align 16 + +define dso_local void @g(double* noalias %A, double* noalias %B, double* %C, double %alpha, double %beta) { +entry: + %A.addr = alloca double*, align 8 + %B.addr = alloca double*, align 8 + %C.addr = alloca double*, align 8 + %alpha.addr = alloca double, align 8 + %beta.addr = alloca double, align 8 + store double* %A, double** %A.addr, align 8 + store double* %B, double** %B.addr, align 8 + store double* %C, double** %C.addr, align 8 + store double %alpha, double* %alpha.addr, align 8 + store double %beta, double* %beta.addr, align 8 + %0 = load double, double* %alpha.addr, align 8 + %1 = load double*, double** %A.addr, align 8 + %2 = load double*, double** %B.addr, align 8 + %3 = load double, double* %beta.addr, align 8 + %4 = load double*, double** %C.addr, align 8 + call void @cblas_dgemm(i32 101, i32 111, i32 112, i32 2, i32 4, i32 3, double %0, double* %1, i32 3, double* %2, i32 3, double %3, double* %4, i32 4) + ret void +} + +declare dso_local void @cblas_dgemm(i32, i32, i32, i32, i32, i32, double, double*, i32, double*, i32, double, double*, i32) + +define dso_local i32 @main() { +entry: + %A = alloca [6 x double], align 16 + %B = alloca [12 x double], align 16 + %C = alloca [8 x double], align 16 + %A1 = alloca [6 x double], align 16 + %B1 = alloca [12 x double], align 16 + %C1 = alloca [8 x double], align 16 + %0 = bitcast [6 x double]* %A to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %0, i8 0, i64 48, i1 false) + %1 = bitcast i8* %0 to [6 x double]* + %2 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 0 + store double 1.000000e+00, double* %2, align 16 + %3 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 1 + store double 2.000000e+00, double* %3, align 8 + %4 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 2 + store double 3.000000e+00, double* %4, align 16 + %5 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 3 + store double 4.000000e+00, double* %5, align 8 + %6 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 4 + store double 5.000000e+00, double* %6, align 16 + %7 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 5 + store double 6.000000e+00, double* %7, align 8 + %8 = bitcast [12 x double]* %B to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %8, i8* align 16 bitcast ([12 x double]* @__const.main.B to i8*), i64 96, i1 false) + %9 = bitcast [8 x double]* %C to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %9, i8 0, i64 64, i1 false) + %10 = bitcast [6 x double]* %A1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %10, i8 0, i64 48, i1 false) + %11 = bitcast [12 x double]* %B1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %11, i8 0, i64 96, i1 false) + %12 = bitcast [8 x double]* %C1 to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %12, i8* align 16 bitcast ([8 x double]* @__const.main.C1 to i8*), i64 64, i1 false) + %arraydecay = getelementptr inbounds [6 x double], [6 x double]* %A, i32 0, i32 0 + %arraydecay1 = getelementptr inbounds [6 x double], [6 x double]* %A1, i32 0, i32 0 + %arraydecay2 = getelementptr inbounds [12 x double], [12 x double]* %B, i32 0, i32 0 + %arraydecay3 = getelementptr inbounds [12 x double], [12 x double]* %B1, i32 0, i32 0 + %arraydecay4 = getelementptr inbounds [8 x double], [8 x double]* %C, i32 0, i32 0 + %arraydecay5 = getelementptr inbounds [8 x double], [8 x double]* %C1, i32 0, i32 0 + %call = call double @__enzyme_autodiff(i8* bitcast (void (double*, double*, double*, double, double)* @g to i8*), double* %arraydecay, double* %arraydecay1, double* %arraydecay2, double* %arraydecay3, double* %arraydecay4, double* %arraydecay5, double 2.000000e+00, double 2.000000e+00) + ret i32 0 +} + +declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1) + +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1) + +declare dso_local double @__enzyme_autodiff(i8*, double*, double*, double*, double*, double*, double*, double, double) + +;CHECK:define internal { double, double } @diffeg(double* noalias %A, double* %"A'", double* noalias %B, double* %"B'", double* %C, double* %"C'", double %alpha, double %beta) { +;CHECK-NEXT:entry: +;CHECK-NEXT: call void @cblas_dgemm(i32 101, i32 111, i32 112, i32 2, i32 4, i32 3, double %alpha, double* nocapture readonly %A, i32 3, double* nocapture readonly %B, i32 3, double %beta, double* %C, i32 4) +;CHECK-NEXT: call void @cblas_dgemm(i32 101, i32 111, i32 111, i32 2, i32 3, i32 4, double %alpha, double* nocapture readonly %"C'", i32 4, double* nocapture readonly %B, i32 3, double 1.000000e+00, double* %"A'", i32 3) +;CHECK-NEXT: call void @cblas_dgemm(i32 101, i32 112, i32 111, i32 3, i32 4, i32 2, double %alpha, double* nocapture readonly %A, i32 3, double* nocapture readonly %"C'", i32 4, double 1.000000e+00, double* %"B'", i32 4) +;CHECK-NEXT: call void @cblas_dscal(i32 8, double %beta, double* %"C'", i32 1) +;CHECK-NEXT: ret { double, double } zeroinitializer +;CHECK-NEXT:} diff --git a/enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_row_nomod_transboth.ll b/enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_row_nomod_transboth.ll new file mode 100644 index 000000000000..12cb32d2be59 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/blas/cblas_dgemm_row_nomod_transboth.ll @@ -0,0 +1,116 @@ +;RUN: %opt < %s %loadEnzyme -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +;#include +; +;extern double __enzyme_autodiff(void *, double *, double *, double *, double *, double *, double*, double, double); +; +;void g(double *restrict A, double *restrict B, double *C, double alpha, double beta) { +; cblas_dgemm(CblasRowMajor, CblasTrans, CblasTrans, 2, 4, 3, alpha, A, 2, B, 3, beta, C, 4); +;} +; +;int main() { +; double A[] = {1, 4, +; 2, 5, +; 3, 6}; +; double B[] = {21, 0.3, 0.7, +; 0.9, 1, 26, +; 30, 31, 32, +; 33, 34, 35}; +; double C[] = { 0.00, 0.00, 0.0, 0.0, +; 0.00, 0.00, 0.0, 0.0}; +; double A1[] = {0, 0, 0, +; 0, 0, 0}; +; double B1[] = {0, 0, 0, 0, +; 0, 0, 0, 0, +; 0, 0, 0, 0}; +; double C1[] = {1, 1, 1, 1, +; 1, 1, 1, 1}; +; __enzyme_autodiff((void*)g, A, A1, B, B1, C, C1, 2.0, 2.0); +;} + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@__const.main.B = private unnamed_addr constant [12 x double] [double 2.100000e+01, double 3.000000e-01, double 0x3FE6666666666666, double 9.000000e-01, double 1.000000e+00, double 2.600000e+01, double 3.000000e+01, double 3.100000e+01, double 3.200000e+01, double 3.300000e+01, double 3.400000e+01, double 3.500000e+01], align 16 +@__const.main.C1 = private unnamed_addr constant [8 x double] [double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00, double 1.000000e+00], align 16 + +define dso_local void @g(double* noalias %A, double* noalias %B, double* %C, double %alpha, double %beta) { +entry: + %A.addr = alloca double*, align 8 + %B.addr = alloca double*, align 8 + %C.addr = alloca double*, align 8 + %alpha.addr = alloca double, align 8 + %beta.addr = alloca double, align 8 + store double* %A, double** %A.addr, align 8 + store double* %B, double** %B.addr, align 8 + store double* %C, double** %C.addr, align 8 + store double %alpha, double* %alpha.addr, align 8 + store double %beta, double* %beta.addr, align 8 + %0 = load double, double* %alpha.addr, align 8 + %1 = load double*, double** %A.addr, align 8 + %2 = load double*, double** %B.addr, align 8 + %3 = load double, double* %beta.addr, align 8 + %4 = load double*, double** %C.addr, align 8 + call void @cblas_dgemm(i32 101, i32 112, i32 112, i32 2, i32 4, i32 3, double %0, double* %1, i32 2, double* %2, i32 3, double %3, double* %4, i32 4) + ret void +} + +declare dso_local void @cblas_dgemm(i32, i32, i32, i32, i32, i32, double, double*, i32, double*, i32, double, double*, i32) + +define dso_local i32 @main() { +entry: + %A = alloca [6 x double], align 16 + %B = alloca [12 x double], align 16 + %C = alloca [8 x double], align 16 + %A1 = alloca [6 x double], align 16 + %B1 = alloca [12 x double], align 16 + %C1 = alloca [8 x double], align 16 + %0 = bitcast [6 x double]* %A to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %0, i8 0, i64 48, i1 false) + %1 = bitcast i8* %0 to [6 x double]* + %2 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 0 + store double 1.000000e+00, double* %2, align 16 + %3 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 1 + store double 4.000000e+00, double* %3, align 8 + %4 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 2 + store double 2.000000e+00, double* %4, align 16 + %5 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 3 + store double 5.000000e+00, double* %5, align 8 + %6 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 4 + store double 3.000000e+00, double* %6, align 16 + %7 = getelementptr inbounds [6 x double], [6 x double]* %1, i32 0, i32 5 + store double 6.000000e+00, double* %7, align 8 + %8 = bitcast [12 x double]* %B to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %8, i8* align 16 bitcast ([12 x double]* @__const.main.B to i8*), i64 96, i1 false) + %9 = bitcast [8 x double]* %C to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %9, i8 0, i64 64, i1 false) + %10 = bitcast [6 x double]* %A1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %10, i8 0, i64 48, i1 false) + %11 = bitcast [12 x double]* %B1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %11, i8 0, i64 96, i1 false) + %12 = bitcast [8 x double]* %C1 to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %12, i8* align 16 bitcast ([8 x double]* @__const.main.C1 to i8*), i64 64, i1 false) + %arraydecay = getelementptr inbounds [6 x double], [6 x double]* %A, i32 0, i32 0 + %arraydecay1 = getelementptr inbounds [6 x double], [6 x double]* %A1, i32 0, i32 0 + %arraydecay2 = getelementptr inbounds [12 x double], [12 x double]* %B, i32 0, i32 0 + %arraydecay3 = getelementptr inbounds [12 x double], [12 x double]* %B1, i32 0, i32 0 + %arraydecay4 = getelementptr inbounds [8 x double], [8 x double]* %C, i32 0, i32 0 + %arraydecay5 = getelementptr inbounds [8 x double], [8 x double]* %C1, i32 0, i32 0 + %call = call double @__enzyme_autodiff(i8* bitcast (void (double*, double*, double*, double, double)* @g to i8*), double* %arraydecay, double* %arraydecay1, double* %arraydecay2, double* %arraydecay3, double* %arraydecay4, double* %arraydecay5, double 2.000000e+00, double 2.000000e+00) + ret i32 0 +} + +declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1) + +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1) + +declare dso_local double @__enzyme_autodiff(i8*, double*, double*, double*, double*, double*, double*, double, double) + +;CHECK:define internal { double, double } @diffeg(double* noalias %A, double* %"A'", double* noalias %B, double* %"B'", double* %C, double* %"C'", double %alpha, double %beta) { +;CHECK-NEXT:entry: +;CHECK-NEXT: call void @cblas_dgemm(i32 101, i32 112, i32 112, i32 2, i32 4, i32 3, double %alpha, double* nocapture readonly %A, i32 2, double* nocapture readonly %B, i32 3, double %beta, double* %C, i32 4) +;CHECK-NEXT: call void @cblas_dgemm(i32 101, i32 111, i32 111, i32 2, i32 3, i32 4, double %alpha, double* nocapture readonly %"C'", i32 4, double* nocapture readonly %B, i32 3, double 1.000000e+00, double* %"A'", i32 3) +;CHECK-NEXT: call void @cblas_dgemm(i32 101, i32 111, i32 111, i32 3, i32 4, i32 2, double %alpha, double* nocapture readonly %A, i32 2, double* nocapture readonly %"C'", i32 4, double 1.000000e+00, double* %"B'", i32 4) +;CHECK-NEXT: call void @cblas_dscal(i32 8, double %beta, double* %"C'", i32 1) +;CHECK-NEXT: ret { double, double } zeroinitializer +;CHECK-NEXT:} diff --git a/enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_col_nomod.ll b/enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_col_nomod.ll new file mode 100644 index 000000000000..5300f5404f29 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_col_nomod.ll @@ -0,0 +1,109 @@ +;RUN: %opt < %s %loadEnzyme -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +;#include +; +;extern float __enzyme_autodiff(void *, float *, float *, float *, float *, float *, float*, float, float); +; +;void g(float *restrict A, float *restrict B, float *C, float alpha, float beta) { +; cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, 4, 3, 2, alpha, A, 4, B, 2, beta, C, 4); +;} +; +;int main() { +; float A[] = {0.11, 0.12, 0.13, 0.14, +; 0.21, 0.22, 0.23, 0.24}; +; float B[] = {1011, 1012, +; 1021, 1022, +; 1031, 1032}; +; float C[] = {0.00, 0.00, 0.00, 0.00, +; 0.00, 0.00, 0.00, 0.00, +; 0.00, 0.00, 0.00, 0.00}; +; float A1[] = {0, 0, 0, 0, 0, 0, 0, 0}; +; float B1[] = {0, 0, 0, 0, 0, 0}; +; float C1[] = {1, 3, 7, 11, +; 0.00, 0.00, 0.00, 0.00, +; 0.00, 0.00, 0.00, 0.00}; +; __enzyme_autodiff((void*)g, A, A1, B, B1, C, C1, 2.0, 3.0); +;} + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@__const.main.A = private unnamed_addr constant [8 x float] [float 0x3FBC28F5C0000000, float 0x3FBEB851E0000000, float 0x3FC0A3D700000000, float 0x3FC1EB8520000000, float 0x3FCAE147A0000000, float 0x3FCC28F5C0000000, float 0x3FCD70A3E0000000, float 0x3FCEB851E0000000], align 16 +@__const.main.B = private unnamed_addr constant [6 x float] [float 1.011000e+03, float 1.012000e+03, float 1.021000e+03, float 1.022000e+03, float 1.031000e+03, float 1.032000e+03], align 16 + +define dso_local void @g(float* %A, float* %B, float* %C, float %alpha, float %beta) { +entry: + %A.addr = alloca float*, align 8 + %B.addr = alloca float*, align 8 + %C.addr = alloca float*, align 8 + %alpha.addr = alloca float, align 4 + %beta.addr = alloca float, align 4 + store float* %A, float** %A.addr, align 8 + store float* %B, float** %B.addr, align 8 + store float* %C, float** %C.addr, align 8 + store float %alpha, float* %alpha.addr, align 4 + store float %beta, float* %beta.addr, align 4 + %0 = load float, float* %alpha.addr, align 4 + %1 = load float*, float** %A.addr, align 8 + %2 = load float*, float** %B.addr, align 8 + %3 = load float, float* %beta.addr, align 4 + %4 = load float*, float** %C.addr, align 8 + call void @cblas_sgemm(i32 102, i32 111, i32 111, i32 4, i32 3, i32 2, float %0, float* %1, i32 4, float* %2, i32 2, float %3, float* %4, i32 4) + ret void +} + +declare dso_local void @cblas_sgemm(i32, i32, i32, i32, i32, i32, float, float*, i32, float*, i32, float, float*, i32) + +define dso_local i32 @main() { +entry: + %A = alloca [8 x float], align 16 + %B = alloca [6 x float], align 16 + %C = alloca [12 x float], align 16 + %A1 = alloca [8 x float], align 16 + %B1 = alloca [6 x float], align 16 + %C1 = alloca [12 x float], align 16 + %0 = bitcast [8 x float]* %A to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([8 x float]* @__const.main.A to i8*), i64 32, i1 false) + %1 = bitcast [6 x float]* %B to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %1, i8* align 16 bitcast ([6 x float]* @__const.main.B to i8*), i64 24, i1 false) + %2 = bitcast [12 x float]* %C to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %2, i8 0, i64 48, i1 false) + %3 = bitcast [8 x float]* %A1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %3, i8 0, i64 32, i1 false) + %4 = bitcast [6 x float]* %B1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %4, i8 0, i64 24, i1 false) + %5 = bitcast [12 x float]* %C1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %5, i8 0, i64 48, i1 false) + %6 = bitcast i8* %5 to <{ float, float, float, float, [8 x float] }>* + %7 = getelementptr inbounds <{ float, float, float, float, [8 x float] }>, <{ float, float, float, float, [8 x float] }>* %6, i32 0, i32 0 + store float 1.000000e+00, float* %7, align 16 + %8 = getelementptr inbounds <{ float, float, float, float, [8 x float] }>, <{ float, float, float, float, [8 x float] }>* %6, i32 0, i32 1 + store float 3.000000e+00, float* %8, align 4 + %9 = getelementptr inbounds <{ float, float, float, float, [8 x float] }>, <{ float, float, float, float, [8 x float] }>* %6, i32 0, i32 2 + store float 7.000000e+00, float* %9, align 8 + %10 = getelementptr inbounds <{ float, float, float, float, [8 x float] }>, <{ float, float, float, float, [8 x float] }>* %6, i32 0, i32 3 + store float 1.100000e+01, float* %10, align 4 + %arraydecay = getelementptr inbounds [8 x float], [8 x float]* %A, i32 0, i32 0 + %arraydecay1 = getelementptr inbounds [8 x float], [8 x float]* %A1, i32 0, i32 0 + %arraydecay2 = getelementptr inbounds [6 x float], [6 x float]* %B, i32 0, i32 0 + %arraydecay3 = getelementptr inbounds [6 x float], [6 x float]* %B1, i32 0, i32 0 + %arraydecay4 = getelementptr inbounds [12 x float], [12 x float]* %C, i32 0, i32 0 + %arraydecay5 = getelementptr inbounds [12 x float], [12 x float]* %C1, i32 0, i32 0 + %call = call float @__enzyme_autodiff(i8* bitcast (void (float*, float*, float*, float, float)* @g to i8*), float* %arraydecay, float* %arraydecay1, float* %arraydecay2, float* %arraydecay3, float* %arraydecay4, float* %arraydecay5, float 2.000000e+00, float 3.000000e+00) + ret i32 0 +} + +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1) + +declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1) + +declare dso_local float @__enzyme_autodiff(i8*, float*, float*, float*, float*, float*, float*, float, float) + +;CHECK:define internal { float, float } @diffeg(float* %A, float* %"A'", float* %B, float* %"B'", float* %C, float* %"C'", float %alpha, float %beta) { +;CHECK-NEXT:entry: +;CHECK-NEXT: call void @cblas_sgemm(i32 102, i32 111, i32 111, i32 4, i32 3, i32 2, float %alpha, float* nocapture readonly %A, i32 4, float* nocapture readonly %B, i32 2, float %beta, float* %C, i32 4) +;CHECK-NEXT: call void @cblas_sgemm(i32 102, i32 111, i32 112, i32 4, i32 2, i32 3, float %alpha, float* nocapture readonly %"C'", i32 4, float* nocapture readonly %B, i32 3, float 1.000000e+00, float* %"A'", i32 4) +;CHECK-NEXT: call void @cblas_sgemm(i32 102, i32 112, i32 111, i32 2, i32 3, i32 4, float %alpha, float* nocapture readonly %A, i32 4, float* nocapture readonly %"C'", i32 4, float 1.000000e+00, float* %"B'", i32 2) +;CHECK-NEXT: call void @cblas_sscal(i32 12, float %beta, float* %"C'", i32 1) +;CHECK-NEXT: ret { float, float } zeroinitializer +;CHECK-NEXT:} diff --git a/enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_col_nomod_transa.ll b/enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_col_nomod_transa.ll new file mode 100644 index 000000000000..e10819b449cf --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_col_nomod_transa.ll @@ -0,0 +1,111 @@ +;RUN: %opt < %s %loadEnzyme -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +;#include +; +;extern float __enzyme_autodiff(void *, float *, float *, float *, float *, float *, float*, float, float); +; +;void g(float *restrict A, float *restrict B, float *C, float alpha, float beta) { +; cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, 4, 3, 2, alpha, A, 4, B, 2, beta, C, 4); +;} +; +;int main() { +; float A[] = {0.11, 0.21, +; 0.12, 0.22, +; 0.13, 0.23, +; 0.14, 0.24}; +; float B[] = {1011, 1012, +; 1021, 1022, +; 1031, 1032}; +; float C[] = {0.00, 0.00, 0.00, 0.00, +; 0.00, 0.00, 0.00, 0.00, +; 0.00, 0.00, 0.00, 0.00}; +; float A1[] = {0, 0, 0, 0, 0, 0, 0, 0}; +; float B1[] = {0, 0, 0, 0, 0, 0}; +; float C1[] = {1, 3, 7, 11, +; 0.00, 0.00, 0.00, 0.00, +; 0.00, 0.00, 0.00, 0.00}; +; __enzyme_autodiff((void*)g, A, A1, B, B1, C, C1, 2.0, 3.0); +;} + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@__const.main.A = private unnamed_addr constant [8 x float] [float 0x3FBC28F5C0000000, float 0x3FCAE147A0000000, float 0x3FBEB851E0000000, float 0x3FCC28F5C0000000, float 0x3FC0A3D700000000, float 0x3FCD70A3E0000000, float 0x3FC1EB8520000000, float 0x3FCEB851E0000000], align 16 +@__const.main.B = private unnamed_addr constant [6 x float] [float 1.011000e+03, float 1.012000e+03, float 1.021000e+03, float 1.022000e+03, float 1.031000e+03, float 1.032000e+03], align 16 + +define dso_local void @g(float* %A, float* %B, float* %C, float %alpha, float %beta) { +entry: + %A.addr = alloca float*, align 8 + %B.addr = alloca float*, align 8 + %C.addr = alloca float*, align 8 + %alpha.addr = alloca float, align 4 + %beta.addr = alloca float, align 4 + store float* %A, float** %A.addr, align 8 + store float* %B, float** %B.addr, align 8 + store float* %C, float** %C.addr, align 8 + store float %alpha, float* %alpha.addr, align 4 + store float %beta, float* %beta.addr, align 4 + %0 = load float, float* %alpha.addr, align 4 + %1 = load float*, float** %A.addr, align 8 + %2 = load float*, float** %B.addr, align 8 + %3 = load float, float* %beta.addr, align 4 + %4 = load float*, float** %C.addr, align 8 + call void @cblas_sgemm(i32 102, i32 111, i32 111, i32 4, i32 3, i32 2, float %0, float* %1, i32 4, float* %2, i32 2, float %3, float* %4, i32 4) + ret void +} + +declare dso_local void @cblas_sgemm(i32, i32, i32, i32, i32, i32, float, float*, i32, float*, i32, float, float*, i32) + +define dso_local i32 @main() { +entry: + %A = alloca [8 x float], align 16 + %B = alloca [6 x float], align 16 + %C = alloca [12 x float], align 16 + %A1 = alloca [8 x float], align 16 + %B1 = alloca [6 x float], align 16 + %C1 = alloca [12 x float], align 16 + %0 = bitcast [8 x float]* %A to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([8 x float]* @__const.main.A to i8*), i64 32, i1 false) + %1 = bitcast [6 x float]* %B to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %1, i8* align 16 bitcast ([6 x float]* @__const.main.B to i8*), i64 24, i1 false) + %2 = bitcast [12 x float]* %C to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %2, i8 0, i64 48, i1 false) + %3 = bitcast [8 x float]* %A1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %3, i8 0, i64 32, i1 false) + %4 = bitcast [6 x float]* %B1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %4, i8 0, i64 24, i1 false) + %5 = bitcast [12 x float]* %C1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %5, i8 0, i64 48, i1 false) + %6 = bitcast i8* %5 to <{ float, float, float, float, [8 x float] }>* + %7 = getelementptr inbounds <{ float, float, float, float, [8 x float] }>, <{ float, float, float, float, [8 x float] }>* %6, i32 0, i32 0 + store float 1.000000e+00, float* %7, align 16 + %8 = getelementptr inbounds <{ float, float, float, float, [8 x float] }>, <{ float, float, float, float, [8 x float] }>* %6, i32 0, i32 1 + store float 3.000000e+00, float* %8, align 4 + %9 = getelementptr inbounds <{ float, float, float, float, [8 x float] }>, <{ float, float, float, float, [8 x float] }>* %6, i32 0, i32 2 + store float 7.000000e+00, float* %9, align 8 + %10 = getelementptr inbounds <{ float, float, float, float, [8 x float] }>, <{ float, float, float, float, [8 x float] }>* %6, i32 0, i32 3 + store float 1.100000e+01, float* %10, align 4 + %arraydecay = getelementptr inbounds [8 x float], [8 x float]* %A, i32 0, i32 0 + %arraydecay1 = getelementptr inbounds [8 x float], [8 x float]* %A1, i32 0, i32 0 + %arraydecay2 = getelementptr inbounds [6 x float], [6 x float]* %B, i32 0, i32 0 + %arraydecay3 = getelementptr inbounds [6 x float], [6 x float]* %B1, i32 0, i32 0 + %arraydecay4 = getelementptr inbounds [12 x float], [12 x float]* %C, i32 0, i32 0 + %arraydecay5 = getelementptr inbounds [12 x float], [12 x float]* %C1, i32 0, i32 0 + %call = call float @__enzyme_autodiff(i8* bitcast (void (float*, float*, float*, float, float)* @g to i8*), float* %arraydecay, float* %arraydecay1, float* %arraydecay2, float* %arraydecay3, float* %arraydecay4, float* %arraydecay5, float 2.000000e+00, float 3.000000e+00) + ret i32 0 +} + +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1) + +declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1) + +declare dso_local float @__enzyme_autodiff(i8*, float*, float*, float*, float*, float*, float*, float, float) + +;CHECK:define internal { float, float } @diffeg(float* %A, float* %"A'", float* %B, float* %"B'", float* %C, float* %"C'", float %alpha, float %beta) { +;CHECK-NEXT:entry: +;CHECK-NEXT: call void @cblas_sgemm(i32 102, i32 111, i32 111, i32 4, i32 3, i32 2, float %alpha, float* nocapture readonly %A, i32 4, float* nocapture readonly %B, i32 2, float %beta, float* %C, i32 4) +;CHECK-NEXT: call void @cblas_sgemm(i32 102, i32 111, i32 112, i32 4, i32 2, i32 3, float %alpha, float* nocapture readonly %"C'", i32 4, float* nocapture readonly %B, i32 3, float 1.000000e+00, float* %"A'", i32 4) +;CHECK-NEXT: call void @cblas_sgemm(i32 102, i32 112, i32 111, i32 2, i32 3, i32 4, float %alpha, float* nocapture readonly %A, i32 4, float* nocapture readonly %"C'", i32 4, float 1.000000e+00, float* %"B'", i32 2) +;CHECK-NEXT: call void @cblas_sscal(i32 12, float %beta, float* %"C'", i32 1) +;CHECK-NEXT: ret { float, float } zeroinitializer +;CHECK-NEXT:} diff --git a/enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_col_nomod_transb.ll b/enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_col_nomod_transb.ll new file mode 100644 index 000000000000..024983882427 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_col_nomod_transb.ll @@ -0,0 +1,100 @@ +;RUN: %opt < %s %loadEnzyme -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +;#include +; +;extern float __enzyme_autodiff(void *, float *, float *, float *, float *, float *, float*, float, float); +; +;void g(float *restrict A, float *restrict B, float *C, float alpha, float beta) { +; cblas_sgemm(CblasColMajor, CblasNoTrans, CblasTrans, 4, 3, 2, alpha, A, 4, B, 3, beta, C, 4); +;} +; +;int main() { +; float A[] = {0.11, 0.12, 0.13, 0.14, +; 0.21, 0.22, 0.23, 0.24}; +; float B[] = {1011, 1021, 1031, +; 1012, 1022, 1032}; +; float C[] = {0.00, 0.00, 0.00, 0.00, +; 0.00, 0.00, 0.00, 0.00, +; 0.00, 0.00, 0.00, 0.00}; +; float A1[] = {0, 0, 0, 0, 0, 0, 0, 0}; +; float B1[] = {0, 0, 0, 0, 0, 0}; +; float C1[] = {1, 1, 1, 1, +; 1, 1, 1, 1, +; 1, 1, 1, 1}; +; __enzyme_autodiff((void*)g, A, A1, B, B1, C, C1, 2.0, 3.0); +;} + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@__const.main.A = private unnamed_addr constant [8 x float] [float 0x3FBC28F5C0000000, float 0x3FBEB851E0000000, float 0x3FC0A3D700000000, float 0x3FC1EB8520000000, float 0x3FCAE147A0000000, float 0x3FCC28F5C0000000, float 0x3FCD70A3E0000000, float 0x3FCEB851E0000000], align 16 +@__const.main.B = private unnamed_addr constant [6 x float] [float 1.011000e+03, float 1.021000e+03, float 1.031000e+03, float 1.012000e+03, float 1.022000e+03, float 1.032000e+03], align 16 +@__const.main.C1 = private unnamed_addr constant [12 x float] [float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00], align 16 + +define dso_local void @g(float* %A, float* %B, float* %C, float %alpha, float %beta) { +entry: + %A.addr = alloca float*, align 8 + %B.addr = alloca float*, align 8 + %C.addr = alloca float*, align 8 + %alpha.addr = alloca float, align 4 + %beta.addr = alloca float, align 4 + store float* %A, float** %A.addr, align 8 + store float* %B, float** %B.addr, align 8 + store float* %C, float** %C.addr, align 8 + store float %alpha, float* %alpha.addr, align 4 + store float %beta, float* %beta.addr, align 4 + %0 = load float, float* %alpha.addr, align 4 + %1 = load float*, float** %A.addr, align 8 + %2 = load float*, float** %B.addr, align 8 + %3 = load float, float* %beta.addr, align 4 + %4 = load float*, float** %C.addr, align 8 + call void @cblas_sgemm(i32 102, i32 111, i32 112, i32 4, i32 3, i32 2, float %0, float* %1, i32 4, float* %2, i32 3, float %3, float* %4, i32 4) + ret void +} + +declare dso_local void @cblas_sgemm(i32, i32, i32, i32, i32, i32, float, float*, i32, float*, i32, float, float*, i32) + +define dso_local i32 @main() { +entry: + %A = alloca [8 x float], align 16 + %B = alloca [6 x float], align 16 + %C = alloca [12 x float], align 16 + %A1 = alloca [8 x float], align 16 + %B1 = alloca [6 x float], align 16 + %C1 = alloca [12 x float], align 16 + %0 = bitcast [8 x float]* %A to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([8 x float]* @__const.main.A to i8*), i64 32, i1 false) + %1 = bitcast [6 x float]* %B to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %1, i8* align 16 bitcast ([6 x float]* @__const.main.B to i8*), i64 24, i1 false) + %2 = bitcast [12 x float]* %C to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %2, i8 0, i64 48, i1 false) + %3 = bitcast [8 x float]* %A1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %3, i8 0, i64 32, i1 false) + %4 = bitcast [6 x float]* %B1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %4, i8 0, i64 24, i1 false) + %5 = bitcast [12 x float]* %C1 to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %5, i8* align 16 bitcast ([12 x float]* @__const.main.C1 to i8*), i64 48, i1 false) + %arraydecay = getelementptr inbounds [8 x float], [8 x float]* %A, i32 0, i32 0 + %arraydecay1 = getelementptr inbounds [8 x float], [8 x float]* %A1, i32 0, i32 0 + %arraydecay2 = getelementptr inbounds [6 x float], [6 x float]* %B, i32 0, i32 0 + %arraydecay3 = getelementptr inbounds [6 x float], [6 x float]* %B1, i32 0, i32 0 + %arraydecay4 = getelementptr inbounds [12 x float], [12 x float]* %C, i32 0, i32 0 + %arraydecay5 = getelementptr inbounds [12 x float], [12 x float]* %C1, i32 0, i32 0 + %call = call float @__enzyme_autodiff(i8* bitcast (void (float*, float*, float*, float, float)* @g to i8*), float* %arraydecay, float* %arraydecay1, float* %arraydecay2, float* %arraydecay3, float* %arraydecay4, float* %arraydecay5, float 2.000000e+00, float 3.000000e+00) + ret i32 0 +} + +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1) + +declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1) + +declare dso_local float @__enzyme_autodiff(i8*, float*, float*, float*, float*, float*, float*, float, float) + +;CHECK:define internal { float, float } @diffeg(float* %A, float* %"A'", float* %B, float* %"B'", float* %C, float* %"C'", float %alpha, float %beta) { +;CHECK-NEXT:entry: +;CHECK-NEXT: call void @cblas_sgemm(i32 102, i32 111, i32 112, i32 4, i32 3, i32 2, float %alpha, float* nocapture readonly %A, i32 4, float* nocapture readonly %B, i32 3, float %beta, float* %C, i32 4) +;CHECK-NEXT: call void @cblas_sgemm(i32 102, i32 111, i32 111, i32 4, i32 2, i32 3, float %alpha, float* nocapture readonly %"C'", i32 4, float* nocapture readonly %B, i32 3, float 1.000000e+00, float* %"A'", i32 4) +;CHECK-NEXT: call void @cblas_sgemm(i32 102, i32 112, i32 111, i32 2, i32 3, i32 4, float %alpha, float* nocapture readonly %A, i32 4, float* nocapture readonly %"C'", i32 4, float 1.000000e+00, float* %"B'", i32 2) +;CHECK-NEXT: call void @cblas_sscal(i32 12, float %beta, float* %"C'", i32 1) +;CHECK-NEXT: ret { float, float } zeroinitializer +;CHECK-NEXT:} diff --git a/enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_col_nomod_transboth.ll b/enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_col_nomod_transboth.ll new file mode 100644 index 000000000000..9a448a95a6ab --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_col_nomod_transboth.ll @@ -0,0 +1,102 @@ +;RUN: %opt < %s %loadEnzyme -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +;#include +; +;extern float __enzyme_autodiff(void *, float *, float *, float *, float *, float *, float*, float, float); +; +;void g(float *restrict A, float *restrict B, float *C, float alpha, float beta) { +; cblas_sgemm(CblasColMajor, CblasTrans, CblasTrans, 4, 3, 2, alpha, A, 2, B, 3, beta, C, 4); +;} +; +;int main() { +; float A[] = {0.11, 0.21, +; 0.12, 0.22, +; 0.13, 0.23, +; 0.14, 0.24}; +; float B[] = {1011, 1021, 1031, +; 1012, 1022, 1032}; +; float C[] = {0.00, 0.00, 0.00, 0.00, +; 0.00, 0.00, 0.00, 0.00, +; 0.00, 0.00, 0.00, 0.00}; +; float A1[] = {0, 0, 0, 0, 0, 0, 0, 0}; +; float B1[] = {0, 0, 0, 0, 0, 0}; +; float C1[] = {1, 1, 1, 1, +; 1, 1, 1, 1, +; 1, 1, 1, 1}; +; __enzyme_autodiff((void*)g, A, A1, B, B1, C, C1, 2.0, 3.0); +;} + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@__const.main.A = private unnamed_addr constant [8 x float] [float 0x3FBC28F5C0000000, float 0x3FCAE147A0000000, float 0x3FBEB851E0000000, float 0x3FCC28F5C0000000, float 0x3FC0A3D700000000, float 0x3FCD70A3E0000000, float 0x3FC1EB8520000000, float 0x3FCEB851E0000000], align 16 +@__const.main.B = private unnamed_addr constant [6 x float] [float 1.011000e+03, float 1.021000e+03, float 1.031000e+03, float 1.012000e+03, float 1.022000e+03, float 1.032000e+03], align 16 +@__const.main.C1 = private unnamed_addr constant [12 x float] [float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00], align 16 + +define dso_local void @g(float* %A, float* %B, float* %C, float %alpha, float %beta) { +entry: + %A.addr = alloca float*, align 8 + %B.addr = alloca float*, align 8 + %C.addr = alloca float*, align 8 + %alpha.addr = alloca float, align 4 + %beta.addr = alloca float, align 4 + store float* %A, float** %A.addr, align 8 + store float* %B, float** %B.addr, align 8 + store float* %C, float** %C.addr, align 8 + store float %alpha, float* %alpha.addr, align 4 + store float %beta, float* %beta.addr, align 4 + %0 = load float, float* %alpha.addr, align 4 + %1 = load float*, float** %A.addr, align 8 + %2 = load float*, float** %B.addr, align 8 + %3 = load float, float* %beta.addr, align 4 + %4 = load float*, float** %C.addr, align 8 + call void @cblas_sgemm(i32 102, i32 112, i32 112, i32 4, i32 3, i32 2, float %0, float* %1, i32 2, float* %2, i32 3, float %3, float* %4, i32 4) + ret void +} + +declare dso_local void @cblas_sgemm(i32, i32, i32, i32, i32, i32, float, float*, i32, float*, i32, float, float*, i32) + +define dso_local i32 @main() { +entry: + %A = alloca [8 x float], align 16 + %B = alloca [6 x float], align 16 + %C = alloca [12 x float], align 16 + %A1 = alloca [8 x float], align 16 + %B1 = alloca [6 x float], align 16 + %C1 = alloca [12 x float], align 16 + %0 = bitcast [8 x float]* %A to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([8 x float]* @__const.main.A to i8*), i64 32, i1 false) + %1 = bitcast [6 x float]* %B to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %1, i8* align 16 bitcast ([6 x float]* @__const.main.B to i8*), i64 24, i1 false) + %2 = bitcast [12 x float]* %C to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %2, i8 0, i64 48, i1 false) + %3 = bitcast [8 x float]* %A1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %3, i8 0, i64 32, i1 false) + %4 = bitcast [6 x float]* %B1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %4, i8 0, i64 24, i1 false) + %5 = bitcast [12 x float]* %C1 to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %5, i8* align 16 bitcast ([12 x float]* @__const.main.C1 to i8*), i64 48, i1 false) + %arraydecay = getelementptr inbounds [8 x float], [8 x float]* %A, i32 0, i32 0 + %arraydecay1 = getelementptr inbounds [8 x float], [8 x float]* %A1, i32 0, i32 0 + %arraydecay2 = getelementptr inbounds [6 x float], [6 x float]* %B, i32 0, i32 0 + %arraydecay3 = getelementptr inbounds [6 x float], [6 x float]* %B1, i32 0, i32 0 + %arraydecay4 = getelementptr inbounds [12 x float], [12 x float]* %C, i32 0, i32 0 + %arraydecay5 = getelementptr inbounds [12 x float], [12 x float]* %C1, i32 0, i32 0 + %call = call float @__enzyme_autodiff(i8* bitcast (void (float*, float*, float*, float, float)* @g to i8*), float* %arraydecay, float* %arraydecay1, float* %arraydecay2, float* %arraydecay3, float* %arraydecay4, float* %arraydecay5, float 2.000000e+00, float 3.000000e+00) + ret i32 0 +} + +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1) + +declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1) + +declare dso_local float @__enzyme_autodiff(i8*, float*, float*, float*, float*, float*, float*, float, float) + +;CHECK:define internal { float, float } @diffeg(float* %A, float* %"A'", float* %B, float* %"B'", float* %C, float* %"C'", float %alpha, float %beta) { +;CHECK-NEXT:entry: +;CHECK-NEXT: call void @cblas_sgemm(i32 102, i32 112, i32 112, i32 4, i32 3, i32 2, float %alpha, float* nocapture readonly %A, i32 2, float* nocapture readonly %B, i32 3, float %beta, float* %C, i32 4) +;CHECK-NEXT: call void @cblas_sgemm(i32 102, i32 111, i32 111, i32 4, i32 2, i32 3, float %alpha, float* nocapture readonly %"C'", i32 4, float* nocapture readonly %B, i32 3, float 1.000000e+00, float* %"A'", i32 4) +;CHECK-NEXT: call void @cblas_sgemm(i32 102, i32 111, i32 111, i32 2, i32 3, i32 4, float %alpha, float* nocapture readonly %A, i32 2, float* nocapture readonly %"C'", i32 4, float 1.000000e+00, float* %"B'", i32 2) +;CHECK-NEXT: call void @cblas_sscal(i32 12, float %beta, float* %"C'", i32 1) +;CHECK-NEXT: ret { float, float } zeroinitializer +;CHECK-NEXT:} diff --git a/enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_col_transboth_inactive_second.ll b/enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_col_transboth_inactive_second.ll new file mode 100644 index 000000000000..dbfecba8ab2b --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_col_transboth_inactive_second.ll @@ -0,0 +1,97 @@ +;RUN: %opt < %s %loadEnzyme -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +;#include +; +;extern float __enzyme_autodiff(void *, float *, float *, float *, float*, float, float); +; +;void g(float *restrict A, float *C, float alpha, float beta) { +; float B[] = {1011, 1021, 1031, +; 1012, 1022, 1032}; +; cblas_sgemm(CblasColMajor, CblasTrans, CblasTrans, 4, 3, 2, alpha, A, 2, B, 3, beta, C, 4); +;} +; +;int main() { +; float A[] = {0.11, 0.21, +; 0.12, 0.22, +; 0.13, 0.23, +; 0.14, 0.24}; +; float C[] = {0.00, 0.00, 0.00, 0.00, +; 0.00, 0.00, 0.00, 0.00, +; 0.00, 0.00, 0.00, 0.00}; +; float A1[] = {0, 0, 0, 0, 0, 0, 0, 0}; +; float C1[] = {1, 1, 1, 1, +; 1, 1, 1, 1, +; 1, 1, 1, 1}; +; __enzyme_autodiff((void*)g, A, A1, C, C1, 2.0, 3.0); +;} + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@__const.g.B = private unnamed_addr constant [6 x float] [float 1.011000e+03, float 1.021000e+03, float 1.031000e+03, float 1.012000e+03, float 1.022000e+03, float 1.032000e+03], align 16 +@__const.main.A = private unnamed_addr constant [8 x float] [float 0x3FBC28F5C0000000, float 0x3FCAE147A0000000, float 0x3FBEB851E0000000, float 0x3FCC28F5C0000000, float 0x3FC0A3D700000000, float 0x3FCD70A3E0000000, float 0x3FC1EB8520000000, float 0x3FCEB851E0000000], align 16 +@__const.main.C1 = private unnamed_addr constant [12 x float] [float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00], align 16 + +define dso_local void @g(float* noalias %A, float* %C, float %alpha, float %beta) { +entry: + %A.addr = alloca float*, align 8 + %C.addr = alloca float*, align 8 + %alpha.addr = alloca float, align 4 + %beta.addr = alloca float, align 4 + %B = alloca [6 x float], align 16 + store float* %A, float** %A.addr, align 8 + store float* %C, float** %C.addr, align 8 + store float %alpha, float* %alpha.addr, align 4 + store float %beta, float* %beta.addr, align 4 + %0 = bitcast [6 x float]* %B to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([6 x float]* @__const.g.B to i8*), i64 24, i1 false) + %1 = load float, float* %alpha.addr, align 4 + %2 = load float*, float** %A.addr, align 8 + %arraydecay = getelementptr inbounds [6 x float], [6 x float]* %B, i32 0, i32 0 + %3 = load float, float* %beta.addr, align 4 + %4 = load float*, float** %C.addr, align 8 + call void @cblas_sgemm(i32 102, i32 112, i32 112, i32 4, i32 3, i32 2, float %1, float* %2, i32 2, float* %arraydecay, i32 3, float %3, float* %4, i32 4) + ret void +} + +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1) + +declare dso_local void @cblas_sgemm(i32, i32, i32, i32, i32, i32, float, float*, i32, float*, i32, float, float*, i32) + +define dso_local i32 @main() { +entry: + %A = alloca [8 x float], align 16 + %C = alloca [12 x float], align 16 + %A1 = alloca [8 x float], align 16 + %C1 = alloca [12 x float], align 16 + %0 = bitcast [8 x float]* %A to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([8 x float]* @__const.main.A to i8*), i64 32, i1 false) + %1 = bitcast [12 x float]* %C to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %1, i8 0, i64 48, i1 false) + %2 = bitcast [8 x float]* %A1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %2, i8 0, i64 32, i1 false) + %3 = bitcast [12 x float]* %C1 to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %3, i8* align 16 bitcast ([12 x float]* @__const.main.C1 to i8*), i64 48, i1 false) + %arraydecay = getelementptr inbounds [8 x float], [8 x float]* %A, i32 0, i32 0 + %arraydecay1 = getelementptr inbounds [8 x float], [8 x float]* %A1, i32 0, i32 0 + %arraydecay2 = getelementptr inbounds [12 x float], [12 x float]* %C, i32 0, i32 0 + %arraydecay3 = getelementptr inbounds [12 x float], [12 x float]* %C1, i32 0, i32 0 + %call = call float @__enzyme_autodiff(i8* bitcast (void (float*, float*, float, float)* @g to i8*), float* %arraydecay, float* %arraydecay1, float* %arraydecay2, float* %arraydecay3, float 2.000000e+00, float 3.000000e+00) + ret i32 0 +} + +declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1) + +declare dso_local float @__enzyme_autodiff(i8*, float*, float*, float*, float*, float, float) + +;CHECK:define internal { float, float } @diffeg(float* noalias %A, float* %"A'", float* %C, float* %"C'", float %alpha, float %beta) { +;CHECK-NEXT:entry: +;CHECK-NEXT: %B = alloca [6 x float], align 16 +;CHECK-NEXT: %0 = bitcast [6 x float]* %B to i8* +;CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([6 x float]* @__const.g.B to i8*), i64 24, i1 false) +;CHECK-NEXT: %arraydecay = getelementptr inbounds [6 x float], [6 x float]* %B, i32 0, i32 0 +;CHECK-NEXT: call void @cblas_sgemm(i32 102, i32 112, i32 112, i32 4, i32 3, i32 2, float %alpha, float* nocapture readonly %A, i32 2, float* nocapture readonly %arraydecay, i32 3, float %beta, float* %C, i32 4) +;CHECK-NEXT: call void @cblas_sgemm(i32 102, i32 111, i32 111, i32 4, i32 2, i32 3, float %alpha, float* nocapture readonly %"C'", i32 4, float* nocapture readonly %arraydecay, i32 3, float 1.000000e+00, float* %"A'", i32 4) +;CHECK-NEXT: call void @cblas_sscal(i32 12, float %beta, float* %"C'", i32 1) +;CHECK-NEXT: ret { float, float } zeroinitializer +;CHECK-NEXT:} diff --git a/enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_row_nomod.ll b/enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_row_nomod.ll new file mode 100644 index 000000000000..d62ce7ae3785 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_row_nomod.ll @@ -0,0 +1,96 @@ +;RUN: %opt < %s %loadEnzyme -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +;#include +; +;extern float __enzyme_autodiff(void *, float *, float *, float *, float *, +; float *, float *, float, float); +; +;void g(float *restrict A, float *restrict B, float *C, float alpha, float beta) { +; cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, 2, 2, 3, alpha, A, 3, +; B, 2, beta, C, 2); +;} +; +;int main() { +; float A[] = {0.11, 0.12, 0.13, 0.21, 0.22, 0.23}; +; float B[] = {1011, 1012, 1021, 1022, 1031, 1032}; +; float C[] = {0.00, 0.00, 0.00, 0.00}; +; float A1[] = {0, 0, 0, 0, 0, 0}; +; float B1[] = {0, 0, 0, 0, 0, 0}; +; float C1[] = {1, 3, 7, 11}; +; __enzyme_autodiff((void *)g, A, A1, B, B1, C, C1, 2.0, 3.0); +;} + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@__const.main.A = private unnamed_addr constant [6 x float] [float 0x3FBC28F5C0000000, float 0x3FBEB851E0000000, float 0x3FC0A3D700000000, float 0x3FCAE147A0000000, float 0x3FCC28F5C0000000, float 0x3FCD70A3E0000000], align 16 +@__const.main.B = private unnamed_addr constant [6 x float] [float 1.011000e+03, float 1.012000e+03, float 1.021000e+03, float 1.022000e+03, float 1.031000e+03, float 1.032000e+03], align 16 +@__const.main.C1 = private unnamed_addr constant [4 x float] [float 1.000000e+00, float 3.000000e+00, float 7.000000e+00, float 1.100000e+01], align 16 + +define dso_local void @g(float* %A, float* %B, float* %C, float %alpha, float %beta) { +entry: + %A.addr = alloca float*, align 8 + %B.addr = alloca float*, align 8 + %C.addr = alloca float*, align 8 + %alpha.addr = alloca float, align 4 + %beta.addr = alloca float, align 4 + store float* %A, float** %A.addr, align 8 + store float* %B, float** %B.addr, align 8 + store float* %C, float** %C.addr, align 8 + store float %alpha, float* %alpha.addr, align 4 + store float %beta, float* %beta.addr, align 4 + %0 = load float, float* %alpha.addr, align 4 + %1 = load float*, float** %A.addr, align 8 + %2 = load float*, float** %B.addr, align 8 + %3 = load float, float* %beta.addr, align 4 + %4 = load float*, float** %C.addr, align 8 + call void @cblas_sgemm(i32 101, i32 111, i32 111, i32 2, i32 2, i32 3, float %0, float* %1, i32 3, float* %2, i32 2, float %3, float* %4, i32 2) + ret void +} + +declare dso_local void @cblas_sgemm(i32, i32, i32, i32, i32, i32, float, float*, i32, float*, i32, float, float*, i32) + +define dso_local i32 @main() { +entry: + %A = alloca [6 x float], align 16 + %B = alloca [6 x float], align 16 + %C = alloca [4 x float], align 16 + %A1 = alloca [6 x float], align 16 + %B1 = alloca [6 x float], align 16 + %C1 = alloca [4 x float], align 16 + %0 = bitcast [6 x float]* %A to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([6 x float]* @__const.main.A to i8*), i64 24, i1 false) + %1 = bitcast [6 x float]* %B to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %1, i8* align 16 bitcast ([6 x float]* @__const.main.B to i8*), i64 24, i1 false) + %2 = bitcast [4 x float]* %C to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %2, i8 0, i64 16, i1 false) + %3 = bitcast [6 x float]* %A1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %3, i8 0, i64 24, i1 false) + %4 = bitcast [6 x float]* %B1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %4, i8 0, i64 24, i1 false) + %5 = bitcast [4 x float]* %C1 to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %5, i8* align 16 bitcast ([4 x float]* @__const.main.C1 to i8*), i64 16, i1 false) + %arraydecay = getelementptr inbounds [6 x float], [6 x float]* %A, i32 0, i32 0 + %arraydecay1 = getelementptr inbounds [6 x float], [6 x float]* %A1, i32 0, i32 0 + %arraydecay2 = getelementptr inbounds [6 x float], [6 x float]* %B, i32 0, i32 0 + %arraydecay3 = getelementptr inbounds [6 x float], [6 x float]* %B1, i32 0, i32 0 + %arraydecay4 = getelementptr inbounds [4 x float], [4 x float]* %C, i32 0, i32 0 + %arraydecay5 = getelementptr inbounds [4 x float], [4 x float]* %C1, i32 0, i32 0 + %call = call float @__enzyme_autodiff(i8* bitcast (void (float*, float*, float*, float, float)* @g to i8*), float* %arraydecay, float* %arraydecay1, float* %arraydecay2, float* %arraydecay3, float* %arraydecay4, float* %arraydecay5, float 2.000000e+00, float 3.000000e+00) + ret i32 0 +} + +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1) + +declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1) + +declare dso_local float @__enzyme_autodiff(i8*, float*, float*, float*, float*, float*, float*, float, float) + +;CHECK:define internal { float, float } @diffeg(float* %A, float* %"A'", float* %B, float* %"B'", float* %C, float* %"C'", float %alpha, float %beta) { +;CHECK-NEXT:entry: +;CHECK-NEXT: call void @cblas_sgemm(i32 101, i32 111, i32 111, i32 2, i32 2, i32 3, float %alpha, float* nocapture readonly %A, i32 3, float* nocapture readonly %B, i32 2, float %beta, float* %C, i32 2) +;CHECK-NEXT: call void @cblas_sgemm(i32 101, i32 111, i32 112, i32 2, i32 3, i32 2, float %alpha, float* nocapture readonly %"C'", i32 2, float* nocapture readonly %B, i32 2, float 1.000000e+00, float* %"A'", i32 3) +;CHECK-NEXT: call void @cblas_sgemm(i32 101, i32 112, i32 111, i32 3, i32 2, i32 2, float %alpha, float* nocapture readonly %A, i32 3, float* nocapture readonly %"C'", i32 2, float 1.000000e+00, float* %"B'", i32 2) +;CHECK-NEXT: call void @cblas_sscal(i32 4, float %beta, float* %"C'", i32 1) +;CHECK-NEXT: ret { float, float } zeroinitializer +;CHECK-NEXT:} diff --git a/enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_row_nomod_transa.ll b/enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_row_nomod_transa.ll new file mode 100644 index 000000000000..09b0241bc6c5 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_row_nomod_transa.ll @@ -0,0 +1,96 @@ +;RUN: %opt < %s %loadEnzyme -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +;#include +; +;extern float __enzyme_autodiff(void *, float *, float *, float *, float *, +; float *, float *, float, float); +; +;void g(float *restrict A, float *restrict B, float *C, float alpha, float beta) { +; cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, 2, 4, 3, alpha, A, 2, B, +; 4, beta, C, 4); +;} +; +;int main() { +; float A[] = {1, 4, 2, 5, 3, 6}; +; float B[] = {21, 0.9, 30, 33, 0.3, 1, 31, 34, 0.7, 26, 32, 35}; +; float C[] = {0.00, 0.00, 0.0, 0.0, 0.00, 0.00, 0.0, 0.0}; +; float A1[] = {0, 0, 0, 0, 0, 0}; +; float B1[] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}; +; float C1[] = {1, 1, 1, 1, 1, 1, 1, 1}; +; __enzyme_autodiff((void *)g, A, A1, B, B1, C, C1, 2.0, 2.0); +;} + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@__const.main.A = private unnamed_addr constant [6 x float] [float 1.000000e+00, float 4.000000e+00, float 2.000000e+00, float 5.000000e+00, float 3.000000e+00, float 6.000000e+00], align 16 +@__const.main.B = private unnamed_addr constant [12 x float] [float 2.100000e+01, float 0x3FECCCCCC0000000, float 3.000000e+01, float 3.300000e+01, float 0x3FD3333340000000, float 1.000000e+00, float 3.100000e+01, float 3.400000e+01, float 0x3FE6666660000000, float 2.600000e+01, float 3.200000e+01, float 3.500000e+01], align 16 +@__const.main.C1 = private unnamed_addr constant [8 x float] [float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00], align 16 + +define dso_local void @g(float* %A, float* %B, float* %C, float %alpha, float %beta) { +entry: + %A.addr = alloca float*, align 8 + %B.addr = alloca float*, align 8 + %C.addr = alloca float*, align 8 + %alpha.addr = alloca float, align 4 + %beta.addr = alloca float, align 4 + store float* %A, float** %A.addr, align 8 + store float* %B, float** %B.addr, align 8 + store float* %C, float** %C.addr, align 8 + store float %alpha, float* %alpha.addr, align 4 + store float %beta, float* %beta.addr, align 4 + %0 = load float, float* %alpha.addr, align 4 + %1 = load float*, float** %A.addr, align 8 + %2 = load float*, float** %B.addr, align 8 + %3 = load float, float* %beta.addr, align 4 + %4 = load float*, float** %C.addr, align 8 + call void @cblas_sgemm(i32 101, i32 112, i32 111, i32 2, i32 4, i32 3, float %0, float* %1, i32 2, float* %2, i32 4, float %3, float* %4, i32 4) + ret void +} + +declare dso_local void @cblas_sgemm(i32, i32, i32, i32, i32, i32, float, float*, i32, float*, i32, float, float*, i32) + +define dso_local i32 @main() { +entry: + %A = alloca [6 x float], align 16 + %B = alloca [12 x float], align 16 + %C = alloca [8 x float], align 16 + %A1 = alloca [6 x float], align 16 + %B1 = alloca [12 x float], align 16 + %C1 = alloca [8 x float], align 16 + %0 = bitcast [6 x float]* %A to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([6 x float]* @__const.main.A to i8*), i64 24, i1 false) + %1 = bitcast [12 x float]* %B to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %1, i8* align 16 bitcast ([12 x float]* @__const.main.B to i8*), i64 48, i1 false) + %2 = bitcast [8 x float]* %C to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %2, i8 0, i64 32, i1 false) + %3 = bitcast [6 x float]* %A1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %3, i8 0, i64 24, i1 false) + %4 = bitcast [12 x float]* %B1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %4, i8 0, i64 48, i1 false) + %5 = bitcast [8 x float]* %C1 to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %5, i8* align 16 bitcast ([8 x float]* @__const.main.C1 to i8*), i64 32, i1 false) + %arraydecay = getelementptr inbounds [6 x float], [6 x float]* %A, i32 0, i32 0 + %arraydecay1 = getelementptr inbounds [6 x float], [6 x float]* %A1, i32 0, i32 0 + %arraydecay2 = getelementptr inbounds [12 x float], [12 x float]* %B, i32 0, i32 0 + %arraydecay3 = getelementptr inbounds [12 x float], [12 x float]* %B1, i32 0, i32 0 + %arraydecay4 = getelementptr inbounds [8 x float], [8 x float]* %C, i32 0, i32 0 + %arraydecay5 = getelementptr inbounds [8 x float], [8 x float]* %C1, i32 0, i32 0 + %call = call float @__enzyme_autodiff(i8* bitcast (void (float*, float*, float*, float, float)* @g to i8*), float* %arraydecay, float* %arraydecay1, float* %arraydecay2, float* %arraydecay3, float* %arraydecay4, float* %arraydecay5, float 2.000000e+00, float 2.000000e+00) + ret i32 0 +} + +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1) + +declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1) + +declare dso_local float @__enzyme_autodiff(i8*, float*, float*, float*, float*, float*, float*, float, float) + +;CHECK:define internal { float, float } @diffeg(float* %A, float* %"A'", float* %B, float* %"B'", float* %C, float* %"C'", float %alpha, float %beta) { +;CHECK-NEXT:entry: +;CHECK-NEXT: call void @cblas_sgemm(i32 101, i32 112, i32 111, i32 2, i32 4, i32 3, float %alpha, float* nocapture readonly %A, i32 2, float* nocapture readonly %B, i32 4, float %beta, float* %C, i32 4) +;CHECK-NEXT: call void @cblas_sgemm(i32 101, i32 111, i32 112, i32 2, i32 3, i32 4, float %alpha, float* nocapture readonly %"C'", i32 4, float* nocapture readonly %B, i32 4, float 1.000000e+00, float* %"A'", i32 3) +;CHECK-NEXT: call void @cblas_sgemm(i32 101, i32 111, i32 111, i32 3, i32 4, i32 2, float %alpha, float* nocapture readonly %A, i32 2, float* nocapture readonly %"C'", i32 4, float 1.000000e+00, float* %"B'", i32 4) +;CHECK-NEXT: call void @cblas_sscal(i32 8, float %beta, float* %"C'", i32 1) +;CHECK-NEXT: ret { float, float } zeroinitializer +;CHECK-NEXT:} diff --git a/enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_row_nomod_transb.ll b/enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_row_nomod_transb.ll new file mode 100644 index 000000000000..ff69f0ed4537 --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_row_nomod_transb.ll @@ -0,0 +1,104 @@ +;RUN: %opt < %s %loadEnzyme -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +;#include +; +;extern float __enzyme_autodiff(void *, float *, float *, float *, float *, float *, float*, float, float); +; +;void g(float *restrict A, float *restrict B, float *C, float alpha, float beta) { +; cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, 2, 4, 3, alpha, A, 3, B, 3, beta, C, 4); +;} +; +;int main() { +; float A[] = { 1, 2, 3, +; 4, 5, 6}; +; float B[] = {21, 0.3, 0.7, +; 0.9, 1, 26, +; 30, 31, 32, +; 33, 34, 35}; +; float C[] = { 0.00, 0.00, 0.0, 0.0, +; 0.00, 0.00, 0.0, 0.0}; +; float A1[] = {0, 0, 0, +; 0, 0, 0}; +; float B1[] = {0, 0, 0, +; 0, 0, 0, +; 0, 0, 0, +; 0, 0, 0}; +; float C1[] = {1, 1, 1, 1, +; 1, 1, 1, 1}; +; __enzyme_autodiff((void*)g, A, A1, B, B1, C, C1, 2.0, 2.0); +;} + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@__const.main.A = private unnamed_addr constant [6 x float] [float 1.000000e+00, float 2.000000e+00, float 3.000000e+00, float 4.000000e+00, float 5.000000e+00, float 6.000000e+00], align 16 +@__const.main.B = private unnamed_addr constant [12 x float] [float 2.100000e+01, float 0x3FD3333340000000, float 0x3FE6666660000000, float 0x3FECCCCCC0000000, float 1.000000e+00, float 2.600000e+01, float 3.000000e+01, float 3.100000e+01, float 3.200000e+01, float 3.300000e+01, float 3.400000e+01, float 3.500000e+01], align 16 +@__const.main.C1 = private unnamed_addr constant [8 x float] [float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00], align 16 + +define dso_local void @g(float* %A, float* %B, float* %C, float %alpha, float %beta) { +entry: + %A.addr = alloca float*, align 8 + %B.addr = alloca float*, align 8 + %C.addr = alloca float*, align 8 + %alpha.addr = alloca float, align 4 + %beta.addr = alloca float, align 4 + store float* %A, float** %A.addr, align 8 + store float* %B, float** %B.addr, align 8 + store float* %C, float** %C.addr, align 8 + store float %alpha, float* %alpha.addr, align 4 + store float %beta, float* %beta.addr, align 4 + %0 = load float, float* %alpha.addr, align 4 + %1 = load float*, float** %A.addr, align 8 + %2 = load float*, float** %B.addr, align 8 + %3 = load float, float* %beta.addr, align 4 + %4 = load float*, float** %C.addr, align 8 + call void @cblas_sgemm(i32 101, i32 111, i32 112, i32 2, i32 4, i32 3, float %0, float* %1, i32 3, float* %2, i32 3, float %3, float* %4, i32 4) + ret void +} + +declare dso_local void @cblas_sgemm(i32, i32, i32, i32, i32, i32, float, float*, i32, float*, i32, float, float*, i32) + +define dso_local i32 @main() { +entry: + %A = alloca [6 x float], align 16 + %B = alloca [12 x float], align 16 + %C = alloca [8 x float], align 16 + %A1 = alloca [6 x float], align 16 + %B1 = alloca [12 x float], align 16 + %C1 = alloca [8 x float], align 16 + %0 = bitcast [6 x float]* %A to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([6 x float]* @__const.main.A to i8*), i64 24, i1 false) + %1 = bitcast [12 x float]* %B to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %1, i8* align 16 bitcast ([12 x float]* @__const.main.B to i8*), i64 48, i1 false) + %2 = bitcast [8 x float]* %C to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %2, i8 0, i64 32, i1 false) + %3 = bitcast [6 x float]* %A1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %3, i8 0, i64 24, i1 false) + %4 = bitcast [12 x float]* %B1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %4, i8 0, i64 48, i1 false) + %5 = bitcast [8 x float]* %C1 to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %5, i8* align 16 bitcast ([8 x float]* @__const.main.C1 to i8*), i64 32, i1 false) + %arraydecay = getelementptr inbounds [6 x float], [6 x float]* %A, i32 0, i32 0 + %arraydecay1 = getelementptr inbounds [6 x float], [6 x float]* %A1, i32 0, i32 0 + %arraydecay2 = getelementptr inbounds [12 x float], [12 x float]* %B, i32 0, i32 0 + %arraydecay3 = getelementptr inbounds [12 x float], [12 x float]* %B1, i32 0, i32 0 + %arraydecay4 = getelementptr inbounds [8 x float], [8 x float]* %C, i32 0, i32 0 + %arraydecay5 = getelementptr inbounds [8 x float], [8 x float]* %C1, i32 0, i32 0 + %call = call float @__enzyme_autodiff(i8* bitcast (void (float*, float*, float*, float, float)* @g to i8*), float* %arraydecay, float* %arraydecay1, float* %arraydecay2, float* %arraydecay3, float* %arraydecay4, float* %arraydecay5, float 2.000000e+00, float 2.000000e+00) + ret i32 0 +} + +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1) + +declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1) + +declare dso_local float @__enzyme_autodiff(i8*, float*, float*, float*, float*, float*, float*, float, float) + +;CHECK:define internal { float, float } @diffeg(float* %A, float* %"A'", float* %B, float* %"B'", float* %C, float* %"C'", float %alpha, float %beta) { +;CHECK-NEXT:entry: +;CHECK-NEXT: call void @cblas_sgemm(i32 101, i32 111, i32 112, i32 2, i32 4, i32 3, float %alpha, float* nocapture readonly %A, i32 3, float* nocapture readonly %B, i32 3, float %beta, float* %C, i32 4) +;CHECK-NEXT: call void @cblas_sgemm(i32 101, i32 111, i32 111, i32 2, i32 3, i32 4, float %alpha, float* nocapture readonly %"C'", i32 4, float* nocapture readonly %B, i32 3, float 1.000000e+00, float* %"A'", i32 3) +;CHECK-NEXT: call void @cblas_sgemm(i32 101, i32 112, i32 111, i32 3, i32 4, i32 2, float %alpha, float* nocapture readonly %A, i32 3, float* nocapture readonly %"C'", i32 4, float 1.000000e+00, float* %"B'", i32 4) +;CHECK-NEXT: call void @cblas_sscal(i32 8, float %beta, float* %"C'", i32 1) +;CHECK-NEXT: ret { float, float } zeroinitializer +;CHECK-NEXT:} diff --git a/enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_row_nomod_transboth.ll b/enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_row_nomod_transboth.ll new file mode 100644 index 000000000000..d7f1329c2e1f --- /dev/null +++ b/enzyme/test/Enzyme/ReverseMode/blas/cblas_sgemm_row_nomod_transboth.ll @@ -0,0 +1,104 @@ +;RUN: %opt < %s %loadEnzyme -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s + +;#include +; +;extern float __enzyme_autodiff(void *, float *, float *, float *, float *, float *, float*, float, float); +; +;void g(float *restrict A, float *restrict B, float *C, float alpha, float beta) { +; cblas_sgemm(CblasRowMajor, CblasTrans, CblasTrans, 2, 4, 3, alpha, A, 2, B, 3, beta, C, 4); +;} +; +;int main() { +; float A[] = {1, 4, +; 2, 5, +; 3, 6}; +; float B[] = {21, 0.3, 0.7, +; 0.9, 1, 26, +; 30, 31, 32, +; 33, 34, 35}; +; float C[] = { 0.00, 0.00, 0.0, 0.0, +; 0.00, 0.00, 0.0, 0.0}; +; float A1[] = {0, 0, 0, +; 0, 0, 0}; +; float B1[] = {0, 0, 0, 0, +; 0, 0, 0, 0, +; 0, 0, 0, 0}; +; float C1[] = {1, 1, 1, 1, +; 1, 1, 1, 1}; +; __enzyme_autodiff((void*)g, A, A1, B, B1, C, C1, 2.0, 2.0); +;} + +target datalayout = "e-m:e-i64:64-f80:128-n8:16:32:64-S128" +target triple = "x86_64-unknown-linux-gnu" + +@__const.main.A = private unnamed_addr constant [6 x float] [float 1.000000e+00, float 4.000000e+00, float 2.000000e+00, float 5.000000e+00, float 3.000000e+00, float 6.000000e+00], align 16 +@__const.main.B = private unnamed_addr constant [12 x float] [float 2.100000e+01, float 0x3FD3333340000000, float 0x3FE6666660000000, float 0x3FECCCCCC0000000, float 1.000000e+00, float 2.600000e+01, float 3.000000e+01, float 3.100000e+01, float 3.200000e+01, float 3.300000e+01, float 3.400000e+01, float 3.500000e+01], align 16 +@__const.main.C1 = private unnamed_addr constant [8 x float] [float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00, float 1.000000e+00], align 16 + +define dso_local void @g(float* %A, float* %B, float* %C, float %alpha, float %beta) { +entry: + %A.addr = alloca float*, align 8 + %B.addr = alloca float*, align 8 + %C.addr = alloca float*, align 8 + %alpha.addr = alloca float, align 4 + %beta.addr = alloca float, align 4 + store float* %A, float** %A.addr, align 8 + store float* %B, float** %B.addr, align 8 + store float* %C, float** %C.addr, align 8 + store float %alpha, float* %alpha.addr, align 4 + store float %beta, float* %beta.addr, align 4 + %0 = load float, float* %alpha.addr, align 4 + %1 = load float*, float** %A.addr, align 8 + %2 = load float*, float** %B.addr, align 8 + %3 = load float, float* %beta.addr, align 4 + %4 = load float*, float** %C.addr, align 8 + call void @cblas_sgemm(i32 101, i32 112, i32 112, i32 2, i32 4, i32 3, float %0, float* %1, i32 2, float* %2, i32 3, float %3, float* %4, i32 4) + ret void +} + +declare dso_local void @cblas_sgemm(i32, i32, i32, i32, i32, i32, float, float*, i32, float*, i32, float, float*, i32) + +define dso_local i32 @main() { +entry: + %A = alloca [6 x float], align 16 + %B = alloca [12 x float], align 16 + %C = alloca [8 x float], align 16 + %A1 = alloca [6 x float], align 16 + %B1 = alloca [12 x float], align 16 + %C1 = alloca [8 x float], align 16 + %0 = bitcast [6 x float]* %A to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %0, i8* align 16 bitcast ([6 x float]* @__const.main.A to i8*), i64 24, i1 false) + %1 = bitcast [12 x float]* %B to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %1, i8* align 16 bitcast ([12 x float]* @__const.main.B to i8*), i64 48, i1 false) + %2 = bitcast [8 x float]* %C to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %2, i8 0, i64 32, i1 false) + %3 = bitcast [6 x float]* %A1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %3, i8 0, i64 24, i1 false) + %4 = bitcast [12 x float]* %B1 to i8* + call void @llvm.memset.p0i8.i64(i8* align 16 %4, i8 0, i64 48, i1 false) + %5 = bitcast [8 x float]* %C1 to i8* + call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 16 %5, i8* align 16 bitcast ([8 x float]* @__const.main.C1 to i8*), i64 32, i1 false) + %arraydecay = getelementptr inbounds [6 x float], [6 x float]* %A, i32 0, i32 0 + %arraydecay1 = getelementptr inbounds [6 x float], [6 x float]* %A1, i32 0, i32 0 + %arraydecay2 = getelementptr inbounds [12 x float], [12 x float]* %B, i32 0, i32 0 + %arraydecay3 = getelementptr inbounds [12 x float], [12 x float]* %B1, i32 0, i32 0 + %arraydecay4 = getelementptr inbounds [8 x float], [8 x float]* %C, i32 0, i32 0 + %arraydecay5 = getelementptr inbounds [8 x float], [8 x float]* %C1, i32 0, i32 0 + %call = call float @__enzyme_autodiff(i8* bitcast (void (float*, float*, float*, float, float)* @g to i8*), float* %arraydecay, float* %arraydecay1, float* %arraydecay2, float* %arraydecay3, float* %arraydecay4, float* %arraydecay5, float 2.000000e+00, float 2.000000e+00) + ret i32 0 +} + +declare void @llvm.memcpy.p0i8.p0i8.i64(i8* nocapture writeonly, i8* nocapture readonly, i64, i1) + +declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1) + +declare dso_local float @__enzyme_autodiff(i8*, float*, float*, float*, float*, float*, float*, float, float) + +;CHECK:define internal { float, float } @diffeg(float* %A, float* %"A'", float* %B, float* %"B'", float* %C, float* %"C'", float %alpha, float %beta) { +;CHECK-NEXT:entry: +;CHECK-NEXT: call void @cblas_sgemm(i32 101, i32 112, i32 112, i32 2, i32 4, i32 3, float %alpha, float* nocapture readonly %A, i32 2, float* nocapture readonly %B, i32 3, float %beta, float* %C, i32 4) +;CHECK-NEXT: call void @cblas_sgemm(i32 101, i32 111, i32 111, i32 2, i32 3, i32 4, float %alpha, float* nocapture readonly %"C'", i32 4, float* nocapture readonly %B, i32 3, float 1.000000e+00, float* %"A'", i32 3) +;CHECK-NEXT: call void @cblas_sgemm(i32 101, i32 111, i32 111, i32 3, i32 4, i32 2, float %alpha, float* nocapture readonly %A, i32 2, float* nocapture readonly %"C'", i32 4, float 1.000000e+00, float* %"B'", i32 4) +;CHECK-NEXT: call void @cblas_sscal(i32 8, float %beta, float* %"C'", i32 1) +;CHECK-NEXT: ret { float, float } zeroinitializer +;CHECK-NEXT:}