Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/src/.vitepress/config.mts
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ export default defineConfig({
{ text: "TPU", link: "/api/tpu" },
{ text: "Triton", link: "/api/triton" },
{ text: "Shardy", link: "/api/shardy" },
{ text: "MPI", link: "/api/mpi" },
],
},
{
Expand Down Expand Up @@ -147,6 +148,7 @@ export default defineConfig({
{ text: "TPU", link: "/api/tpu" },
{ text: "Triton", link: "/api/triton" },
{ text: "Shardy", link: "/api/shardy" },
{ text: "MPI", link: "/api/mpi" },
],
},
{
Expand Down
12 changes: 12 additions & 0 deletions docs/src/api/mpi.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
```@meta
CollapsedDocStrings = true
```

# MPI Dialect

Refer to the [official documentation](https://mlir.llvm.org/docs/Dialects/MPI/) for
more details.

```@autodocs
Modules = [Reactant.MLIR.Dialects.mpi]
```
236 changes: 236 additions & 0 deletions src/mlir/Dialects/MPI.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
module mpi
using ...IR
import ...IR:
NamedAttribute,
Value,
Location,
Block,
Region,
Attribute,
create_operation,
context,
IndexType
import ..Dialects: namedattribute, operandsegmentsizes
import ...API

"""
`comm_rank`

Communicators other than `MPI_COMM_WORLD` are not supported for now.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
"""
function comm_rank(;
retval=nothing::Union{Nothing,IR.Type}, rank::IR.Type, location=Location()
)
op_ty_results = IR.Type[rank,]
operands = Value[]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]
!isnothing(retval) && push!(op_ty_results, retval)

return create_operation(
"mpi.comm_rank",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

"""
`error_class`

`MPI_Error_class` maps return values from MPI calls to a set of well-known
MPI error classes.
"""
function error_class(val::Value; errclass::IR.Type, location=Location())
op_ty_results = IR.Type[errclass,]
operands = Value[val,]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]

return create_operation(
"mpi.error_class",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

"""
`finalize`

This function cleans up the MPI state. Afterwards, no MPI methods may
be invoked (excpet for MPI_Get_version, MPI_Initialized, and MPI_Finalized).
Notably, MPI_Init cannot be called again in the same program.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
"""
function finalize(; retval=nothing::Union{Nothing,IR.Type}, location=Location())
op_ty_results = IR.Type[]
operands = Value[]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]
!isnothing(retval) && push!(op_ty_results, retval)

return create_operation(
"mpi.finalize",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

"""
`init`

This operation must preceed most MPI calls (except for very few exceptions,
please consult with the MPI specification on these).

Passing &argc, &argv is not supported currently.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
"""
function init(; retval=nothing::Union{Nothing,IR.Type}, location=Location())
op_ty_results = IR.Type[]
operands = Value[]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]
!isnothing(retval) && push!(op_ty_results, retval)

return create_operation(
"mpi.init",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

"""
`recv`

MPI_Recv performs a blocking receive of `size` elements of type `dtype`
from rank `dest`. The `tag` value and communicator enables the library to
determine the matching of multiple sends and receives between the same
ranks.

Communicators other than `MPI_COMM_WORLD` are not supprted for now.
The MPI_Status is set to `MPI_STATUS_IGNORE`, as the status object
is not yet ported to MLIR.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
"""
function recv(
ref::Value,
tag::Value,
rank::Value;
retval=nothing::Union{Nothing,IR.Type},
location=Location(),
)
op_ty_results = IR.Type[]
operands = Value[ref, tag, rank]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]
!isnothing(retval) && push!(op_ty_results, retval)

return create_operation(
"mpi.recv",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

"""
`retval_check`

This operation compares MPI status codes to known error class
constants such as `MPI_SUCCESS`, or `MPI_ERR_COMM`.
"""
function retval_check(val::Value; res::IR.Type, errclass, location=Location())
op_ty_results = IR.Type[res,]
operands = Value[val,]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[namedattribute("errclass", errclass),]

return create_operation(
"mpi.retval_check",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

"""
`send`

MPI_Send performs a blocking send of `size` elements of type `dtype` to rank
`dest`. The `tag` value and communicator enables the library to determine
the matching of multiple sends and receives between the same ranks.

Communicators other than `MPI_COMM_WORLD` are not supprted for now.

This operation can optionally return an `!mpi.retval` value that can be used
to check for errors.
"""
function send(
ref::Value,
tag::Value,
rank::Value;
retval=nothing::Union{Nothing,IR.Type},
location=Location(),
)
op_ty_results = IR.Type[]
operands = Value[ref, tag, rank]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]
!isnothing(retval) && push!(op_ty_results, retval)

return create_operation(
"mpi.send",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

end # mpi
Loading