Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 143 additions & 13 deletions src/mlir/Dialects/EnzymeXLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -346,14 +346,13 @@ end
"""
`lapack_geqrf`

This operation computes the QR factorization of a matrix using Householder
reflections. Mathematically, it decomposes A into the product of an
This operation computes the QR factorization of a matrix using Householder
reflections. Mathematically, it decomposes A into the product of an
orthogonal matri x Q and an upper triangular matrix R,
such that A = QR.
such that A = QR.

This operation is modeled after
LAPACK\'s *GEQRF routines, which returns the result in
the QR packed format.
This operation is modeled after LAPACK\'s *GEQRF routines, which returns the
result in the QR packed format.
"""
function lapack_geqrf(
input::Value; output::IR.Type, tau::IR.Type, info::IR.Type, location=Location()
Expand All @@ -379,11 +378,11 @@ end
"""
`lapack_geqrt`

This operation computes the QR factorization of a matrix using Householder
reflections. Mathematically, it decomposes A into the product of an
This operation computes the QR factorization of a matrix using Householder
reflections. Mathematically, it decomposes A into the product of an
orthogonal matrix Q and an upper triangular matrix R, such that A = QR.

This operation is modeled after LAPACK\'s *GEQRT routines, which returns the
This operation is modeled after LAPACK\'s *GEQRT routines, which returns the
result in the QR CompactWY format.
"""
function lapack_geqrt(
Expand Down Expand Up @@ -413,6 +412,90 @@ function lapack_geqrt(
)
end

function lapack_gesdd(
input::Value;
U::IR.Type,
S::IR.Type,
Vt::IR.Type,
info::IR.Type,
full=nothing,
location=Location(),
)
op_ty_results = IR.Type[U, S, Vt, info]
operands = Value[input,]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]
!isnothing(full) && push!(attributes, namedattribute("full", full))

return create_operation(
"enzymexla.lapack.gesdd",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

function lapack_gesvd(
input::Value;
U::IR.Type,
S::IR.Type,
Vt::IR.Type,
info::IR.Type,
full=nothing,
location=Location(),
)
op_ty_results = IR.Type[U, S, Vt, info]
operands = Value[input,]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]
!isnothing(full) && push!(attributes, namedattribute("full", full))

return create_operation(
"enzymexla.lapack.gesvd",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

function lapack_gesvj(
input::Value;
U::IR.Type,
S::IR.Type,
Vt::IR.Type,
info::IR.Type,
full=nothing,
location=Location(),
)
op_ty_results = IR.Type[U, S, Vt, info]
operands = Value[input,]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]
!isnothing(full) && push!(attributes, namedattribute("full", full))

return create_operation(
"enzymexla.lapack.gesvj",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

function get_stream(; result::IR.Type, location=Location())
op_ty_results = IR.Type[result,]
operands = Value[]
Expand All @@ -432,6 +515,51 @@ function get_stream(; result::IR.Type, location=Location())
)
end

function lapack_getrf(
input::Value;
output::IR.Type,
pivots::IR.Type,
permutation::IR.Type,
info::IR.Type,
location=Location(),
)
op_ty_results = IR.Type[output, pivots, permutation, info]
operands = Value[input,]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]

return create_operation(
"enzymexla.lapack.getrf",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

function lapack_getri(input::Value, ipiv::Value; output::IR.Type, location=Location())
op_ty_results = IR.Type[output,]
operands = Value[input, ipiv]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]

return create_operation(
"enzymexla.lapack.getri",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

function jit_call(
inputs::Vector{Value};
result_0::Vector{IR.Type},
Expand Down Expand Up @@ -754,15 +882,15 @@ end
"""
`linalg_qr`

This operation computes the QR factorization of a matrix using Householder
reflections. Mathematically, it decomposes A into the product of an
orthogonal (unitary if complex) matrix Q and an upper triangular matrix R,
This operation computes the QR factorization of a matrix using Householder
reflections. Mathematically, it decomposes A into the product of an
orthogonal (unitary if complex) matrix Q and an upper triangular matrix R,
such that A = QR.

If A has size m x n and m > n, Q is an m x n isometric matrix. If m < n, R
will be a m x n trapezoidal matrix.

This operation is modeled after the mathematical formulation of the QR
This operation is modeled after the mathematical formulation of the QR
factorization, and not after LAPACK\'s compact formats.
"""
function linalg_qr(
Expand Down Expand Up @@ -842,6 +970,7 @@ function linalg_svd(
Vt::IR.Type,
info::IR.Type,
full=nothing,
algorithm=nothing,
location=Location(),
)
op_ty_results = IR.Type[U, S, Vt, info]
Expand All @@ -850,6 +979,7 @@ function linalg_svd(
successors = Block[]
attributes = NamedAttribute[]
!isnothing(full) && push!(attributes, namedattribute("full", full))
!isnothing(algorithm) && push!(attributes, namedattribute("algorithm", algorithm))

return create_operation(
"enzymexla.linalg.svd",
Expand Down
4 changes: 4 additions & 0 deletions src/mlir/libMLIR_h.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11655,6 +11655,10 @@ function enzymexlaQRAlgorithmAttrGet(ctx, mode)
@ccall mlir_c.enzymexlaQRAlgorithmAttrGet(ctx::MlirContext, mode::Int32)::MlirAttribute
end

function enzymexlaSVDAlgorithmAttrGet(ctx, mode)
@ccall mlir_c.enzymexlaSVDAlgorithmAttrGet(ctx::MlirContext, mode::Int32)::MlirAttribute
end

function enzymexlaGeluApproximationAttrGet(ctx, mode)
@ccall mlir_c.enzymexlaGeluApproximationAttrGet(
ctx::MlirContext, mode::Int32
Expand Down
Loading