Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 44 additions & 18 deletions src/mlir/Dialects/Arith.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1601,17 +1601,30 @@ broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
that there could be multiple quantization axes. Internally,
`arith.scaling_extf` would perform the following:

```
resultTy = get_type(result)
scaleTy = get_type(scale)
inputTy = get_type(input)
scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
scale.extf = arith.extf(scale.exponent) : f8E8M0 to resultTy
input.extf = arith.extf(input) : inputTy to resultTy
result = arith.mulf(scale.extf, input.extf)
```
It propagates NaN values. Therefore, if either scale or the input element
contains NaN, then the output element value will also be a NaN.
```mlir
// Cast scale to result type.
%0 = arith.truncf %1 : f32 to f8E8M0FNU
%1 = arith.extf %0 : f8E8M0FNU to f16

// Cast input to result type.
%2 = arith.extf %3 : f4E2M1FN to f16

// Perform scaling
%3 = arith.mulf %2, %1 : f16
```
It propagates NaN values. Therefore, if either scale or the input element
contains NaN, then the output element value will also be a NaN.

# Example

```mlir
// Upcast from f4E2M1FN to f32.
%a = arith.scaling_extf %b, %c : f4E2M1FN, f8E8M0FNU to f32

// Element-wise upcast with broadcast (blockSize = 32).
%f = vector.broadcast %g : vector<1xf8E8M0FNU> to vector<32xf8E8M0FNU>
%h = arith.scaling_extf %i, %f : vector<32xf4E2M1FN>, vector<32xf8E8M0FNU> to vector<32xbf16>
```
"""
function scaling_extf(
in::Value, scale::Value; out::IR.Type, fastmath=nothing, location=Location()
Expand Down Expand Up @@ -1662,14 +1675,27 @@ broadcasted to `<dim1 x dim2 x dim3 ... (dimN/blockSize) x blockSize>`. Note
that there could be multiple quantization axes. Internally,
`arith.scaling_truncf` would perform the following:

```mlir
// Cast scale to input type.
%0 = arith.truncf %1 : f32 to f8E8M0FNU
%1 = arith.extf %0 : f8E8M0FNU to f16

// Perform scaling.
%3 = arith.divf %2, %1 : f16

// Cast to result type.
%4 = arith.truncf %3 : f16 to f4E2M1FN
```
scaleTy = get_type(scale)
inputTy = get_type(input)
resultTy = get_type(result)
scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0
scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputTy
result = arith.divf(input, scale.extf)
result.cast = arith.truncf(result, resultTy)

# Example

```mlir
// Downcast from f32 to f4E2M1FN.
%a = arith.scaling_truncf %b, %c : f32, f8E8M0FNU to f4E2M1FN

// Element-wise downcast with broadcast (blockSize = 32).
%f = vector.broadcast %g : vector<1xf8E8M0FNU> to vector<32xf8E8M0FNU>
%h = arith.scaling_truncf %i, %f : vector<32xbf16>, vector<32xf8E8M0FNU> to vector<32xf4E2M1FN>
```
"""
function scaling_truncf(
Expand Down
71 changes: 66 additions & 5 deletions src/mlir/Dialects/EnzymeXLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,27 @@ function barrier(indices::Vector{Value}; location=Location())
)
end

function cacheload(
memref::Value, indices::Vector{Value}; result::IR.Type, location=Location()
)
op_ty_results = IR.Type[result,]
operands = Value[memref, indices...]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]

