Skip to content

Commit 7851b1b

Browse files
authored
[mlir][gpu] Change GPU modules to globals (llvm#135478)
Load/unload GPU modules in global ctors/dtors instead of each time when launching a kernel. Loading GPU modules is a heavy-weight operation and synchronizes the GPU context. Now that the modules are loaded ahead of time, asynchronously launched kernels can run concurrently, see https://discourse.llvm.org/t/how-to-lower-the-combination-of-async-gpu-ops-in-gpu-dialect. The implementations of `embedBinary()` and `launchKernel()` use slightly different mechanics at the moment but I prefer to not change the latter more than necessary as part of this PR. I will prepare a follow-up NFC for `launchKernel()` to align them again.
1 parent c60f24d commit 7851b1b

File tree

5 files changed

+220
-173
lines changed

5 files changed

+220
-173
lines changed

mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
#include "mlir/ExecutionEngine/CRunnerUtils.h"
1616

17-
#include <stdio.h>
17+
#include <cstdio>
1818

1919
#include "cuda.h"
2020
#include "cuda_bf16.h"
@@ -56,14 +56,10 @@
5656

5757
thread_local static int32_t defaultDevice = 0;
5858

59-
const char *kDebugEnvironmentVariable = "MLIR_CUDA_DEBUG";
60-
6159
/// Helper method that checks environment value for debugging.
6260
bool isDebugEnabled() {
63-
static bool isInitialized = false;
64-
static bool isEnabled = false;
65-
if (!isInitialized)
66-
isEnabled = getenv(kDebugEnvironmentVariable) != nullptr;
61+
const char *kDebugEnvironmentVariable = "MLIR_CUDA_DEBUG";
62+
static bool isEnabled = getenv(kDebugEnvironmentVariable) != nullptr;
6763
return isEnabled;
6864
}
6965

mlir/lib/Target/LLVMIR/Dialect/GPU/SelectObjectAttr.cpp

Lines changed: 128 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@
1818
#include "mlir/Target/LLVMIR/Export.h"
1919
#include "mlir/Target/LLVMIR/ModuleTranslation.h"
2020

21+
#include "llvm/ADT/ScopeExit.h"
2122
#include "llvm/IR/Constants.h"
2223
#include "llvm/IR/IRBuilder.h"
2324
#include "llvm/IR/LLVMContext.h"
2425
#include "llvm/IR/Module.h"
2526
#include "llvm/Support/FormatVariadic.h"
27+
#include "llvm/Transforms/Utils/ModuleUtils.h"
2628

2729
using namespace mlir;
2830

@@ -31,9 +33,13 @@ namespace {
3133
class SelectObjectAttrImpl
3234
: public gpu::OffloadingLLVMTranslationAttrInterface::FallbackModel<
3335
SelectObjectAttrImpl> {
36+
// Returns the selected object for embedding.
37+
gpu::ObjectAttr getSelectedObject(gpu::BinaryOp op) const;
38+
3439
public:
3540
// Translates a `gpu.binary`, embedding the binary into a host LLVM module as
36-
// global binary string.
41+
// global binary string which gets loaded/unloaded into a global module
42+
// object through a global ctor/dtor.
3743
LogicalResult embedBinary(Attribute attribute, Operation *operation,
3844
llvm::IRBuilderBase &builder,
3945
LLVM::ModuleTranslation &moduleTranslation) const;
@@ -45,23 +51,9 @@ class SelectObjectAttrImpl
4551
Operation *binaryOperation,
4652
llvm::IRBuilderBase &builder,
4753
LLVM::ModuleTranslation &moduleTranslation) const;
48-
49-
// Returns the selected object for embedding.
50-
gpu::ObjectAttr getSelectedObject(gpu::BinaryOp op) const;
5154
};
52-
// Returns an identifier for the global string holding the binary.
53-
std::string getBinaryIdentifier(StringRef binaryName) {
54-
return binaryName.str() + "_bin_cst";
55-
}
5655
} // namespace
5756

