diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index cdede7c364..18636d89d0 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -373,21 +373,33 @@ enzymeActivityAttrGet(MlirContext ctx, int32_t val) { (mlir::enzyme::Activity)val)); } -extern "C" MLIR_CAPI_EXPORTED MlirType enzymeTraceTypeGet(MlirContext ctx) { +REACTANT_ABI MLIR_CAPI_EXPORTED MlirType enzymeTraceTypeGet(MlirContext ctx) { return wrap(mlir::enzyme::TraceType::get(unwrap(ctx))); } -extern "C" MLIR_CAPI_EXPORTED MlirType +REACTANT_ABI MLIR_CAPI_EXPORTED MlirType enzymeConstraintTypeGet(MlirContext ctx) { return wrap(mlir::enzyme::ConstraintType::get(unwrap(ctx))); } -extern "C" MLIR_CAPI_EXPORTED MlirAttribute +REACTANT_ABI MLIR_CAPI_EXPORTED MlirAttribute enzymeSymbolAttrGet(MlirContext ctx, uint64_t symbol) { mlir::Attribute attr = mlir::enzyme::SymbolAttr::get(unwrap(ctx), symbol); return wrap(attr); } +REACTANT_ABI MLIR_CAPI_EXPORTED MlirAttribute +enzymeRngDistributionAttrGet(MlirContext ctx, int32_t val) { + return wrap(mlir::enzyme::RngDistributionAttr::get( + unwrap(ctx), (mlir::enzyme::RngDistribution)val)); +} + +REACTANT_ABI MLIR_CAPI_EXPORTED MlirAttribute +enzymeMCMCAlgorithmAttrGet(MlirContext ctx, int32_t val) { + return wrap(mlir::enzyme::MCMCAlgorithmAttr::get( + unwrap(ctx), (mlir::enzyme::MCMCAlgorithm)val)); +} + // Create profiler session and start profiling REACTANT_ABI tsl::ProfilerSession * CreateProfilerSession(uint32_t device_tracer_level,