diff --git a/src/mlir/Dialects/Enzyme.jl b/src/mlir/Dialects/Enzyme.jl index f922304da3..e4306b06a1 100755 --- a/src/mlir/Dialects/Enzyme.jl +++ b/src/mlir/Dialects/Enzyme.jl @@ -282,6 +282,28 @@ function push(cache::Value, value::Value; location=Location()) ) end +function sample( + inputs::Vector{Value}; outputs::Vector{IR.Type}, fn, name=nothing, location=Location() +) + op_ty_results = IR.Type[outputs...,] + operands = Value[inputs...,] + owned_regions = Region[] + successors = Block[] + attributes = NamedAttribute[namedattribute("fn", fn),] + !isnothing(name) && push!(attributes, namedattribute("name", name)) + + return create_operation( + "enzyme.sample", + location; + operands, + owned_regions, + successors, + attributes, + results=op_ty_results, + result_inference=false, + ) +end + function set(gradient::Value, value::Value; location=Location()) op_ty_results = IR.Type[] operands = Value[gradient, value]