return create_operation(
"enzymexla.cacheload",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

function comm_region(; result_0::Vector{IR.Type}, body::Region, location=Location())
op_ty_results = IR.Type[result_0...,]
operands = Value[]
Expand Down Expand Up @@ -240,8 +261,8 @@ end
`gpu_wrapper`

The optional arguments to this operation are suggestions about what block
dimensions this gpu kernel should have - usually taken from kernel launch
params
dimensions this gpu kernel should have - usually taken f rom kernel
launch params
"""
function gpu_wrapper(
blockDims::Vector{Value}; result::IR.Type, region::Region, location=Location()
Expand Down Expand Up @@ -327,10 +348,12 @@ end

This operation computes the QR factorization of a matrix using Householder
reflections. Mathematically, it decomposes A into the product of an
orthogonal matrix Q and an upper triangular matrix R, such that A = QR.
orthogonal matri x Q and an upper triangular matrix R,
such that A = QR.

This operation is modeled after LAPACK\'s *GEQRF routines, which returns the
result in the QR packed format.
This operation is modeled after
LAPACK\'s *GEQRF routines, which returns the result in
the QR packed format.
"""
function lapack_geqrf(
input::Value; output::IR.Type, tau::IR.Type, info::IR.Type, location=Location()
Expand Down Expand Up @@ -859,6 +882,25 @@ function stream2token(source::Value; result::IR.Type, location=Location())
)
end

function subindex(source::Value, index::Value; result::IR.Type, location=Location())
op_ty_results = IR.Type[result,]
operands = Value[source, index]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]

return create_operation(
"enzymexla.subindex",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

"""
`lapack_symm`

Expand Down Expand Up @@ -893,6 +935,25 @@ function lapack_symm(
)
end

function typeAlign(; result::IR.Type, source, location=Location())
op_ty_results = IR.Type[result,]
operands = Value[]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[namedattribute("source", source),]

return create_operation(
"enzymexla.typeAlign",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

function wrap(
operand::Value;
result=nothing::Union{Nothing,IR.Type},
Expand Down
2 changes: 1 addition & 1 deletion src/mlir/Dialects/Gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -995,7 +995,7 @@ end
This operation provides a memref pointer to the start of dynamic shared
memory, often referred to as workgroup memory. It\'s important to note that
this dynamic shared memory needs to be allocated at kernel launch. One can
conveniently utilize `the dynamic_shared_memory_size` parameter of
conveniently utilize the `dynamic_shared_memory_size` parameter of
`gpu.launch` for this purpose.

Examples:
Expand Down
5 changes: 2 additions & 3 deletions src/mlir/Dialects/Llvm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1949,7 +1949,6 @@ function func(;
reciprocal_estimates=nothing,
prefer_vector_width=nothing,
target_features=nothing,
unsafe_fp_math=nothing,
no_infs_fp_math=nothing,
no_nans_fp_math=nothing,
no_signed_zeros_fp_math=nothing,
Expand All @@ -1960,6 +1959,7 @@ function func(;
instrument_function_exit=nothing,
no_inline=nothing,
always_inline=nothing,
inline_hint=nothing,
no_unwind=nothing,
will_return=nothing,
optimize_none=nothing,
Expand Down Expand Up @@ -2026,8 +2026,6 @@ function func(;
push!(attributes, namedattribute("prefer_vector_width", prefer_vector_width))
!isnothing(target_features) &&
push!(attributes, namedattribute("target_features", target_features))
!isnothing(unsafe_fp_math) &&
push!(attributes, namedattribute("unsafe_fp_math", unsafe_fp_math))
!isnothing(no_infs_fp_math) &&
push!(attributes, namedattribute("no_infs_fp_math", no_infs_fp_math))
!isnothing(no_nans_fp_math) &&
Expand All @@ -2050,6 +2048,7 @@ function func(;
!isnothing(no_inline) && push!(attributes, namedattribute("no_inline", no_inline))
!isnothing(always_inline) &&
push!(attributes, namedattribute("always_inline", always_inline))
!isnothing(inline_hint) && push!(attributes, namedattribute("inline_hint", inline_hint))
!isnothing(no_unwind) && push!(attributes, namedattribute("no_unwind", no_unwind))
!isnothing(will_return) && push!(attributes, namedattribute("will_return", will_return))
!isnothing(optimize_none) &&
Expand Down
Loading