Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/mlir/Dialects/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,25 @@ function initTrace(; trace::IR.Type, location=Location())
)
end

function load(cache::Value, indices::Vector{Value}; result::IR.Type, location=Location())
op_ty_results = IR.Type[result,]
operands = Value[cache, indices...]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]

return create_operation(
"enzyme.load",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

function placeholder(; output::IR.Type, location=Location())
op_ty_results = IR.Type[output,]
operands = Value[]
Expand Down Expand Up @@ -654,6 +673,25 @@ function simulate(
)
end

function store(value::Value, cache::Value, indices::Vector{Value}; location=Location())
op_ty_results = IR.Type[]
operands = Value[value, cache, indices...]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]

return create_operation(
"enzyme.store",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

"""
`untracedCall`

Expand Down
50 changes: 49 additions & 1 deletion src/mlir/Dialects/MemRef.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import ...API
The `assume_alignment` operation takes a memref and an integer alignment
value. It returns a new SSA value of the same memref type, but associated
with the assumption that the underlying buffer is aligned to the given
alignment.
alignment.

If the buffer isn\'t aligned to the given alignment, its result is poison.
This operation doesn\'t affect the semantics of a program where the
Expand Down Expand Up @@ -151,6 +151,50 @@ function copy(source::Value, target::Value; location=Location())
)
end

"""
`distinct_objects`

The `distinct_objects` operation takes a list of memrefs and returns the same
memrefs, with the additional assumption that accesses to them will never
alias with each other. This means that loads and stores to different
memrefs in the list can be safely reordered.

If the memrefs do alias, the load/store behavior is undefined. This
operation doesn\'t affect the semantics of a valid program. It is
intended for optimization purposes, allowing the compiler to generate more
efficient code based on the non-aliasing assumption. The optimization is
best-effort.

# Example

```mlir
%1, %2 = memref.distinct_objects %a, %b : memref<?xf32>, memref<?xf32>
```
"""
function distinct_objects(
operands::Vector{Value};
results=nothing::Union{Nothing,Vector{IR.Type}},
location=Location(),
)
op_ty_results = IR.Type[]
operands = Value[operands...,]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]
!isnothing(results) && push!(op_ty_results, results...)

return create_operation(
"memref.distinct_objects",
location;
operands,
owned_regions,
successors,
attributes,
results=(length(op_ty_results) == 0 ? nothing : op_ty_results),
result_inference=(length(op_ty_results) == 0 ? true : false),
)
end

"""
`generic_atomic_rmw`

Expand Down Expand Up @@ -1115,6 +1159,10 @@ The input and result must have the same shape, element type, rank, and layout.

If the source and target address spaces are the same, this operation is a noop.

Finally, if the target memory-space is the generic/default memory-space,
then it is assumed this cast can be bubbled down safely. See the docs of
`MemorySpaceCastOpInterface` interface for more details.

# Example

```mlir
Expand Down
118 changes: 100 additions & 18 deletions src/mlir/Dialects/Nvvm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,86 @@ function read_ptx_sreg_clusterid_z(; res::IR.Type, range=nothing, location=Locat
)
end

"""
`clusterlaunchcontrol_query_cancel`

`clusterlaunchcontrol.query.cancel` queries the response of a
`clusterlaunchcontrol.try.cancel` operation specified by operand
`try_cancel_response`.

Operand `query_type` specifies the type of query to perform and can be one
of the following:
- `is_canceled` : Returns true if the try cancel request succeeded,
and false otherwise.
- `get_first_cta_id_{x/y/z}` : Returns the x, y, or z coordinate of the
first CTA in the canceled cluster. Behaviour is defined only if the try
cancel request succeeded.

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-query-cancel)
"""
function clusterlaunchcontrol_query_cancel(
try_cancel_response::Value; res::IR.Type, query_type, location=Location()
)
op_ty_results = IR.Type[res,]
operands = Value[try_cancel_response,]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[namedattribute("query_type", query_type),]

return create_operation(
"nvvm.clusterlaunchcontrol.query.cancel",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

"""
`clusterlaunchcontrol_try_cancel`

`clusterlaunchcontrol.try.cancel` requests atomically canceling the launch
of a cluster that has not started running yet. It asynchronously writes an
opaque response to shared memory indicating whether the operation succeeded
or failed.

Operand `smemAddress` specifies the naturally aligned address of the
16-byte wide shared memory location where the request\'s response is written.

Operand `mbarrier` specifies the mbarrier object used to track the
completion of the asynchronous operation.

If `multicast` is specified, the response is asynchronously written to the
corresponding local shared memory location (specifed by `addr`) of each CTA
in the requesting cluster.

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#parallel-synchronization-and-communication-instructions-clusterlaunchcontrol-try-cancel)
"""
function clusterlaunchcontrol_try_cancel(
smemAddress::Value, mbarrier::Value; multicast=nothing, location=Location()
)
op_ty_results = IR.Type[]
operands = Value[smemAddress, mbarrier]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]
!isnothing(multicast) && push!(attributes, namedattribute("multicast", multicast))

return create_operation(
"nvvm.clusterlaunchcontrol.try.cancel",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

"""
`cluster_wait`

Expand Down Expand Up @@ -741,13 +821,13 @@ respectively.
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
"""
function convert_bf16x2_to_f8x2(
a::Value; dst::IR.Type, type, rnd=nothing, sat=nothing, location=Location()
a::Value; dst::IR.Type, rnd=nothing, sat=nothing, dstTy, location=Location()
)
op_ty_results = IR.Type[dst,]
operands = Value[a,]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[namedattribute("type", type),]
attributes = NamedAttribute[namedattribute("dstTy", dstTy),]
!isnothing(rnd) && push!(attributes, namedattribute("rnd", rnd))
!isnothing(sat) && push!(attributes, namedattribute("sat", sat))

Expand Down Expand Up @@ -782,13 +862,13 @@ the cvt instruction.
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
"""
function convert_f16x2_to_f8x2(
a::Value; dst::IR.Type, type, relu=nothing, location=Location()
a::Value; dst::IR.Type, relu=nothing, dstTy, location=Location()
)
op_ty_results = IR.Type[dst,]
operands = Value[a,]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[namedattribute("type", type),]
attributes = NamedAttribute[namedattribute("dstTy", dstTy),]
!isnothing(relu) && push!(attributes, namedattribute("relu", relu))

return create_operation(
Expand Down Expand Up @@ -821,13 +901,13 @@ the cvt instruction.
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
"""
function convert_f32x2_to_f6x2(
a::Value, b::Value; dst::IR.Type, type, relu=nothing, location=Location()
a::Value, b::Value; dst::IR.Type, relu=nothing, dstTy, location=Location()
)
op_ty_results = IR.Type[dst,]
operands = Value[a, b]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[namedattribute("type", type),]
attributes = NamedAttribute[namedattribute("dstTy", dstTy),]
!isnothing(relu) && push!(attributes, namedattribute("relu", relu))

return create_operation(
Expand Down Expand Up @@ -863,17 +943,17 @@ function convert_f32x2_to_f8x2(
a::Value,
b::Value;
dst::IR.Type,
type,
rnd=nothing,
sat=nothing,
relu=nothing,
dstTy,
location=Location(),
)
op_ty_results = IR.Type[dst,]
operands = Value[a, b]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[namedattribute("type", type),]
attributes = NamedAttribute[namedattribute("dstTy", dstTy),]
!isnothing(rnd) && push!(attributes, namedattribute("rnd", rnd))
!isnothing(sat) && push!(attributes, namedattribute("sat", sat))
!isnothing(relu) && push!(attributes, namedattribute("relu", relu))
Expand Down Expand Up @@ -1148,16 +1228,8 @@ end
`cp_async_bulk_tensor_shared_cluster_global`

Initiates an asynchronous copy operation on the tensor data from global
memory to shared memory.

The Op operates has two load modes:
1) Tiled Mode: It\'s the default mode. The source multi-dimensional tensor
layout is preserved at the destination.

2) Im2col Mode: This mode is used when `im2colOffsets` operands are present.
the elements in the Bounding Box of the source tensor are rearranged into
columns at the destination. In this mode, the tensor has to be at least
3-dimensional.
memory to shared::cluster (or) shared::cta memory. This Op supports all
the load modes specified in `TMALoadMode`.

The `multicastMask` operand is optional. When it is present, the Op copies
data from global memory to shared memory of multiple CTAs in the cluster.
Expand All @@ -1168,6 +1240,10 @@ the `nvvm.read.ptx.sreg.ctaid` of the destination CTA.
The `l2CacheHint` operand is optional, and it is used to specify cache
eviction policy that may be used during the memory access.

When the `isCTAOnly` attribute is set to true, the destination is
shared::cta only. Hence, `multicastMask` and `CTAGroup` are not applicable
when `isCTAOnly` is true.

[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor)
"""
function cp_async_bulk_tensor_shared_cluster_global(
Expand All @@ -1179,6 +1255,9 @@ function cp_async_bulk_tensor_shared_cluster_global(
multicastMask=nothing::Union{Nothing,Value};
l2CacheHint=nothing::Union{Nothing,Value},
predicate=nothing::Union{Nothing,Value},
mode=nothing,
isCTAOnly=nothing,
group=nothing,
location=Location(),
)
op_ty_results = IR.Type[]
Expand All @@ -1202,6 +1281,9 @@ function cp_async_bulk_tensor_shared_cluster_global(
(predicate == nothing) ? 0 : 1,
]),
)
!isnothing(mode) && push!(attributes, namedattribute("mode", mode))
!isnothing(isCTAOnly) && push!(attributes, namedattribute("isCTAOnly", isCTAOnly))
!isnothing(group) && push!(attributes, namedattribute("group", group))

return create_operation(
"nvvm.cp.async.bulk.tensor.shared.cluster.global",
Expand Down
1 change: 1 addition & 0 deletions src/mlir/Dialects/Shardy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ affect the order of the corresponding replica groups.
**Constraints:**
- Must satisfy the constraints listed in `Sdy_CollectiveOpInterface`.
- `reduction_axes` must satisfy the constraints listed in `AxisRefListAttr`.
- `reduction_axes` must be sorted w.r.t. the mesh.
- The operand sharding and `out_sharding` must have equivalent dimension
shardings.
- `reduction_axes` must not overlap with the operand dimension sharding and
Expand Down
Loading