Skip to content

Commit

Permalink
syrk functioning (#1916)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jun 3, 2024
1 parent 0afb1d3 commit 9e4e6c1
Show file tree
Hide file tree
Showing 5 changed files with 233 additions and 53 deletions.
13 changes: 9 additions & 4 deletions enzyme/Enzyme/BlasDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def Concat : MagicInst;

def Mul : MagicInst;

def uplo_to_side : MagicInst;
def trans_to_side : MagicInst;

def Lookup : MagicInst;

Expand Down Expand Up @@ -184,7 +184,11 @@ def copy : CallBlasPattern<(Op $n, $x, $incx, $y, $incy),
[
(InactiveArg),// copy moves x into y, so x is never modified.
(BlasCall<"axpy"> $n, Constant<"1.0">, (Shadow $y), (Shadow $x))
]
],
(Seq<[], ["beta1"]>
(BlasCall<"copy"> (FirstUse<"beta1"> $n, $n), (Shadow $x), (Shadow $y)),
(FirstUse<"beta1"> (BlasCall<"scal"> $n, Constant<"0">, (Shadow $y)))
)
>;

// def swap : CallBlasPattern<(Op $n, $x, $incx, $y, $incy),
Expand Down Expand Up @@ -321,7 +325,8 @@ def syrk : CallBlasPattern<(Op $layout, $uplo, $trans, $n, $k, $alpha, $A, $lda,
(Seq<[], []>
(BlasCall<"symm">
$layout,
(uplo_to_side $uplo),
(trans_to_side $uplo),
$uplo,
(Rows $trans,
(Concat $n, $k),
(Concat $k, $n)),
Expand All @@ -332,7 +337,7 @@ def syrk : CallBlasPattern<(Op $layout, $uplo, $trans, $n, $k, $alpha, $A, $lda,
(Shadow $A)
),
(For<"i"> $n,
(BlasCall<"axpy"> $k, (Mul $alpha, (Lookup (Shadow $C), $i, $i)), (Lookup (Concat $A, (ld $A, $trans, $lda, $n, $k)), $i), Constant<"1">, (Shadow $A))
(BlasCall<"axpy"> $k, (Mul $alpha, (Lookup (Shadow $C), $i, $i)), (Lookup (Concat $A, (ld $A, $trans, $lda, $n, $k)), $i), ConstantInt<1>, (Lookup (Shadow $A), $i), ConstantInt<1>)
)
),
/* beta */ (AssertingInactiveArg),
Expand Down
17 changes: 9 additions & 8 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2642,7 +2642,8 @@ std::optional<BlasInfo> extractBLAS(llvm::StringRef in)
llvm::Optional<BlasInfo> extractBLAS(llvm::StringRef in)
#endif
{
const char *extractable[] = {"dot", "scal", "axpy", "gemv", "gemm", "spmv"};
const char *extractable[] = {"dot", "scal", "axpy", "gemv",
"gemm", "spmv", "syrk"};
const char *floatType[] = {"s", "d"}; // c, z
const char *prefixes[] = {"" /*Fortran*/, "cblas_"};
const char *suffixes[] = {"", "_", "64_", "_64_"};
Expand Down Expand Up @@ -2943,7 +2944,7 @@ llvm::Value *transpose(llvm::IRBuilder<> &B, llvm::Value *V, bool byRef,
"transpose." + name);
}

llvm::Value *uplo_to_side(IRBuilder<> &B, llvm::Value *V, bool cublas) {
llvm::Value *trans_to_side(IRBuilder<> &B, llvm::Value *V, bool cublas) {
llvm::Type *T = V->getType();
if (cublas) {
assert(0 && "cublas unknown");
Expand All @@ -2962,10 +2963,10 @@ llvm::Value *uplo_to_side(IRBuilder<> &B, llvm::Value *V, bool cublas) {
return sel3;
}

llvm::Value *uplo_to_side(llvm::IRBuilder<> &B, llvm::Value *V, bool byRef,
bool cublas, llvm::IntegerType *julia_decl,
llvm::IRBuilder<> &entryBuilder,
const llvm::Twine &name) {
llvm::Value *trans_to_side(llvm::IRBuilder<> &B, llvm::Value *V, bool byRef,
bool cublas, llvm::IntegerType *julia_decl,
llvm::IRBuilder<> &entryBuilder,
const llvm::Twine &name) {

if (!byRef) {
// Explicitly support 'N' always, since we use in the rule infra
Expand All @@ -2986,10 +2987,10 @@ llvm::Value *uplo_to_side(llvm::IRBuilder<> &B, llvm::Value *V, bool byRef,
V = B.CreateLoad(charType, V, "ld." + name);
}

V = uplo_to_side(B, V, cublas);
V = trans_to_side(B, V, cublas);

return to_blas_callconv(B, V, byRef, cublas, julia_decl, entryBuilder,
"uplo_to_side." + name);
"trans_to_side." + name);
}

llvm::Value *load_if_ref(llvm::IRBuilder<> &B, llvm::Type *intType,
Expand Down
10 changes: 5 additions & 5 deletions enzyme/Enzyme/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1860,12 +1860,12 @@ llvm::Value *transpose(llvm::IRBuilder<> &B, llvm::Value *V, bool byRef,
llvm::IRBuilder<> &entryBuilder,
const llvm::Twine &name);
// first one assume V is an Integer
llvm::Value *uplo_to_side(llvm::IRBuilder<> &B, llvm::Value *V, bool cublas);
llvm::Value *trans_to_side(llvm::IRBuilder<> &B, llvm::Value *V, bool cublas);
// secon one assume V is an Integer or a ptr to an int (depends on byRef)
llvm::Value *uplo_to_side(llvm::IRBuilder<> &B, llvm::Value *V, bool byRef,
bool cublas, llvm::IntegerType *IT,
llvm::IRBuilder<> &entryBuilder,
const llvm::Twine &name);
llvm::Value *trans_to_side(llvm::IRBuilder<> &B, llvm::Value *V, bool byRef,
bool cublas, llvm::IntegerType *IT,
llvm::IRBuilder<> &entryBuilder,
const llvm::Twine &name);
llvm::SmallVector<llvm::Value *, 1>
get_blas_row(llvm::IRBuilder<> &B, llvm::ArrayRef<llvm::Value *> trans,
llvm::ArrayRef<llvm::Value *> row,
Expand Down
172 changes: 172 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/blas/syrk_f.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
;RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -S | FileCheck %s; fi
;RUN: %opt < %s %newLoadEnzyme -passes="enzyme" -S | FileCheck %s

; dsyrk ( character UPLO,
; character TRANS,
; integer N,
; integer K,
; double precision ALPHA,
; double precision, dimension(lda,*) A,
; integer LDA,
; double precision BETA,
; double precision, dimension(ldc,*) C,
; integer LDC
; )

declare void @dsyrk_64_(i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture readonly, i8* nocapture, i8* nocapture readonly, i64, i64)

define void @f(i8* %C, i8* %A) {
entry:
%uplo = alloca i8, align 1
%trans = alloca i8, align 1
%n = alloca i64, align 16
%n_p = bitcast i64* %n to i8*
%k = alloca i64, align 16
%k_p = bitcast i64* %k to i8*
%alpha = alloca double, align 16
%alpha_p = bitcast double* %alpha to i8*
%lda = alloca i64, align 16
%lda_p = bitcast i64* %lda to i8*
%ldb = alloca i64, align 16
%ldb_p = bitcast i64* %ldb to i8*
%beta = alloca double, align 16
%beta_p = bitcast double* %beta to i8*
%ldc = alloca i64, align 16
%ldc_p = bitcast i64* %ldc to i8*
store i8 85, i8* %uplo, align 1
store i8 78, i8* %trans, align 1
store i64 4, i64* %n, align 16
store i64 8, i64* %k, align 16
store double 1.000000e+00, double* %alpha, align 16
store i64 4, i64* %lda, align 16
store i64 8, i64* %ldb, align 16
store double 0.000000e+00, double* %beta
store i64 4, i64* %ldc, align 16
call void @dsyrk_64_(i8* %uplo, i8* %trans, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1)
ret void
}

declare dso_local void @__enzyme_autodiff(...)

define void @active(i8* %C, i8* %dC, i8* %A, i8* %dA) {
entry:
call void (...) @__enzyme_autodiff(void (i8*,i8*)* @f, metadata !"enzyme_dup", i8* %C, i8* %dC, metadata !"enzyme_dup", i8* %A, i8* %dA)
ret void
}

; CHECK: define internal void @diffef(i8* %C, i8* %"C'", i8* %A, i8* %"A'")
; CHECK-NEXT: entry:
; CHECK-NEXT: %ret = alloca double, align 8
; CHECK-NEXT: %byref.int.one = alloca i64, align 8
; CHECK-NEXT: %byref.trans_to_side.uplo = alloca i8, align 1
; CHECK-NEXT: %byref.constant.fp.1 = alloca double, align 8
; CHECK-NEXT: %byref.for.i = alloca i64, align 8
; CHECK-NEXT: %byref.mul = alloca double, align 8
; CHECK-NEXT: %byref.constant.int.1 = alloca i64, align 8
; CHECK-NEXT: %byref.constant.int.11 = alloca i64, align 8
; CHECK-NEXT: %byref.constant.int.0 = alloca i64, align 8
; CHECK-NEXT: %byref.constant.int.03 = alloca i64, align 8
; CHECK-NEXT: %byref.constant.fp.1.0 = alloca double, align 8
; CHECK-NEXT: %0 = alloca i64, align 8
; CHECK-NEXT: %uplo = alloca i8, align 1
; CHECK-NEXT: %trans = alloca i8, align 1
; CHECK-NEXT: %n = alloca i64, align 16
; CHECK-NEXT: %n_p = bitcast i64* %n to i8*
; CHECK-NEXT: %k = alloca i64, align 16
; CHECK-NEXT: %k_p = bitcast i64* %k to i8*
; CHECK-NEXT: %alpha = alloca double, align 16
; CHECK-NEXT: %alpha_p = bitcast double* %alpha to i8*
; CHECK-NEXT: %lda = alloca i64, align 16
; CHECK-NEXT: %lda_p = bitcast i64* %lda to i8*
; CHECK-NEXT: %beta = alloca double, align 16
; CHECK-NEXT: %beta_p = bitcast double* %beta to i8*
; CHECK-NEXT: %ldc = alloca i64, align 16
; CHECK-NEXT: %ldc_p = bitcast i64* %ldc to i8*
; CHECK-NEXT: store i8 85, i8* %uplo, align 1
; CHECK-NEXT: store i8 78, i8* %trans, align 1
; CHECK-NEXT: store i64 4, i64* %n, align 16
; CHECK-NEXT: store i64 8, i64* %k, align 16
; CHECK-NEXT: store double 1.000000e+00, double* %alpha, align 16
; CHECK-NEXT: store i64 4, i64* %lda, align 16
; CHECK-NEXT: store double 0.000000e+00, double* %beta, align 8
; CHECK-NEXT: store i64 4, i64* %ldc, align 16
; CHECK-NEXT: call void @dsyrk_64_(i8* %uplo, i8* %trans, i8* %n_p, i8* %k_p, i8* %alpha_p, i8* %A, i8* %lda_p, i8* %beta_p, i8* %C, i8* %ldc_p, i64 1, i64 1)
; CHECK-NEXT: br label %invertentry

; CHECK: invertentry: ; preds = %entry
; CHECK-NEXT: store i64 1, i64* %byref.int.one, align 4
; CHECK-NEXT: %intcast.int.one = bitcast i64* %byref.int.one to i8*
; CHECK-NEXT: %ld.uplo = load i8, i8* %uplo, align 1
; CHECK-NEXT: %1 = icmp eq i8 %ld.uplo, 78
; CHECK-NEXT: %2 = select i1 %1, i8 76, i8 108
; CHECK-NEXT: %3 = icmp eq i8 %ld.uplo, 116
; CHECK-NEXT: %4 = select i1 %3, i8 114, i8 %2
; CHECK-NEXT: %5 = icmp eq i8 %ld.uplo, 84
; CHECK-NEXT: %6 = select i1 %5, i8 82, i8 %4
; CHECK-NEXT: store i8 %6, i8* %byref.trans_to_side.uplo, align 1
; CHECK-NEXT: %ld.row.trans = load i8, i8* %trans, align 1
; CHECK-NEXT: %7 = icmp eq i8 %ld.row.trans, 110
; CHECK-NEXT: %8 = icmp eq i8 %ld.row.trans, 78
; CHECK-NEXT: %9 = or i1 %8, %7
; CHECK-NEXT: %10 = select i1 %9, i8* %n_p, i8* %k_p
; CHECK-NEXT: %11 = select i1 %9, i8* %k_p, i8* %n_p
; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1, align 8
; CHECK-NEXT: %fpcast.constant.fp.1 = bitcast double* %byref.constant.fp.1 to i8*
; CHECK-NEXT: call void @dsymm_64_(i8* %byref.trans_to_side.uplo, i8* %uplo, i8* %10, i8* %11, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %A, i8* %lda_p, i8* %fpcast.constant.fp.1, i8* %"A'", i8* %lda_p)
; CHECK-NEXT: %12 = bitcast i8* %n_p to i64*
; CHECK-NEXT: %13 = load i64, i64* %12, align 4
; CHECK-NEXT: %14 = icmp eq i64 %13, 0
; CHECK-NEXT: br i1 %14, label %invertentry_end, label %invertentry_loop

; CHECK: invertentry_loop: ; preds = %invertentry_loop, %invertentry
; CHECK-NEXT: %15 = phi i64 [ 0, %invertentry ], [ %44, %invertentry_loop ]
; CHECK-NEXT: store i64 %15, i64* %byref.for.i, align 4
; CHECK-NEXT: %intcast.for.i = bitcast i64* %byref.for.i to i8*
; CHECK-NEXT: %16 = bitcast i8* %"C'" to double*
; CHECK-NEXT: %17 = bitcast i8* %ldc_p to i64*
; CHECK-NEXT: %18 = load i64, i64* %17, align 4
; CHECK-NEXT: %19 = bitcast i8* %intcast.for.i to i64*
; CHECK-NEXT: %20 = load i64, i64* %19, align 4
; CHECK-NEXT: %21 = mul i64 %20, %18
; CHECK-NEXT: %22 = bitcast i8* %intcast.for.i to i64*
; CHECK-NEXT: %23 = load i64, i64* %22, align 4
; CHECK-NEXT: %24 = add i64 %21, %23
; CHECK-NEXT: %25 = getelementptr double, double* %16, i64 %24
; CHECK-NEXT: %26 = load double, double* %25, align 8
; CHECK-NEXT: %27 = bitcast i8* %alpha_p to double*
; CHECK-NEXT: %28 = load double, double* %27, align 8
; CHECK-NEXT: %29 = fmul fast double %28, %26
; CHECK-NEXT: store double %29, double* %byref.mul, align 8
; CHECK-NEXT: %30 = bitcast i8* %A to double*
; CHECK-NEXT: %31 = bitcast i8* %lda_p to i64*
; CHECK-NEXT: %32 = load i64, i64* %31, align 4
; CHECK-NEXT: %33 = bitcast i8* %intcast.for.i to i64*
; CHECK-NEXT: %34 = load i64, i64* %33, align 4
; CHECK-NEXT: %35 = mul i64 %34, %32
; CHECK-NEXT: %36 = getelementptr double, double* %30, i64 %35
; CHECK-NEXT: store i64 1, i64* %byref.constant.int.1, align 4
; CHECK-NEXT: %intcast.constant.int.1 = bitcast i64* %byref.constant.int.1 to i8*
; CHECK-NEXT: %37 = bitcast i8* %"A'" to double*
; CHECK-NEXT: %38 = bitcast i8* %lda_p to i64*
; CHECK-NEXT: %39 = load i64, i64* %38, align 4
; CHECK-NEXT: %40 = bitcast i8* %intcast.for.i to i64*
; CHECK-NEXT: %41 = load i64, i64* %40, align 4
; CHECK-NEXT: %42 = mul i64 %41, %39
; CHECK-NEXT: %43 = getelementptr double, double* %37, i64 %42
; CHECK-NEXT: store i64 1, i64* %byref.constant.int.11, align 4
; CHECK-NEXT: %intcast.constant.int.12 = bitcast i64* %byref.constant.int.11 to i8*
; CHECK-NEXT: call void @daxpy_64_(i8* %k_p, double* %byref.mul, double* %36, i8* %intcast.constant.int.1, double* %43, i8* %intcast.constant.int.12)
; CHECK-NEXT: %44 = add nuw nsw i64 %13, 1
; CHECK-NEXT: %45 = icmp eq i64 %13, %44
; CHECK-NEXT: br i1 %45, label %invertentry_end, label %invertentry_loop

; CHECK: invertentry_end: ; preds = %invertentry_loop, %invertentry
; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0, align 4
; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8*
; CHECK-NEXT: store i64 0, i64* %byref.constant.int.03, align 4
; CHECK-NEXT: %intcast.constant.int.04 = bitcast i64* %byref.constant.int.03 to i8*
; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1.0, align 8
; CHECK-NEXT: %fpcast.constant.fp.1.0 = bitcast double* %byref.constant.fp.1.0 to i8*
; CHECK-NEXT: call void @dlascl_64_(i8* %uplo, i8* %intcast.constant.int.0, i8* %intcast.constant.int.04, i8* %fpcast.constant.fp.1.0, i8* %beta_p, i8* %n_p, i8* %n_p, i8* %"C'", i8* %ldc_p, i64* %0, i64 1)
; CHECK-NEXT: ret void
; CHECK-NEXT: }
Loading

0 comments on commit 9e4e6c1

Please sign in to comment.