diff --git a/src/mlir/Dialects/EnzymeXLA.jl b/src/mlir/Dialects/EnzymeXLA.jl index ff55ca0a92..65908921b4 100755 --- a/src/mlir/Dialects/EnzymeXLA.jl +++ b/src/mlir/Dialects/EnzymeXLA.jl @@ -1071,6 +1071,79 @@ function blas_symm( ) end +""" +`blas_syrk` + +C := alpha*A*A^T + beta*C, or C := alpha*A^T*A + beta*C, where alpha and beta are scalars. C must be a n x n symmetric matrix.\" +""" +function blas_syrk( + A::Value, + C::Value, + alpha::Value, + beta::Value; + output::IR.Type, + uplo, + transpose=nothing, + location=Location(), +) + op_ty_results = IR.Type[output,] + operands = Value[A, C, alpha, beta] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("uplo", uplo),] + !isnothing(transpose) && push!(attributes, namedattribute("transpose", transpose)) + + return create_operation( + "enzymexla.blas.syrk", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`blas_trmm` + +B := alpha * op(A) x B, or B := alpha * B x op(A), where alpha is a scalar, +B is a m x n matrix, A is a unit, or non-unit, upper or lower triangular +matrix, and op(A) is one of op(A) = A, or op(A) = A^T or A^H. +""" +function blas_trmm( + A::Value, + B::Value, + alpha::Value; + output::IR.Type, + side, + uplo, + transpose, + location=Location(), +) + op_ty_results = IR.Type[output,] + operands = Value[A, B, alpha] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[ + namedattribute("side", side), + namedattribute("uplo", uplo), + namedattribute("transpose", transpose), + ] + + return create_operation( + "enzymexla.blas.trmm", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + function typeAlign(; result::IR.Type, source, location=Location()) op_ty_results = IR.Type[result,] operands = Value[]