diff --git a/src/mlir/Dialects/Arith.jl b/src/mlir/Dialects/Arith.jl index e4fab6b297..bf17d15995 100755 --- a/src/mlir/Dialects/Arith.jl +++ b/src/mlir/Dialects/Arith.jl @@ -1601,17 +1601,30 @@ broadcasted to ``. Note that there could be multiple quantization axes. Internally, `arith.scaling_extf` would perform the following: - ``` - resultTy = get_type(result) - scaleTy = get_type(scale) - inputTy = get_type(input) - scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0 - scale.extf = arith.extf(scale.exponent) : f8E8M0 to resultTy - input.extf = arith.extf(input) : inputTy to resultTy - result = arith.mulf(scale.extf, input.extf) - ``` - It propagates NaN values. Therefore, if either scale or the input element - contains NaN, then the output element value will also be a NaN. +```mlir +// Cast scale to result type. +%0 = arith.truncf %1 : f32 to f8E8M0FNU +%1 = arith.extf %0 : f8E8M0FNU to f16 + +// Cast input to result type. +%2 = arith.extf %3 : f4E2M1FN to f16 + +// Perform scaling +%3 = arith.mulf %2, %1 : f16 +``` +It propagates NaN values. Therefore, if either scale or the input element +contains NaN, then the output element value will also be a NaN. + +# Example + +```mlir +// Upcast from f4E2M1FN to f32. +%a = arith.scaling_extf %b, %c : f4E2M1FN, f8E8M0FNU to f32 + +// Element-wise upcast with broadcast (blockSize = 32). +%f = vector.broadcast %g : vector<1xf8E8M0FNU> to vector<32xf8E8M0FNU> +%h = arith.scaling_extf %i, %f : vector<32xf4E2M1FN>, vector<32xf8E8M0FNU> to vector<32xbf16> +``` """ function scaling_extf( in::Value, scale::Value; out::IR.Type, fastmath=nothing, location=Location() @@ -1662,14 +1675,27 @@ broadcasted to ``. Note that there could be multiple quantization axes. Internally, `arith.scaling_truncf` would perform the following: +```mlir +// Cast scale to input type. +%0 = arith.truncf %1 : f32 to f8E8M0FNU +%1 = arith.extf %0 : f8E8M0FNU to f16 + +// Perform scaling. +%3 = arith.divf %2, %1 : f16 + +// Cast to result type. +%4 = arith.truncf %3 : f16 to f4E2M1FN ``` -scaleTy = get_type(scale) -inputTy = get_type(input) -resultTy = get_type(result) -scale.exponent = arith.truncf(scale) : scaleTy to f8E8M0 -scale.extf = arith.extf(scale.exponent) : f8E8M0 to inputTy -result = arith.divf(input, scale.extf) -result.cast = arith.truncf(result, resultTy) + +# Example + +```mlir +// Downcast from f32 to f4E2M1FN. +%a = arith.scaling_truncf %b, %c : f32, f8E8M0FNU to f4E2M1FN + +// Element-wise downcast with broadcast (blockSize = 32). +%f = vector.broadcast %g : vector<1xf8E8M0FNU> to vector<32xf8E8M0FNU> +%h = arith.scaling_truncf %i, %f : vector<32xbf16>, vector<32xf8E8M0FNU> to vector<32xf4E2M1FN> ``` """ function scaling_truncf( diff --git a/src/mlir/Dialects/EnzymeXLA.jl b/src/mlir/Dialects/EnzymeXLA.jl index ca507515a8..4de6d8a13f 100755 --- a/src/mlir/Dialects/EnzymeXLA.jl +++ b/src/mlir/Dialects/EnzymeXLA.jl @@ -72,6 +72,27 @@ function barrier(indices::Vector{Value}; location=Location()) ) end +function cacheload( + memref::Value, indices::Vector{Value}; result::IR.Type, location=Location() +) + op_ty_results = IR.Type[result,] + operands = Value[memref, indices...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "enzymexla.cacheload", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + function comm_region(; result_0::Vector{IR.Type}, body::Region, location=Location()) op_ty_results = IR.Type[result_0...,] operands = Value[] @@ -240,8 +261,8 @@ end `gpu_wrapper` The optional arguments to this operation are suggestions about what block -dimensions this gpu kernel should have - usually taken from kernel launch -params +dimensions this gpu kernel should have - usually taken f rom kernel + launch params """ function gpu_wrapper( blockDims::Vector{Value}; result::IR.Type, region::Region, location=Location() @@ -327,10 +348,12 @@ end This operation computes the QR factorization of a matrix using Householder reflections. Mathematically, it decomposes A into the product of an -orthogonal matrix Q and an upper triangular matrix R, such that A = QR. +orthogonal matri x Q and an upper triangular matrix R, + such that A = QR. -This operation is modeled after LAPACK\'s *GEQRF routines, which returns the -result in the QR packed format. + This operation is modeled after + LAPACK\'s *GEQRF routines, which returns the result in + the QR packed format. """ function lapack_geqrf( input::Value; output::IR.Type, tau::IR.Type, info::IR.Type, location=Location() @@ -859,6 +882,25 @@ function stream2token(source::Value; result::IR.Type, location=Location()) ) end +function subindex(source::Value, index::Value; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[source, index] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "enzymexla.subindex", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + """ `lapack_symm` @@ -893,6 +935,25 @@ function lapack_symm( ) end +function typeAlign(; result::IR.Type, source, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("source", source),] + + return create_operation( + "enzymexla.typeAlign", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + function wrap( operand::Value; result=nothing::Union{Nothing,IR.Type}, diff --git a/src/mlir/Dialects/Gpu.jl b/src/mlir/Dialects/Gpu.jl index e0bb4c04ca..b39a5c5890 100755 --- a/src/mlir/Dialects/Gpu.jl +++ b/src/mlir/Dialects/Gpu.jl @@ -995,7 +995,7 @@ end This operation provides a memref pointer to the start of dynamic shared memory, often referred to as workgroup memory. It\'s important to note that this dynamic shared memory needs to be allocated at kernel launch. One can -conveniently utilize `the dynamic_shared_memory_size` parameter of +conveniently utilize the `dynamic_shared_memory_size` parameter of `gpu.launch` for this purpose. Examples: diff --git a/src/mlir/Dialects/Llvm.jl b/src/mlir/Dialects/Llvm.jl index 9afa56daad..2cf3900080 100755 --- a/src/mlir/Dialects/Llvm.jl +++ b/src/mlir/Dialects/Llvm.jl @@ -1949,7 +1949,6 @@ function func(; reciprocal_estimates=nothing, prefer_vector_width=nothing, target_features=nothing, - unsafe_fp_math=nothing, no_infs_fp_math=nothing, no_nans_fp_math=nothing, no_signed_zeros_fp_math=nothing, @@ -1960,6 +1959,7 @@ function func(; instrument_function_exit=nothing, no_inline=nothing, always_inline=nothing, + inline_hint=nothing, no_unwind=nothing, will_return=nothing, optimize_none=nothing, @@ -2026,8 +2026,6 @@ function func(; push!(attributes, namedattribute("prefer_vector_width", prefer_vector_width)) !isnothing(target_features) && push!(attributes, namedattribute("target_features", target_features)) - !isnothing(unsafe_fp_math) && - push!(attributes, namedattribute("unsafe_fp_math", unsafe_fp_math)) !isnothing(no_infs_fp_math) && push!(attributes, namedattribute("no_infs_fp_math", no_infs_fp_math)) !isnothing(no_nans_fp_math) && @@ -2050,6 +2048,7 @@ function func(; !isnothing(no_inline) && push!(attributes, namedattribute("no_inline", no_inline)) !isnothing(always_inline) && push!(attributes, namedattribute("always_inline", always_inline)) + !isnothing(inline_hint) && push!(attributes, namedattribute("inline_hint", inline_hint)) !isnothing(no_unwind) && push!(attributes, namedattribute("no_unwind", no_unwind)) !isnothing(will_return) && push!(attributes, namedattribute("will_return", will_return)) !isnothing(optimize_none) && diff --git a/src/mlir/Dialects/Nvvm.jl b/src/mlir/Dialects/Nvvm.jl index 88954d7613..26e0e95654 100755 --- a/src/mlir/Dialects/Nvvm.jl +++ b/src/mlir/Dialects/Nvvm.jl @@ -843,6 +843,134 @@ function convert_bf16x2_to_f8x2( ) end +""" +`convert_f4x2_to_f16x2` + +This Op converts the given f4 inputs in a packed i8 to f16. + +The result `dst` is represented as a vector of f16 elements. +The `relu` attribute, when set, lowers to the \'.relu\' variant of +the cvt instruction.\" + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) +""" +function convert_f4x2_to_f16x2( + src::Value; dst::IR.Type, relu=nothing, srcType, location=Location() +) + op_ty_results = IR.Type[dst,] + operands = Value[src,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("srcType", srcType),] + !isnothing(relu) && push!(attributes, namedattribute("relu", relu)) + + return create_operation( + "nvvm.convert.f4x2.to.f16x2", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`convert_f6x2_to_f16x2` + +This Op converts the given f6 inputs in a i8x2 vector to f16. + +The result `dst` is represented as a vector of f16 elements. +The `relu` attribute, when set, lowers to the \'.relu\' variant of +the cvt instruction.\" + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) +""" +function convert_f6x2_to_f16x2( + src::Value; dst::IR.Type, relu=nothing, srcType, location=Location() +) + op_ty_results = IR.Type[dst,] + operands = Value[src,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("srcType", srcType),] + !isnothing(relu) && push!(attributes, namedattribute("relu", relu)) + + return create_operation( + "nvvm.convert.f6x2.to.f16x2", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`convert_f8x2_to_bf16x2` + +This Op converts the given f8 inputs in a i8x2 vector to bf16. + +The result `dst` is represented as a vector of bf16 elements. + + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) +""" +function convert_f8x2_to_bf16x2(src::Value; dst::IR.Type, srcType, location=Location()) + op_ty_results = IR.Type[dst,] + operands = Value[src,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("srcType", srcType),] + + return create_operation( + "nvvm.convert.f8x2.to.bf16x2", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`convert_f8x2_to_f16x2` + +This Op converts the given f8 inputs in a i8x2 vector to f16. + +The result `dst` is represented as a vector of f16 elements. +The `relu` attribute, when set, lowers to the \'.relu\' variant of +the cvt instruction.\" + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) +""" +function convert_f8x2_to_f16x2( + src::Value; dst::IR.Type, relu=nothing, srcType, location=Location() +) + op_ty_results = IR.Type[dst,] + operands = Value[src,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("srcType", srcType),] + !isnothing(relu) && push!(attributes, namedattribute("relu", relu)) + + return create_operation( + "nvvm.convert.f8x2.to.f16x2", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + """ `convert_f16x2_to_f8x2` @@ -883,6 +1011,41 @@ function convert_f16x2_to_f8x2( ) end +""" +`convert_f32x2_to_f4x2` + +This Op converts each of the given float inputs to the specified fp4 type. +The result `dst` is returned as an i8 type where the converted values are +packed such that the value converted from `a` is stored in the upper 4 bits +of `dst` and the value converted from `b` is stored in the lower 4 bits of +`dst`. +The `relu` attribute, when set, lowers to the \'.relu\' variant of +the cvt instruction. + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) +""" +function convert_f32x2_to_f4x2( + a::Value, b::Value; dst::IR.Type, relu=nothing, dstTy, location=Location() +) + op_ty_results = IR.Type[dst,] + operands = Value[a, b] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dstTy", dstTy),] + !isnothing(relu) && push!(attributes, namedattribute("relu", relu)) + + return create_operation( + "nvvm.convert.f32x2.to.f4x2", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + """ `convert_f32x2_to_f6x2` diff --git a/src/mlir/Dialects/TPU.jl b/src/mlir/Dialects/TPU.jl index 4bdeecf7e1..6e1cca871b 100755 --- a/src/mlir/Dialects/TPU.jl +++ b/src/mlir/Dialects/TPU.jl @@ -70,6 +70,21 @@ function assume_layout(input::Value; result::IR.Type, location=Location()) ) end +""" +`assume_multiple` + +This operation is a hint to the compiler that the input `value` is guaranteed +to be a multiple of `multiple`. This can be used to satisfy divisibility checks +in some compiler passes. + +The result is the same as the input `value`. + +# Example + +```mlir +%val = tpu.assume_multiple %arg0, 16 : index +``` +""" function assume_multiple( value::Value; result=nothing::Union{Nothing,IR.Type}, multiple, location=Location() ) diff --git a/src/mlir/libMLIR_h.jl b/src/mlir/libMLIR_h.jl index ce9c4c230b..33c37243a7 100755 --- a/src/mlir/libMLIR_h.jl +++ b/src/mlir/libMLIR_h.jl @@ -9705,14 +9705,14 @@ struct MlirRewritePatternCallbacks end """ - mlirOpRewritePattenCreate(rootName, benefit, context, callbacks, userData, nGeneratedNames, generatedNames) + mlirOpRewritePatternCreate(rootName, benefit, context, callbacks, userData, nGeneratedNames, generatedNames) Create a rewrite pattern that matches the operation with the given rootName, corresponding to mlir::OpRewritePattern. """ -function mlirOpRewritePattenCreate( +function mlirOpRewritePatternCreate( rootName, benefit, context, callbacks, userData, nGeneratedNames, generatedNames ) - @ccall mlir_c.mlirOpRewritePattenCreate( + @ccall mlir_c.mlirOpRewritePatternCreate( rootName::MlirStringRef, benefit::Cuint, context::MlirContext, @@ -10007,6 +10007,10 @@ function mlirTranslateModuleToLLVMIR(_module, context) )::LLVMModuleRef end +function mlirTranslateModuleToLLVMIRToString(_module) + @ccall mlir_c.mlirTranslateModuleToLLVMIRToString(_module::MlirOperation)::Cstring +end + struct MlirTypeFromLLVMIRTranslator ptr::Ptr{Cvoid} end