58-
void mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(
59-
DialectRegistry &registry) {
60-
registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
61-
SelectObjectAttr::attachInterface<SelectObjectAttrImpl>(*ctx);
62-
});
63-
}
64-
6557
gpu::ObjectAttr
6658
SelectObjectAttrImpl::getSelectedObject(gpu::BinaryOp op) const {
6759
ArrayRef<Attribute> objects = op.getObjectsAttr().getValue();
@@ -96,6 +88,94 @@ SelectObjectAttrImpl::getSelectedObject(gpu::BinaryOp op) const {
9688
return mlir::dyn_cast<gpu::ObjectAttr>(objects[index]);
9789
}
9890

91+
static Twine getModuleIdentifier(StringRef moduleName) {
92+
return moduleName + "_module";
93+
}
94+
95+
namespace llvm {
96+
static LogicalResult embedBinaryImpl(StringRef moduleName,
97+
gpu::ObjectAttr object, Module &module) {
98+
99+
// Embed the object as a global string.
100+
// Add null for assembly output for JIT paths that expect null-terminated
101+
// strings.
102+
bool addNull = (object.getFormat() == gpu::CompilationTarget::Assembly);
103+
StringRef serializedStr = object.getObject().getValue();
104+
Constant *serializedCst =
105+
ConstantDataArray::getString(module.getContext(), serializedStr, addNull);
106+
GlobalVariable *serializedObj =
107+
new GlobalVariable(module, serializedCst->getType(), true,
108+
GlobalValue::LinkageTypes::InternalLinkage,
109+
serializedCst, moduleName + "_binary");
110+
serializedObj->setAlignment(MaybeAlign(8));
111+
serializedObj->setUnnamedAddr(GlobalValue::UnnamedAddr::None);
112+
113+
// Default JIT optimization level.
114+
auto optLevel = APInt::getZero(32);
115+
116+
if (DictionaryAttr objectProps = object.getProperties()) {
117+
if (auto section = dyn_cast_or_null<StringAttr>(
118+
objectProps.get(gpu::elfSectionName))) {
119+
serializedObj->setSection(section.getValue());
120+
}
121+
// Check if there's an optimization level embedded in the object.
122+
if (auto optAttr = dyn_cast_or_null<IntegerAttr>(objectProps.get("O")))
123+
optLevel = optAttr.getValue();
124+
}
125+
126+
IRBuilder<> builder(module.getContext());
127+
auto i32Ty = builder.getInt32Ty();
128+
auto i64Ty = builder.getInt64Ty();
129+
auto ptrTy = builder.getPtrTy(0);
130+
auto voidTy = builder.getVoidTy();
131+
132+
// Embed the module as a global object.
133+
auto *modulePtr = new GlobalVariable(
134+
module, ptrTy, /*isConstant=*/false, GlobalValue::InternalLinkage,
135+
/*Initializer=*/ConstantPointerNull::get(ptrTy),
136+
getModuleIdentifier(moduleName));
137+
138+
auto *loadFn = Function::Create(FunctionType::get(voidTy, /*IsVarArg=*/false),
139+
GlobalValue::InternalLinkage,
140+
moduleName + "_load", module);
141+
loadFn->setSection(".text.startup");
142+
auto *loadBlock = BasicBlock::Create(module.getContext(), "entry", loadFn);
143+
builder.SetInsertPoint(loadBlock);
144+
Value *moduleObj = [&] {
145+
if (object.getFormat() == gpu::CompilationTarget::Assembly) {
146+
FunctionCallee moduleLoadFn = module.getOrInsertFunction(
147+
"mgpuModuleLoadJIT", FunctionType::get(ptrTy, {ptrTy, i32Ty}, false));
148+
Constant *optValue = ConstantInt::get(i32Ty, optLevel);
149+
return builder.CreateCall(moduleLoadFn, {serializedObj, optValue});
150+
} else {
151+
FunctionCallee moduleLoadFn = module.getOrInsertFunction(
152+
"mgpuModuleLoad", FunctionType::get(ptrTy, {ptrTy, i64Ty}, false));
153+
Constant *binarySize =
154+
ConstantInt::get(i64Ty, serializedStr.size() + (addNull ? 1 : 0));
155+
return builder.CreateCall(moduleLoadFn, {serializedObj, binarySize});
156+
}
157+
}();
158+
builder.CreateStore(moduleObj, modulePtr);
159+
builder.CreateRetVoid();
160+
appendToGlobalCtors(module, loadFn, /*Priority=*/123);
161+
162+
auto *unloadFn = Function::Create(
163+
FunctionType::get(voidTy, /*IsVarArg=*/false),
164+
GlobalValue::InternalLinkage, moduleName + "_unload", module);
165+
unloadFn->setSection(".text.startup");
166+
auto *unloadBlock =
167+
BasicBlock::Create(module.getContext(), "entry", unloadFn);
168+
builder.SetInsertPoint(unloadBlock);
169+
FunctionCallee moduleUnloadFn = module.getOrInsertFunction(
170+
"mgpuModuleUnload", FunctionType::get(voidTy, ptrTy, false));
171+
builder.CreateCall(moduleUnloadFn, builder.CreateLoad(ptrTy, modulePtr));
172+
builder.CreateRetVoid();
173+
appendToGlobalDtors(module, unloadFn, /*Priority=*/123);
174+
175+
return success();
176+
}
177+
} // namespace llvm
178+
99179
LogicalResult SelectObjectAttrImpl::embedBinary(
100180
Attribute attribute, Operation *operation, llvm::IRBuilderBase &builder,
101181
LLVM::ModuleTranslation &moduleTranslation) const {
@@ -113,29 +193,8 @@ LogicalResult SelectObjectAttrImpl::embedBinary(
113193
if (!object)
114194
return failure();
115195

116-
llvm::Module *module = moduleTranslation.getLLVMModule();
117-
118-
// Embed the object as a global string.
119-
// Add null for assembly output for JIT paths that expect null-terminated
120-
// strings.
121-
bool addNull = (object.getFormat() == gpu::CompilationTarget::Assembly);
122-
llvm::Constant *binary = llvm::ConstantDataArray::getString(
123-
builder.getContext(), object.getObject().getValue(), addNull);
124-
llvm::GlobalVariable *serializedObj =
125-
new llvm::GlobalVariable(*module, binary->getType(), true,
126-
llvm::GlobalValue::LinkageTypes::InternalLinkage,
127-
binary, getBinaryIdentifier(op.getName()));
128-
129-
if (object.getProperties()) {
130-
if (auto section = mlir::dyn_cast_or_null<mlir::StringAttr>(
131-
object.getProperties().get(gpu::elfSectionName))) {
132-
serializedObj->setSection(section.getValue());
133-
}
134-
}
135-
serializedObj->setLinkage(llvm::GlobalValue::LinkageTypes::InternalLinkage);
136-
serializedObj->setAlignment(llvm::MaybeAlign(8));
137-
serializedObj->setUnnamedAddr(llvm::GlobalValue::UnnamedAddr::None);
138-
return success();
196+
return embedBinaryImpl(op.getName(), object,
197+
*moduleTranslation.getLLVMModule());
139198
}
140199

141200
namespace llvm {
@@ -153,15 +212,6 @@ class LaunchKernel {
153212
// Get the module function callee.
154213
FunctionCallee getModuleFunctionFn();
155214

156-
// Get the module load callee.
157-
FunctionCallee getModuleLoadFn();
158-
159-
// Get the module load JIT callee.
160-
FunctionCallee getModuleLoadJITFn();
161-
162-
// Get the module unload callee.
163-
FunctionCallee getModuleUnloadFn();
164-
165215
// Get the stream create callee.
166216
FunctionCallee getStreamCreateFn();
167217

@@ -261,24 +311,6 @@ llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() {
261311
FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, ptrTy}), false));
262312
}
263313

