Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions src/mlir/Dialects/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,40 @@ function autodiff(
)
end

function autodiff_region(
inputs::Vector{Value};
outputs::Vector{IR.Type},
activity,
ret_activity,
width=nothing,
strong_zero=nothing,
fn=nothing,
body::Region,
location=Location(),
)
op_ty_results = IR.Type[outputs...,]
operands = Value[inputs...,]
owned_regions = Region[body,]
successors = Block[]
attributes = NamedAttribute[
namedattribute("activity", activity), namedattribute("ret_activity", ret_activity)
]
!isnothing(width) && push!(attributes, namedattribute("width", width))
!isnothing(strong_zero) && push!(attributes, namedattribute("strong_zero", strong_zero))
!isnothing(fn) && push!(attributes, namedattribute("fn", fn))

return create_operation(
"enzyme.autodiff_region",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

function batch(
inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, batch_shape, location=Location()
)
Expand Down Expand Up @@ -648,4 +682,23 @@ function untracedCall(
)
end

function yield(operands::Vector{Value}; location=Location())
op_ty_results = IR.Type[]
operands = Value[operands...,]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[]

return create_operation(
"enzyme.yield",
location;
operands,
owned_regions,
successors,
attributes,
results=op_ty_results,
result_inference=false,
)
end

end # enzyme
42 changes: 42 additions & 0 deletions src/mlir/Dialects/Gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2729,6 +2729,48 @@ function spmat_get_size(
)
end

"""
`subgroup_broadcast`

Broadcasts a value from one lane to all active lanes in a subgroup. The
result is guaranteed to be uniform across the active lanes in subgroup.

The possible broadcast types are:

* `first_active_lane` - broadcasts the value from the first active lane
in the subgroup.
* `specific_lane` - broadcasts from the specified lane. The lane index
must be uniform and within the subgroup size. The result is poison if the
lane index is invalid, non subgroup-uniform, or if the source lane is not
active.
"""
function subgroup_broadcast(
src::Value,
lane=nothing::Union{Nothing,Value};
result=nothing::Union{Nothing,IR.Type},
broadcast_type,
location=Location(),
)
op_ty_results = IR.Type[]
operands = Value[src,]
owned_regions = Region[]
successors = Block[]
attributes = NamedAttribute[namedattribute("broadcast_type", broadcast_type),]
!isnothing(lane) && push!(operands, lane)
!isnothing(result) && push!(op_ty_results, result)

return create_operation(
"gpu.subgroup_broadcast",
location;
operands,
owned_regions,
successors,
attributes,
results=(length(op_ty_results) == 0 ? nothing : op_ty_results),
result_inference=(length(op_ty_results) == 0 ? true : false),
)
end

"""
`subgroup_id`

Expand Down
22 changes: 19 additions & 3 deletions src/mlir/Dialects/Llvm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1459,6 +1459,22 @@ Examples:
// Alignment is optional
llvm.mlir.global private constant @y(dense<1.0> : tensor<8xf32>) { alignment = 32 : i64 } : !llvm.array<8 x f32>
```

The `target_specific_attrs` attribute provides a mechanism to preserve
target-specific LLVM IR attributes that are not explicitly modeled in the
LLVM dialect.

The attribute is an array containing either string attributes or
two-element array attributes of strings. The value of a standalone string
attribute is interpreted as the name of an LLVM IR attribute on the global.
A two-element array is interpreted as a key-value pair.

# Example

```mlir
llvm.mlir.global external @example() {
target_specific_attrs = [\"value-less-attr\", [\"int-attr\", \"4\"], [\"string-attr\", \"string\"]]} : f64
```
"""
function mlir_global(;
global_type,
Expand All @@ -1476,6 +1492,7 @@ function mlir_global(;
comdat=nothing,
dbg_exprs=nothing,
visibility_=nothing,
target_specific_attrs=nothing,
initializer::Region,
location=Location(),
)
Expand Down Expand Up @@ -1503,6 +1520,8 @@ function mlir_global(;
!isnothing(comdat) && push!(attributes, namedattribute("comdat", comdat))
!isnothing(dbg_exprs) && push!(attributes, namedattribute("dbg_exprs", dbg_exprs))
!isnothing(visibility_) && push!(attributes, namedattribute("visibility_", visibility_))
!isnothing(target_specific_attrs) &&
push!(attributes, namedattribute("target_specific_attrs", target_specific_attrs))

return create_operation(
"llvm.mlir.global",
Expand Down Expand Up @@ -1933,7 +1952,6 @@ function func(;
unsafe_fp_math=nothing,
no_infs_fp_math=nothing,
no_nans_fp_math=nothing,
approx_func_fp_math=nothing,
no_signed_zeros_fp_math=nothing,
denormal_fp_math=nothing,
denormal_fp_math_f32=nothing,
Expand Down Expand Up @@ -2014,8 +2032,6 @@ function func(;
push!(attributes, namedattribute("no_infs_fp_math", no_infs_fp_math))
!isnothing(no_nans_fp_math) &&
push!(attributes, namedattribute("no_nans_fp_math", no_nans_fp_math))
!isnothing(approx_func_fp_math) &&
push!(attributes, namedattribute("approx_func_fp_math", approx_func_fp_math))
!isnothing(no_signed_zeros_fp_math) && push!(
attributes, namedattribute("no_signed_zeros_fp_math", no_signed_zeros_fp_math)
)
Expand Down
Loading
Loading