diff --git a/src/mlir/Dialects/Enzyme.jl b/src/mlir/Dialects/Enzyme.jl index 405824dc7b..42569f0784 100755 --- a/src/mlir/Dialects/Enzyme.jl +++ b/src/mlir/Dialects/Enzyme.jl @@ -505,6 +505,25 @@ function initTrace(; trace::IR.Type, location=Location()) ) end +function load(cache::Value, indices::Vector{Value}; result::IR.Type, location=Location()) + op_ty_results = IR.Type[result,] + operands = Value[cache, indices...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "enzyme.load", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + function placeholder(; output::IR.Type, location=Location()) op_ty_results = IR.Type[output,] operands = Value[] @@ -654,6 +673,25 @@ function simulate( ) end +function store(value::Value, cache::Value, indices::Vector{Value}; location=Location()) + op_ty_results = IR.Type[] + operands = Value[value, cache, indices...] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "enzyme.store", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + """ `untracedCall` diff --git a/src/mlir/Dialects/MemRef.jl b/src/mlir/Dialects/MemRef.jl index d4a43622cc..4d361376b2 100755 --- a/src/mlir/Dialects/MemRef.jl +++ b/src/mlir/Dialects/MemRef.jl @@ -19,7 +19,7 @@ import ...API The `assume_alignment` operation takes a memref and an integer alignment value. It returns a new SSA value of the same memref type, but associated with the assumption that the underlying buffer is aligned to the given -alignment. +alignment. If the buffer isn\'t aligned to the given alignment, its result is poison. This operation doesn\'t affect the semantics of a program where the @@ -151,6 +151,50 @@ function copy(source::Value, target::Value; location=Location()) ) end +""" +`distinct_objects` + +The `distinct_objects` operation takes a list of memrefs and returns the same +memrefs, with the additional assumption that accesses to them will never +alias with each other. This means that loads and stores to different +memrefs in the list can be safely reordered. + +If the memrefs do alias, the load/store behavior is undefined. This +operation doesn\'t affect the semantics of a valid program. It is +intended for optimization purposes, allowing the compiler to generate more +efficient code based on the non-aliasing assumption. The optimization is +best-effort. + +# Example + +```mlir +%1, %2 = memref.distinct_objects %a, %b : memref, memref +``` +""" +function distinct_objects( + operands::Vector{Value}; + results=nothing::Union{Nothing,Vector{IR.Type}}, + location=Location(), +) + op_ty_results = IR.Type[] + operands = Value[operands...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(results) && push!(op_ty_results, results...) + + return create_operation( + "memref.distinct_objects", + location; + operands, + owned_regions, + successors, + attributes, + results=(length(op_ty_results) == 0 ? nothing : op_ty_results), + result_inference=(length(op_ty_results) == 0 ? true : false), + ) +end + """ `generic_atomic_rmw` @@ -1115,6 +1159,10 @@ The input and result must have the same shape, element type, rank, and layout. If the source and target address spaces are the same, this operation is a noop. +Finally, if the target memory-space is the generic/default memory-space, +then it is assumed this cast can be bubbled down safely. See the docs of +`MemorySpaceCastOpInterface` interface for more details. + # Example ```mlir diff --git a/src/mlir/Dialects/Nvvm.jl b/src/mlir/Dialects/Nvvm.jl index ae71d73503..88954d7613 100755 --- a/src/mlir/Dialects/Nvvm.jl +++ b/src/mlir/Dialects/Nvvm.jl @@ -693,6 +693,86 @@ function read_ptx_sreg_clusterid_z(; res::IR.Type, range=nothing, location=Locat ) end +""" +`clusterlaunchcontrol_query_cancel` + +`clusterlaunchcontrol.query.cancel` queries the response of a +`clusterlaunchcontrol.try.cancel` operation specified by operand +`try_cancel_response`. + +Operand `query_type` specifies the type of query to perform and can be one +of the following: +- `is_canceled` : Returns true if the try cancel request succeeded, +and false otherwise. +- `get_first_cta_id_{x/y/z}` : Returns the x, y, or z coordinate of the +first CTA in the canceled cluster. Behaviour is defined only if the try +cancel request succeeded. + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-query-cancel) +""" +function clusterlaunchcontrol_query_cancel( + try_cancel_response::Value; res::IR.Type, query_type, location=Location() +) + op_ty_results = IR.Type[res,] + operands = Value[try_cancel_response,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("query_type", query_type),] + + return create_operation( + "nvvm.clusterlaunchcontrol.query.cancel", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +""" +`clusterlaunchcontrol_try_cancel` + +`clusterlaunchcontrol.try.cancel` requests atomically canceling the launch +of a cluster that has not started running yet. It asynchronously writes an +opaque response to shared memory indicating whether the operation succeeded +or failed. + +Operand `smemAddress` specifies the naturally aligned address of the +16-byte wide shared memory location where the request\'s response is written. + +Operand `mbarrier` specifies the mbarrier object used to track the +completion of the asynchronous operation. + +If `multicast` is specified, the response is asynchronously written to the +corresponding local shared memory location (specifed by `addr`) of each CTA +in the requesting cluster. + +[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-try-cancel) +""" +function clusterlaunchcontrol_try_cancel( + smemAddress::Value, mbarrier::Value; multicast=nothing, location=Location() +) + op_ty_results = IR.Type[] + operands = Value[smemAddress, mbarrier] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + !isnothing(multicast) && push!(attributes, namedattribute("multicast", multicast)) + + return create_operation( + "nvvm.clusterlaunchcontrol.try.cancel", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + """ `cluster_wait` @@ -741,13 +821,13 @@ respectively. [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) """ function convert_bf16x2_to_f8x2( - a::Value; dst::IR.Type, type, rnd=nothing, sat=nothing, location=Location() + a::Value; dst::IR.Type, rnd=nothing, sat=nothing, dstTy, location=Location() ) op_ty_results = IR.Type[dst,] operands = Value[a,] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("type", type),] + attributes = NamedAttribute[namedattribute("dstTy", dstTy),] !isnothing(rnd) && push!(attributes, namedattribute("rnd", rnd)) !isnothing(sat) && push!(attributes, namedattribute("sat", sat)) @@ -782,13 +862,13 @@ the cvt instruction. [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) """ function convert_f16x2_to_f8x2( - a::Value; dst::IR.Type, type, relu=nothing, location=Location() + a::Value; dst::IR.Type, relu=nothing, dstTy, location=Location() ) op_ty_results = IR.Type[dst,] operands = Value[a,] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("type", type),] + attributes = NamedAttribute[namedattribute("dstTy", dstTy),] !isnothing(relu) && push!(attributes, namedattribute("relu", relu)) return create_operation( @@ -821,13 +901,13 @@ the cvt instruction. [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) """ function convert_f32x2_to_f6x2( - a::Value, b::Value; dst::IR.Type, type, relu=nothing, location=Location() + a::Value, b::Value; dst::IR.Type, relu=nothing, dstTy, location=Location() ) op_ty_results = IR.Type[dst,] operands = Value[a, b] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("type", type),] + attributes = NamedAttribute[namedattribute("dstTy", dstTy),] !isnothing(relu) && push!(attributes, namedattribute("relu", relu)) return create_operation( @@ -863,17 +943,17 @@ function convert_f32x2_to_f8x2( a::Value, b::Value; dst::IR.Type, - type, rnd=nothing, sat=nothing, relu=nothing, + dstTy, location=Location(), ) op_ty_results = IR.Type[dst,] operands = Value[a, b] owned_regions = Region[] successors = Block[] - attributes = NamedAttribute[namedattribute("type", type),] + attributes = NamedAttribute[namedattribute("dstTy", dstTy),] !isnothing(rnd) && push!(attributes, namedattribute("rnd", rnd)) !isnothing(sat) && push!(attributes, namedattribute("sat", sat)) !isnothing(relu) && push!(attributes, namedattribute("relu", relu)) @@ -1148,16 +1228,8 @@ end `cp_async_bulk_tensor_shared_cluster_global` Initiates an asynchronous copy operation on the tensor data from global -memory to shared memory. - -The Op operates has two load modes: -1) Tiled Mode: It\'s the default mode. The source multi-dimensional tensor -layout is preserved at the destination. - -2) Im2col Mode: This mode is used when `im2colOffsets` operands are present. -the elements in the Bounding Box of the source tensor are rearranged into -columns at the destination. In this mode, the tensor has to be at least -3-dimensional. +memory to shared::cluster (or) shared::cta memory. This Op supports all +the load modes specified in `TMALoadMode`. The `multicastMask` operand is optional. When it is present, the Op copies data from global memory to shared memory of multiple CTAs in the cluster. @@ -1168,6 +1240,10 @@ the `nvvm.read.ptx.sreg.ctaid` of the destination CTA. The `l2CacheHint` operand is optional, and it is used to specify cache eviction policy that may be used during the memory access. +When the `isCTAOnly` attribute is set to true, the destination is +shared::cta only. Hence, `multicastMask` and `CTAGroup` are not applicable +when `isCTAOnly` is true. + [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor) """ function cp_async_bulk_tensor_shared_cluster_global( @@ -1179,6 +1255,9 @@ function cp_async_bulk_tensor_shared_cluster_global( multicastMask=nothing::Union{Nothing,Value}; l2CacheHint=nothing::Union{Nothing,Value}, predicate=nothing::Union{Nothing,Value}, + mode=nothing, + isCTAOnly=nothing, + group=nothing, location=Location(), ) op_ty_results = IR.Type[] @@ -1202,6 +1281,9 @@ function cp_async_bulk_tensor_shared_cluster_global( (predicate == nothing) ? 0 : 1, ]), ) + !isnothing(mode) && push!(attributes, namedattribute("mode", mode)) + !isnothing(isCTAOnly) && push!(attributes, namedattribute("isCTAOnly", isCTAOnly)) + !isnothing(group) && push!(attributes, namedattribute("group", group)) return create_operation( "nvvm.cp.async.bulk.tensor.shared.cluster.global", diff --git a/src/mlir/Dialects/Shardy.jl b/src/mlir/Dialects/Shardy.jl index c2c3bfd769..23d7837a80 100755 --- a/src/mlir/Dialects/Shardy.jl +++ b/src/mlir/Dialects/Shardy.jl @@ -80,6 +80,7 @@ affect the order of the corresponding replica groups. **Constraints:** - Must satisfy the constraints listed in `Sdy_CollectiveOpInterface`. - `reduction_axes` must satisfy the constraints listed in `AxisRefListAttr`. +- `reduction_axes` must be sorted w.r.t. the mesh. - The operand sharding and `out_sharding` must have equivalent dimension shardings. - `reduction_axes` must not overlap with the operand dimension sharding and diff --git a/src/mlir/Dialects/TPU.jl b/src/mlir/Dialects/TPU.jl index 056b1241d4..4bdeecf7e1 100755 --- a/src/mlir/Dialects/TPU.jl +++ b/src/mlir/Dialects/TPU.jl @@ -1361,6 +1361,48 @@ function shuffled_store( ) end +function stochastic_convert_elementwise( + input::Value, random::Value; output::IR.Type, dst_type, location=Location() +) + op_ty_results = IR.Type[output,] + operands = Value[input, random] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("dst_type", dst_type),] + + return create_operation( + "tpu.stochastic_convert_elementwise", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + +function stochastic_convert( + input::Value, random::Value; output::IR.Type, location=Location() +) + op_ty_results = IR.Type[output,] + operands = Value[input, random] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[] + + return create_operation( + "tpu.stochastic_convert", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + function store( valueToStore::Value, base::Value, @@ -1564,7 +1606,12 @@ function truncf(in::Value; out::IR.Type, rounding_mode, location=Location()) end function unpack_subelements( - source::Value; output::IR.Type, index, pack_format, location=Location() + source::Value; + output::IR.Type, + index, + pack_format, + sign_extended=nothing, + location=Location(), ) op_ty_results = IR.Type[output,] operands = Value[source,] @@ -1573,6 +1620,8 @@ function unpack_subelements( attributes = NamedAttribute[ namedattribute("index", index), namedattribute("pack_format", pack_format) ] + !isnothing(sign_extended) && + push!(attributes, namedattribute("sign_extended", sign_extended)) return create_operation( "tpu.unpack_subelements", diff --git a/src/mlir/Dialects/Triton.jl b/src/mlir/Dialects/Triton.jl index d8d8fb38a9..0bb374da31 100755 --- a/src/mlir/Dialects/Triton.jl +++ b/src/mlir/Dialects/Triton.jl @@ -1011,6 +1011,7 @@ function make_tensor_descriptor( shape::Vector{Value}, strides::Vector{Value}; result::IR.Type, + padding=nothing, location=Location(), ) op_ty_results = IR.Type[result,] @@ -1018,6 +1019,7 @@ function make_tensor_descriptor( owned_regions = Region[] successors = Block[] attributes = NamedAttribute[] + !isnothing(padding) && push!(attributes, namedattribute("padding", padding)) return create_operation( "tt.make_tensor_descriptor", diff --git a/src/mlir/libMLIR_h.jl b/src/mlir/libMLIR_h.jl index 6e9e695241..ce9c4c230b 100755 --- a/src/mlir/libMLIR_h.jl +++ b/src/mlir/libMLIR_h.jl @@ -1346,6 +1346,15 @@ function mlirOperationGetLocation(op) @ccall mlir_c.mlirOperationGetLocation(op::MlirOperation)::MlirLocation end +""" + mlirOperationSetLocation(op, loc) + +Sets the location of the operation. +""" +function mlirOperationSetLocation(op, loc) + @ccall mlir_c.mlirOperationSetLocation(op::MlirOperation, loc::MlirLocation)::Cvoid +end + """ mlirOperationGetTypeID(op) @@ -7178,12 +7187,20 @@ end end """ - mlirLLVMDICompileUnitAttrGet(ctx, id, sourceLanguage, file, producer, isOptimized, emissionKind, nameTableKind) + mlirLLVMDICompileUnitAttrGet(ctx, id, sourceLanguage, file, producer, isOptimized, emissionKind, nameTableKind, splitDebugFilename) Creates a LLVM DICompileUnit attribute. """ function mlirLLVMDICompileUnitAttrGet( - ctx, id, sourceLanguage, file, producer, isOptimized, emissionKind, nameTableKind + ctx, + id, + sourceLanguage, + file, + producer, + isOptimized, + emissionKind, + nameTableKind, + splitDebugFilename, ) @ccall mlir_c.mlirLLVMDICompileUnitAttrGet( ctx::MlirContext, @@ -7194,6 +7211,7 @@ function mlirLLVMDICompileUnitAttrGet( isOptimized::Bool, emissionKind::MlirLLVMDIEmissionKind, nameTableKind::MlirLLVMDINameTableKind, + splitDebugFilename::MlirAttribute, )::MlirAttribute end @@ -8982,6 +9000,27 @@ function mlirPassManagerEnableTiming(passManager) @ccall mlir_c.mlirPassManagerEnableTiming(passManager::MlirPassManager)::Cvoid end +""" + MlirPassDisplayMode + +Enumerated type of pass display modes. Mainly used in [`mlirPassManagerEnableStatistics`](@ref). +""" +@cenum MlirPassDisplayMode::UInt32 begin + MLIR_PASS_DISPLAY_MODE_LIST = 0x0000000000000000 + MLIR_PASS_DISPLAY_MODE_PIPELINE = 0x0000000000000001 +end + +""" + mlirPassManagerEnableStatistics(passManager, displayMode) + +Enable pass statistics. +""" +function mlirPassManagerEnableStatistics(passManager, displayMode) + @ccall mlir_c.mlirPassManagerEnableStatistics( + passManager::MlirPassManager, displayMode::MlirPassDisplayMode + )::Cvoid +end + """ mlirPassManagerGetNestedUnder(passManager, operationName) @@ -9167,6 +9206,14 @@ struct MlirRewritePatternSet ptr::Ptr{Cvoid} end +struct MlirPatternRewriter + ptr::Ptr{Cvoid} +end + +struct MlirRewritePattern + ptr::Ptr{Cvoid} +end + """ mlirRewriterBaseGetContext(rewriter) @@ -9258,6 +9305,17 @@ function mlirRewriterBaseGetBlock(rewriter) @ccall mlir_c.mlirRewriterBaseGetBlock(rewriter::MlirRewriterBase)::MlirBlock end +""" + mlirRewriterBaseGetOperationAfterInsertion(rewriter) + +Returns the operation right after the current insertion point of the rewriter. A null [`MlirOperation`](@ref) will be returned +""" +function mlirRewriterBaseGetOperationAfterInsertion(rewriter) + @ccall mlir_c.mlirRewriterBaseGetOperationAfterInsertion( + rewriter::MlirRewriterBase + )::MlirOperation +end + """ mlirRewriterBaseCreateBlockBefore(rewriter, insertBefore, nArgTypes, argTypes, locations) @@ -9583,18 +9641,25 @@ function mlirIRRewriterDestroy(rewriter) end """ - mlirFreezeRewritePattern(op) + mlirFreezeRewritePattern(set) -FrozenRewritePatternSet API +Freeze the given [`MlirRewritePatternSet`](@ref) to a [`MlirFrozenRewritePatternSet`](@ref). Note that the ownership of the input set is transferred into the frozen set after this call. """ -function mlirFreezeRewritePattern(op) +function mlirFreezeRewritePattern(set) @ccall mlir_c.mlirFreezeRewritePattern( - op::MlirRewritePatternSet + set::MlirRewritePatternSet )::MlirFrozenRewritePatternSet end -function mlirFrozenRewritePatternSetDestroy(op) - @ccall mlir_c.mlirFrozenRewritePatternSetDestroy(op::MlirFrozenRewritePatternSet)::Cvoid +""" + mlirFrozenRewritePatternSetDestroy(set) + +Destroy the given [`MlirFrozenRewritePatternSet`](@ref). +""" +function mlirFrozenRewritePatternSetDestroy(set) + @ccall mlir_c.mlirFrozenRewritePatternSetDestroy( + set::MlirFrozenRewritePatternSet + )::Cvoid end function mlirApplyPatternsAndFoldGreedilyWithOp(op, patterns, arg3) @@ -9613,6 +9678,80 @@ function mlirApplyPatternsAndFoldGreedily(op, patterns, arg3) )::MlirLogicalResult end +""" + mlirPatternRewriterAsBase(rewriter) + +Cast the PatternRewriter to a RewriterBase +""" +function mlirPatternRewriterAsBase(rewriter) + @ccall mlir_c.mlirPatternRewriterAsBase(rewriter::MlirPatternRewriter)::MlirRewriterBase +end + +""" + MlirRewritePatternCallbacks + +Callbacks to construct a rewrite pattern. + +| Field | Note | +| :-------------- | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| construct | Optional constructor for the user data. Set to nullptr to disable it. | +| destruct | Optional destructor for the user data. Set to nullptr to disable it. | +| matchAndRewrite | The callback function to match against code rooted at the specified operation, and perform the rewrite if the match is successful, corresponding to RewritePattern::matchAndRewrite. | +""" +struct MlirRewritePatternCallbacks + construct::Ptr{Cvoid} + destruct::Ptr{Cvoid} + matchAndRewrite::Ptr{Cvoid} +end + +""" + mlirOpRewritePattenCreate(rootName, benefit, context, callbacks, userData, nGeneratedNames, generatedNames) + +Create a rewrite pattern that matches the operation with the given rootName, corresponding to mlir::OpRewritePattern. +""" +function mlirOpRewritePattenCreate( + rootName, benefit, context, callbacks, userData, nGeneratedNames, generatedNames +) + @ccall mlir_c.mlirOpRewritePattenCreate( + rootName::MlirStringRef, + benefit::Cuint, + context::MlirContext, + callbacks::MlirRewritePatternCallbacks, + userData::Ptr{Cvoid}, + nGeneratedNames::Csize_t, + generatedNames::Ptr{MlirStringRef}, + )::MlirRewritePattern +end + +""" + mlirRewritePatternSetCreate(context) + +Create an empty [`MlirRewritePatternSet`](@ref). +""" +function mlirRewritePatternSetCreate(context) + @ccall mlir_c.mlirRewritePatternSetCreate(context::MlirContext)::MlirRewritePatternSet +end + +""" + mlirRewritePatternSetDestroy(set) + +Destruct the given [`MlirRewritePatternSet`](@ref). +""" +function mlirRewritePatternSetDestroy(set) + @ccall mlir_c.mlirRewritePatternSetDestroy(set::MlirRewritePatternSet)::Cvoid +end + +""" + mlirRewritePatternSetAdd(set, pattern) + +Add the given [`MlirRewritePattern`](@ref) into a [`MlirRewritePatternSet`](@ref). Note that the ownership of the pattern is transferred to the set after this call. +""" +function mlirRewritePatternSetAdd(set, pattern) + @ccall mlir_c.mlirRewritePatternSetAdd( + set::MlirRewritePatternSet, pattern::MlirRewritePattern + )::Cvoid +end + """ mlirTranslateModuleToSMTLIB(arg1, arg2, userData, inlineSingleUseValues, indentLetBody) @@ -11203,6 +11342,7 @@ struct MlirTpuApplyVectorLayoutContext target_shape::MlirTpuI64TargetTuple mxu_shape::MlirTpuMxuShape max_sublanes_in_scratch::Int64 + shape_invariant_numerics::Bool end function mlirTpuVectorLayoutCreate(bitwidth, offsets, tiling, implicit_dim)