264-
llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadFn() {
265-
return module.getOrInsertFunction(
266-
"mgpuModuleLoad",
267-
FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, i64Ty}), false));
268-
}
269-
270-
llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadJITFn() {
271-
return module.getOrInsertFunction(
272-
"mgpuModuleLoadJIT",
273-
FunctionType::get(ptrTy, ArrayRef<Type *>({ptrTy, i32Ty}), false));
274-
}
275-
276-
llvm::FunctionCallee llvm::LaunchKernel::getModuleUnloadFn() {
277-
return module.getOrInsertFunction(
278-
"mgpuModuleUnload",
279-
FunctionType::get(voidTy, ArrayRef<Type *>({ptrTy}), false));
280-
}
281-
282314
llvm::FunctionCallee llvm::LaunchKernel::getStreamCreateFn() {
283315
return module.getOrInsertFunction("mgpuStreamCreate",
284316
FunctionType::get(ptrTy, false));
@@ -301,9 +333,9 @@ llvm::FunctionCallee llvm::LaunchKernel::getStreamSyncFn() {
301333
llvm::Value *llvm::LaunchKernel::getOrCreateFunctionName(StringRef moduleName,
302334
StringRef kernelName) {
303335
std::string globalName =
304-
std::string(formatv("{0}_{1}_kernel_name", moduleName, kernelName));
336+
std::string(formatv("{0}_{1}_name", moduleName, kernelName));
305337

306-
if (GlobalVariable *gv = module.getGlobalVariable(globalName))
338+
if (GlobalVariable *gv = module.getGlobalVariable(globalName, true))
307339
return gv;
308340

309341
return builder.CreateGlobalString(kernelName, globalName);
@@ -346,16 +378,13 @@ llvm::LaunchKernel::createKernelArgArray(mlir::gpu::LaunchFuncOp op) {
346378
}
347379

348380
// Emits LLVM IR to launch a kernel function:
349-
// %0 = call %binarygetter
350-
// %1 = call %moduleLoad(%0)
351-
// %2 = <see generateKernelNameConstant>
352-
// %3 = call %moduleGetFunction(%1, %2)
353-
// %4 = call %streamCreate()
354-
// %5 = <see generateParamsArray>
355-
// call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
356-
// call %streamSynchronize(%4)
357-
// call %streamDestroy(%4)
358-
// call %moduleUnload(%1)
381+
// %1 = load %global_module_object
382+
// %2 = call @mgpuModuleGetFunction(%1, %global_kernel_name)
383+
// %3 = call @mgpuStreamCreate()
384+
// %4 = <see createKernelArgArray()>
385+
// call @mgpuLaunchKernel(%2, ..., %3, %4, ...)
386+
// call @mgpuStreamSynchronize(%3)
387+
// call @mgpuStreamDestroy(%3)
359388
llvm::LogicalResult
360389
llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
361390
mlir::gpu::ObjectAttr object) {
@@ -385,58 +414,29 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
385414
// Create the argument array.
386415
Value *argArray = createKernelArgArray(op);
387416

388-
// Default JIT optimization level.
389-
llvm::Constant *optV = llvm::ConstantInt::get(i32Ty, 0);
390-
// Check if there's an optimization level embedded in the object.
391-
DictionaryAttr objectProps = object.getProperties();
392-
mlir::Attribute optAttr;
393-
if (objectProps && (optAttr = objectProps.get("O"))) {
394-
auto optLevel = dyn_cast<IntegerAttr>(optAttr);
395-
if (!optLevel)
396-
return op.emitError("the optimization level must be an integer");
397-
optV = llvm::ConstantInt::get(i32Ty, optLevel.getValue());
398-
}
399-
400-
// Load the kernel module.
401-
StringRef moduleName = op.getKernelModuleName().getValue();
402-
std::string binaryIdentifier = getBinaryIdentifier(moduleName);
403-
Value *binary = module.getGlobalVariable(binaryIdentifier, true);
404-
if (!binary)
405-
return op.emitError() << "Couldn't find the binary: " << binaryIdentifier;
406-
407-
auto binaryVar = dyn_cast<llvm::GlobalVariable>(binary);
408-
if (!binaryVar)
409-
return op.emitError() << "Binary is not a global variable: "
410-
<< binaryIdentifier;
411-
llvm::Constant *binaryInit = binaryVar->getInitializer();
412-
auto binaryDataSeq =
413-
dyn_cast_if_present<llvm::ConstantDataSequential>(binaryInit);
414-
if (!binaryDataSeq)
415-
return op.emitError() << "Couldn't find binary data array: "
416-
<< binaryIdentifier;
417-
llvm::Constant *binarySize =
418-
llvm::ConstantInt::get(i64Ty, binaryDataSeq->getNumElements() *
419-
binaryDataSeq->getElementByteSize());
420-
421-
Value *moduleObject =
422-
object.getFormat() == gpu::CompilationTarget::Assembly
423-
? builder.CreateCall(getModuleLoadJITFn(), {binary, optV})
424-
: builder.CreateCall(getModuleLoadFn(), {binary, binarySize});
425-
426417
// Load the kernel function.
427-
Value *moduleFunction = builder.CreateCall(
428-
getModuleFunctionFn(),
429-
{moduleObject,
430-
getOrCreateFunctionName(moduleName, op.getKernelName().getValue())});
418+
StringRef moduleName = op.getKernelModuleName().getValue();
419+
Twine moduleIdentifier = getModuleIdentifier(moduleName);
420+
Value *modulePtr = module.getGlobalVariable(moduleIdentifier.str(), true);
421+
if (!modulePtr)
422+
return op.emitError() << "Couldn't find the binary: " << moduleIdentifier;
423+
Value *moduleObj = builder.CreateLoad(ptrTy, modulePtr);
424+
Value *functionName = getOrCreateFunctionName(moduleName, op.getKernelName());
425+
Value *moduleFunction =
426+
builder.CreateCall(getModuleFunctionFn(), {moduleObj, functionName});
431427

432428
// Get the stream to use for execution. If there's no async object then create
433429
// a stream to make a synchronous kernel launch.
434430
Value *stream = nullptr;
435-
bool handleStream = false;
431+
// Sync & destroy the stream, for synchronous launches.
432+
auto destroyStream = make_scope_exit([&]() {
433+
builder.CreateCall(getStreamSyncFn(), {stream});
434+
builder.CreateCall(getStreamDestroyFn(), {stream});
435+
});
436436
if (mlir::Value asyncObject = op.getAsyncObject()) {
437437
stream = llvmValue(asyncObject);
438+
destroyStream.release();
438439
} else {
439-
handleStream = true;
440440
stream = builder.CreateCall(getStreamCreateFn(), {});
441441
}
442442

@@ -462,14 +462,12 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
462462
argArray, nullPtr, paramsCount}));
463463
}
464464

465-
// Sync & destroy the stream, for synchronous launches.
466-
if (handleStream) {
467-
builder.CreateCall(getStreamSyncFn(), {stream});
468-
builder.CreateCall(getStreamDestroyFn(), {stream});
469-
}
470-
471-
// Unload the kernel module.
472-
builder.CreateCall(getModuleUnloadFn(), {moduleObject});
473-
474465
return success();
475466
}
467+
468+
void mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels(
469+
DialectRegistry &registry) {
470+
registry.addExtension(+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
471+
SelectObjectAttr::attachInterface<SelectObjectAttrImpl>(*ctx);
472+
});
473+
}

0 commit comments

Comments
 (0)