Skip to content

Commit

Permalink
Add trmv (#1917)
Browse files Browse the repository at this point in the history
* Add trmm

* Fix input mats

* add blasinfra

* fixup

* continue

* trmm tested

* more fixes

* correct trmm [colmajor]

* All tr tests passing

* all functioning now
  • Loading branch information
wsmoses committed Jun 5, 2024
1 parent 9e4e6c1 commit 46647d1
Show file tree
Hide file tree
Showing 13 changed files with 1,905 additions and 446 deletions.
163 changes: 142 additions & 21 deletions enzyme/Enzyme/BlasDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,65 @@ class MagicInst : Inst<"blas">;
def Rows : MagicInst; // given a transpose, normal rows, normal cols get the true rows, aka normal rows if N else normal cols
def Concat : MagicInst;

def Mul : MagicInst;
def ShadowNoInc : MagicInst;

def trans_to_side : MagicInst;
class Binop<string _s, list<string> _tys> {
string s = _s;
list<string> tys = _tys;
}
def BXor : Binop<"Xor", ["intType", "intType"]>;
def BAnd : Binop<"And", ["intType", "intType"]>;
def Not : Binop<"Not", ["intType", "intType"]>;
def Sub : Binop<"Sub", ["intType", "intType"]>;
def Add : Binop<"Add", ["intType", "intType"]>;
def Mul : Binop<"Mul", ["intType", "intType"]>;
def BFMul : Binop<"FMul", ["fpType", "fpType"]>;
def BFDiv : Binop<"FDiv", ["fpType", "fpType"]>;
def FSelect : Binop<"Select", ["Builder2.getInt1Ty()", "fpType", "fpType"]>;
def ISelect : Binop<"Select", ["Builder2.getInt1Ty()", "intType", "intType"]>;

class IntMatchers<string _inty, string _outty, list<string> _before, list<string> _after> {
string inty = _inty;
string outty = _outty;
list<string> before = _before;
list<string> after = _after;
}

def trans_to_side : IntMatchers<
"charType", "charType",
["'n'", "'N'", "'T'", "'t'"],
["'l'", "'L'", "'U'", "'u'"]
>;

def side_to_trans : IntMatchers<
"charType", "charType",
["'l'", "'L'", "'R'", "'r'"],
["'n'", "'N'", "'T'", "'t'"]
>;

def is_upper : IntMatchers<
"charType", "Builder2.getInt1Ty()",
["'u'", "'U'", "'L'", "'l'"],
["true", "true", "false", "false"]
>;

def is_diag_int : IntMatchers<
"charType", "intType",
["'u'", "'U'", "'N'", "'n'"],
["1", "1", "0", "0"]
>;

def is_left : IntMatchers<
"charType", "Builder2.getInt1Ty()",
["'l'", "'L'", "'R'", "'r'"],
["true", "true", "false", "false"]
>;

def First : MagicInst;
def Lookup : MagicInst;
def LoadLookup : MagicInst;

class Add<string _tmp=""> {
class FAdd<string _tmp=""> {
string unused = _tmp;
}

Expand Down Expand Up @@ -100,9 +152,11 @@ class Seq<list<string> _args = [], list<string> _vars = []> {
list<string> vars = _vars;

}
class For<string idx_> {
class For<string idx_, bit offset_> {
string idx = idx_;
bit offset = offset_;
}

class FirstUse<string var_> {
string var = var_;
}
Expand Down Expand Up @@ -170,13 +224,15 @@ def dot : CallBlasPattern<(Op $n, $x, $incx, $y, $incy),
(BlasCall<"axpy"> $n, DiffeRet, $y, (Shadow $x)),
(BlasCall<"axpy"> $n, DiffeRet, $x, (Shadow $y)),
],
(Add<""> (BlasCall<"dot"> $n, (Shadow $x), $y), (BlasCall<"dot"> $n, $x, (Shadow $y)))
(FAdd<""> (BlasCall<"dot"> $n, (Shadow $x), $y), (BlasCall<"dot"> $n, $x, (Shadow $y)))
>;

// def nrm2 : CallBlasPattern<(Op $n, $x, $incx),
// [],[len, vinc],
// [(FDiv (BlasCall<"scal"> $n, DiffeRet, $x, $incx), Ret<"">)]
// >;
//def nrm2 : CallBlasPattern<(Op $n, $x, $incx),
// [],[len, vinc<["n"]>],
// [
// (AssertingInactiveArg) # (BlasCall<"axpy"> $n, (BFDiv DiffeRet, (BlasCall<"nrm2"> $n, $x)), $x, (Shadow $x))
// ]
// >;


def copy : CallBlasPattern<(Op $n, $x, $incx, $y, $incy),
Expand Down Expand Up @@ -238,11 +294,22 @@ def gemv : CallBlasPattern<(Op $layout, $transa, $m, $n, $alpha, $A, $lda, $x, $

// x = Ax
// currently assumes for vector dimensions that transa = 'N' and gets dimensions wrong otherwise
def trmv : CallBlasPattern<(Op $layout, $transa, $diag, $n, $A, $lda, $x, $incx),
["x"], [cblas_layout, trans, diag, len, mld<["diag", "n", "n"]>, vinc<["n"]>],
def trmv : CallBlasPattern<(Op $layout, $uplo, $trans, $diag, $n, $A, $lda, $x, $incx),
["x"], [cblas_layout, uplo, trans, diag, len, mld<["diag", "n", "n"]>, vinc<["n"]>],
[
/* A */ (AssertingInactiveArg), //(BlasCall<"ger"> $layout, $m, $n, $alpha, (Rows $transa, (Concat (Shadow $y), $x), (Concat $x, (Shadow $y))), (Shadow $A)),
/* x */ (BlasCall<"trmv"> $layout, transpose<"transa">, $diag, $n, $A, (ld $A, Char<"N">, $lda, $n, $n), (Shadow $x))
/* A */ (For<"i", 1> (ISelect (is_upper $uplo), $n, (Sub $n, (is_diag_int $diag))),
(BlasCall<"axpy">
(ISelect (is_upper $uplo),
(Sub $i, (is_diag_int $diag)),
(Add (Sub (Sub $n, (is_diag_int $diag)), $i), ConstantInt<1>)
),
(LoadLookup $layout, (Rows $trans, input<"x">, (Shadow $x)), (Sub $i, ConstantInt<1>)),
(Lookup $layout, (Rows $trans, (Shadow $x), input<"x">), (ISelect (is_upper $uplo), ConstantInt<0>, (Sub (Add $i, (is_diag_int $diag)), ConstantInt<1>))),
(First (Lookup $layout, (Shadow $A), (ISelect (is_upper $uplo), ConstantInt<0>, (Sub (Add $i, (is_diag_int $diag)), ConstantInt<1>)), (Sub $i, ConstantInt<1>))),
ConstantInt<1>
)
),
/* x */ (BlasCall<"trmv"> $layout, $uplo, transpose<"trans">, $diag, $n, $A, (ld $A, Char<"N">, $lda, $n, $n), (Shadow $x))
]
>;
//
Expand Down Expand Up @@ -301,6 +368,46 @@ def gemm : CallBlasPattern<(Op $layout, $transa, $transb, $m, $n, $k, $alpha, $A
)
>;

// B := alpha*op( A )*B, or B := alpha*B*op( A ),
def trmm : CallBlasPattern<(Op $layout, $side, $uplo, $transa, $diag, $m, $n, $alpha, $A, $lda, $B, $ldb),
["B"],
[cblas_layout, side, uplo, trans, diag, len, len, fp, mld<["side", "m", "n"]>, mld<["m","n"]>],
[
/*alpha*/ (AssertingInactiveArg),
/* A */ (For<"i", 1> (Sub (ISelect (is_left $side), $m, $n), (ISelect (is_upper $uplo), ConstantInt<0>, (is_diag_int $diag))),
(BlasCall<"gemv">
$layout,
(side_to_trans $side),
(ISelect (is_left $side),
(Concat
(ISelect (is_upper $uplo),
(Sub $i, (is_diag_int $diag)),
(Sub $m, (Sub (Add $i, (is_diag_int $diag)), ConstantInt<1>))),
$n),
(Concat $m,
(ISelect (is_upper $uplo),
(Sub $i, (is_diag_int $diag)),
(Sub $n, (Sub (Add $i, (is_diag_int $diag)), ConstantInt<1>)))
)
),
$alpha,
(Lookup $layout, (ISelect (BXor (is_left $side), (Not (Rows $transa))), (Shadow $B), (Concat input<"B">, $m)),
(ISelect (BAnd (is_left $side), (Not (is_upper $uplo))), (Sub (Add $i, (is_diag_int $diag)), ConstantInt<1>), ConstantInt<0>),
(ISelect (BAnd (Not (is_left $side)), (Not (is_upper $uplo))), (Sub (Add $i, (is_diag_int $diag)), ConstantInt<1>), ConstantInt<0>)),
(First (Lookup $layout, (ISelect (BXor (is_left $side), (Not (Rows $transa))), (Concat input<"B">, $m), (Shadow $B)),
(ISelect (is_left $side), (Sub $i, ConstantInt<1>), ConstantInt<0>),
(ISelect (is_left $side), ConstantInt<0>, (Sub $i, ConstantInt<1>)))),
(ISelect (is_left $side), (ISelect (BXor (is_left $side), (Not (Rows $transa))), $m, $ldb), ConstantInt<1>),
Constant<"1">,
(First (Lookup $layout, (Shadow $A),
(ISelect (is_upper $uplo), ConstantInt<0>, (Sub (Add $i, (is_diag_int $diag)), ConstantInt<1>)),
(Sub $i, ConstantInt<1>))),
ConstantInt<1>
)
),
/* B */ (BlasCall<"trmm"> $layout, $side, $uplo, transpose<"transa">, $diag, $m, $n, $alpha, $A, (ld $A, $side, $lda, $m, $n), (Shadow $B))
]
>;

def symm: CallBlasPattern<(Op $layout, $side, $uplo, $m, $n, $alpha, $A, $lda, $B, $ldb, $beta, $C, $ldc),
["C"],
Expand All @@ -325,7 +432,7 @@ def syrk : CallBlasPattern<(Op $layout, $uplo, $trans, $n, $k, $alpha, $A, $lda,
(Seq<[], []>
(BlasCall<"symm">
$layout,
(trans_to_side $uplo),
(Rows $trans, Char<"l">, Char<"r">),
$uplo,
(Rows $trans,
(Concat $n, $k),
Expand All @@ -336,8 +443,27 @@ def syrk : CallBlasPattern<(Op $layout, $uplo, $trans, $n, $k, $alpha, $A, $lda,
Constant<"1">,
(Shadow $A)
),
(For<"i"> $n,
(BlasCall<"axpy"> $k, (Mul $alpha, (Lookup (Shadow $C), $i, $i)), (Lookup (Concat $A, (ld $A, $trans, $lda, $n, $k)), $i), ConstantInt<1>, (Lookup (Shadow $A), $i), ConstantInt<1>)
(For<"i", 0> $n,
(BlasCall<"axpy">
$k,
(BFMul $alpha, (LoadLookup $layout, (Shadow $C), $i, $i)),
(First
(Lookup $layout,
(Concat $A, (ld $A, $trans, $lda, $n, $k)),
(Rows $trans, $i, ConstantInt<0>),
(Rows $trans, ConstantInt<0>, $i)
)
),
(Rows $trans, (ld $A, $trans, $lda, $n, $k), ConstantInt<1>),
(First
(Lookup $layout,
(Shadow $A),
(Rows $trans, $i, ConstantInt<0>),
(Rows $trans, ConstantInt<0>, $i)
)
),
(Rows $trans, $lda, ConstantInt<1>)
)
)
),
/* beta */ (AssertingInactiveArg),
Expand Down Expand Up @@ -411,11 +537,6 @@ def spr2 : CallBlasPattern<(Op $layout, $uplo, $n, $alpha, $x, $incx, $y, $incy,
//
// // Lv 3
//
// def : CallBlasPattern<(Op $layout, $side, $uplo, $transa, $diag, $m, $n, $alpha, $a, $lda, $b, $ldb),
// ["trmm"],
// [cblas_layout, side, uplo, trans, diag, len, len, fp, vld, vld],
// []
// >;
//
// def : CallBlasPattern<(Op $layout, $side, $uplo, $transa, $diag, $m, $n, $alpha, $a, $lda, $b, $ldb),
// ["trsm"],
Expand Down
10 changes: 7 additions & 3 deletions enzyme/Enzyme/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -566,13 +566,17 @@ class GradientUtils : public CacheUtility {
assert(llvm::cast<llvm::ArrayType>(vals[i]->getType())
->getNumElements() == width);

llvm::Type *wrappedType = llvm::ArrayType::get(diffType, width);
llvm::Value *res = llvm::UndefValue::get(wrappedType);
llvm::Type *wrappedType = diffType->isVoidTy()
? nullptr
: llvm::ArrayType::get(diffType, width);
llvm::Value *res =
diffType->isVoidTy() ? nullptr : llvm::UndefValue::get(wrappedType);
for (unsigned int i = 0; i < getWidth(); ++i) {
auto tup = std::tuple<Args...>{
(args ? extractMeta(Builder, args, i) : nullptr)...};
auto diff = std::apply(rule, std::move(tup));
res = Builder.CreateInsertValue(res, diff, {i});
if (!diffType->isVoidTy())
res = Builder.CreateInsertValue(res, diff, {i});
}
return res;
} else {
Expand Down
65 changes: 11 additions & 54 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2642,8 +2642,8 @@ std::optional<BlasInfo> extractBLAS(llvm::StringRef in)
llvm::Optional<BlasInfo> extractBLAS(llvm::StringRef in)
#endif
{
const char *extractable[] = {"dot", "scal", "axpy", "gemv",
"gemm", "spmv", "syrk"};
const char *extractable[] = {"dot", "scal", "axpy", "gemv", "gemm", "spmv",
"syrk", "nrm2", "trmm", "trmv", "symm"};
const char *floatType[] = {"s", "d"}; // c, z
const char *prefixes[] = {"" /*Fortran*/, "cblas_"};
const char *suffixes[] = {"", "_", "64_", "_64_"};
Expand Down Expand Up @@ -2944,55 +2944,6 @@ llvm::Value *transpose(llvm::IRBuilder<> &B, llvm::Value *V, bool byRef,
"transpose." + name);
}

llvm::Value *trans_to_side(IRBuilder<> &B, llvm::Value *V, bool cublas) {
llvm::Type *T = V->getType();
if (cublas) {
assert(0 && "cublas unknown");
}

auto isn = B.CreateICmpEQ(V, ConstantInt::get(T, 'N'));
auto sel1 =
B.CreateSelect(isn, ConstantInt::get(T, 'L'), ConstantInt::get(T, 'l'));

auto isN = B.CreateICmpEQ(V, ConstantInt::get(T, 't'));
auto sel2 = B.CreateSelect(isN, ConstantInt::get(T, 'r'), sel1);

auto ist = B.CreateICmpEQ(V, ConstantInt::get(T, 'T'));
auto sel3 = B.CreateSelect(ist, ConstantInt::get(T, 'R'), sel2);

return sel3;
}

llvm::Value *trans_to_side(llvm::IRBuilder<> &B, llvm::Value *V, bool byRef,
bool cublas, llvm::IntegerType *julia_decl,
llvm::IRBuilder<> &entryBuilder,
const llvm::Twine &name) {

if (!byRef) {
// Explicitly support 'N' always, since we use in the rule infra
if (auto CI = dyn_cast<ConstantInt>(V)) {
if (CI->getValue() == 'N')
return ConstantInt::get(CI->getType(), 'L');
if (CI->getValue() == 'n')
return ConstantInt::get(CI->getType(), 'l');
if (CI->getValue() == 'T')
return ConstantInt::get(CI->getType(), 'R');
if (CI->getValue() == 't')
return ConstantInt::get(CI->getType(), 'r');
}
}

if (byRef) {
auto charType = IntegerType::get(V->getContext(), 8);
V = B.CreateLoad(charType, V, "ld." + name);
}

V = trans_to_side(B, V, cublas);

return to_blas_callconv(B, V, byRef, cublas, julia_decl, entryBuilder,
"trans_to_side." + name);
}

llvm::Value *load_if_ref(llvm::IRBuilder<> &B, llvm::Type *intType,
llvm::Value *V, bool byRef) {
if (!byRef)
Expand All @@ -3006,8 +2957,6 @@ llvm::Value *load_if_ref(llvm::IRBuilder<> &B, llvm::Type *intType,

SmallVector<llvm::Value *, 1> get_blas_row(llvm::IRBuilder<> &B,
ArrayRef<llvm::Value *> transA,
ArrayRef<llvm::Value *> row,
ArrayRef<llvm::Value *> col,
bool byRef, bool cublas) {
assert(transA.size() == 1);
auto trans = transA[0];
Expand All @@ -3031,10 +2980,18 @@ SmallVector<llvm::Value *, 1> get_blas_row(llvm::IRBuilder<> &B,
// TODO: verify
cond = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 0));
}
return {cond};
}
SmallVector<llvm::Value *, 1> get_blas_row(llvm::IRBuilder<> &B,
ArrayRef<llvm::Value *> transA,
ArrayRef<llvm::Value *> row,
ArrayRef<llvm::Value *> col,
bool byRef, bool cublas) {
auto conds = get_blas_row(B, transA, byRef, cublas);
assert(row.size() == col.size());
SmallVector<Value *, 1> toreturn;
for (size_t i = 0; i < row.size(); i++) {
toreturn.push_back(B.CreateSelect(cond, row[i], col[i]));
toreturn.push_back(B.CreateSelect(conds[0], row[i], col[i]));
}
return toreturn;
}
Expand Down
11 changes: 4 additions & 7 deletions enzyme/Enzyme/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1859,18 +1859,15 @@ llvm::Value *transpose(llvm::IRBuilder<> &B, llvm::Value *V, bool byRef,
bool cublas, llvm::IntegerType *IT,
llvm::IRBuilder<> &entryBuilder,
const llvm::Twine &name);
// first one assume V is an Integer
llvm::Value *trans_to_side(llvm::IRBuilder<> &B, llvm::Value *V, bool cublas);
// secon one assume V is an Integer or a ptr to an int (depends on byRef)
llvm::Value *trans_to_side(llvm::IRBuilder<> &B, llvm::Value *V, bool byRef,
bool cublas, llvm::IntegerType *IT,
llvm::IRBuilder<> &entryBuilder,
const llvm::Twine &name);
llvm::SmallVector<llvm::Value *, 1>
get_blas_row(llvm::IRBuilder<> &B, llvm::ArrayRef<llvm::Value *> trans,
llvm::ArrayRef<llvm::Value *> row,
llvm::ArrayRef<llvm::Value *> col, bool byRef, bool cublas);

llvm::SmallVector<llvm::Value *, 1>
get_blas_row(llvm::IRBuilder<> &B, llvm::ArrayRef<llvm::Value *> trans,
bool byRef, bool cublas);

#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wunused-variable"
Expand Down
Loading

0 comments on commit 46647d1

Please sign in to comment.