Skip to content
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Scratch = "6c6a2e73-6563-6170-7368-637461726353"
[weakdeps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"
Expand All @@ -31,6 +32,7 @@ path = "lib/ReactantCore"
[extensions]
ReactantAbstractFFTsExt = "AbstractFFTs"
ReactantArrayInterfaceExt = "ArrayInterface"
ReactantCUDAExt = "CUDA"
ReactantNNlibExt = "NNlib"
ReactantStatisticsExt = "Statistics"
ReactantYaoBlocksExt = "YaoBlocks"
Expand Down Expand Up @@ -58,4 +60,5 @@ julia = "1.10"
[extras]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
84 changes: 84 additions & 0 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,10 @@ extern "C" void RegisterDialects(MlirContext cctx) {
context.loadDialect<mlir::stablehlo::StablehloDialect>();
context.loadDialect<mlir::chlo::ChloDialect>();
}

#include "mlir/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMIRToLLVMTranslation.h"
#include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h"
extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
mlir::DialectRegistry &registry = *unwrap(creg);

Expand Down Expand Up @@ -513,6 +517,11 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
mlir::affine::registerAffinePasses();
mlir::registerReconcileUnrealizedCasts();

mlir::registerLLVMDialectImport(registry);
mlir::registerNVVMDialectImport(registry);

mlir::LLVM::registerInlinerInterface(registry);

/*
registry.addExtension(+[](MLIRContext *ctx, LLVM::LLVMDialect *dialect) {
LLVM::LLVMFunctionType::attachInterface<MemRefInsider>(*ctx);
Expand Down Expand Up @@ -540,6 +549,81 @@ extern "C" void InitializeRegistryAndPasses(MlirDialectRegistry creg) {
mlir::enzyme::registerEnzymeJaxTransformExtension(registry);
}


/// Returns an unused symbol in `module` for `oldSymbolName` by trying numeric
/// suffix in `lastUsedID`.
static mlir::StringAttr renameSymbol(llvm::StringRef oldSymName,
unsigned &lastUsedID,
mlir::ModuleOp source,
mlir::ModuleOp target) {
using namespace llvm;
using namespace mlir;
SmallString<64> newSymName(oldSymName);
newSymName.push_back('_');
while (true) {
auto possible = newSymName + Twine(++lastUsedID);
if (!SymbolTable::lookupSymbolIn(source, possible.str()) && !SymbolTable::lookupSymbolIn(target, possible.str())) {
return StringAttr::get(target.getContext(), possible);
}
}
}


/// Checks if a symbol with the same name as `op` already exists in `source`.
/// If so, renames `op` and updates all its references in `target`.
static mlir::LogicalResult
updateSymbolAndAllUses(mlir::SymbolOpInterface op, mlir::ModuleOp source, mlir::ModuleOp target,
unsigned &lastUsedID) {
using namespace llvm;
using namespace mlir;

auto opName = op.getName().str();

if (!SymbolTable::lookupSymbolIn(target, opName)) {
return success();
}

StringAttr newSymName =
renameSymbol(opName, lastUsedID, source, target);

if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, source)))
return op.emitError("unable to update all symbol uses for ")
<< opName << " to " << newSymName;

SymbolTable::setSymbolName(op, newSymName);
return success();
}

extern "C" MlirOperation LinkInModule(MlirModule prevModC, MlirModule newModC, const char* entryfn) {
auto prevMod = cast<ModuleOp>(*unwrap(prevModC));
auto newMod = cast<ModuleOp>(*unwrap(newModC));

Operation* entryFn = nullptr;

unsigned lastUsedID = 0;

for (auto &op : *newMod.getBody()) {
auto symbolOp = dyn_cast<SymbolOpInterface>(op);
if (!symbolOp)
continue;

StringRef oldSymName = symbolOp.getName();

if (oldSymName == entryfn) {
entryFn = &op;
}

if (failed(updateSymbolAndAllUses(symbolOp, newMod, prevMod,
lastUsedID))) {
assert(0 && "failed to update all uses");
}
SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private);
}
prevMod.getBody()->getOperations().splice(prevMod.getBody()->getOperations().end(),
newMod.getBody()->getOperations());
return wrap(entryFn);
}

#pragma region xla::ifrt

#pragma region xla::ifrt::Value
Expand Down
5 changes: 5 additions & 0 deletions deps/ReactantExtra/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,7 @@ cc_library(
"-Wl,-exported_symbol,_BufferToHost",
"-Wl,-exported_symbol,_FreeClient",
"-Wl,-exported_symbol,_ClientCompile",
"-Wl,-exported_symbol,_LinkInModule",
"-Wl,-exported_symbol,_FreeFuture",
"-Wl,-exported_symbol,_FutureIsReady",
"-Wl,-exported_symbol,_FutureAwait",
Expand Down Expand Up @@ -451,6 +452,10 @@ cc_library(
"@llvm-project//mlir:TransformDialect",
"@llvm-project//mlir:Transforms",

"@llvm-project//mlir:LLVMIRToLLVMTranslation",
"@llvm-project//mlir:LLVMIRToNVVMTranslation",
"@llvm-project//mlir:LLVMIRTransforms",

"@llvm-project//llvm:IRReader",
"@llvm-project//llvm:Support",
"@llvm-project//llvm:AArch64AsmParser",
Expand Down
Loading
Loading