18
18
#include " mlir/Target/LLVMIR/Export.h"
19
19
#include " mlir/Target/LLVMIR/ModuleTranslation.h"
20
20
21
+ #include " llvm/ADT/ScopeExit.h"
21
22
#include " llvm/IR/Constants.h"
22
23
#include " llvm/IR/IRBuilder.h"
23
24
#include " llvm/IR/LLVMContext.h"
24
25
#include " llvm/IR/Module.h"
25
26
#include " llvm/Support/FormatVariadic.h"
27
+ #include " llvm/Transforms/Utils/ModuleUtils.h"
26
28
27
29
using namespace mlir ;
28
30
@@ -31,9 +33,13 @@ namespace {
31
33
class SelectObjectAttrImpl
32
34
: public gpu::OffloadingLLVMTranslationAttrInterface::FallbackModel<
33
35
SelectObjectAttrImpl> {
36
+ // Returns the selected object for embedding.
37
+ gpu::ObjectAttr getSelectedObject (gpu::BinaryOp op) const ;
38
+
34
39
public:
35
40
// Translates a `gpu.binary`, embedding the binary into a host LLVM module as
36
- // global binary string.
41
+ // global binary string which gets loaded/unloaded into a global module
42
+ // object through a global ctor/dtor.
37
43
LogicalResult embedBinary (Attribute attribute, Operation *operation,
38
44
llvm::IRBuilderBase &builder,
39
45
LLVM::ModuleTranslation &moduleTranslation) const ;
@@ -45,23 +51,9 @@ class SelectObjectAttrImpl
45
51
Operation *binaryOperation,
46
52
llvm::IRBuilderBase &builder,
47
53
LLVM::ModuleTranslation &moduleTranslation) const ;
48
-
49
- // Returns the selected object for embedding.
50
- gpu::ObjectAttr getSelectedObject (gpu::BinaryOp op) const ;
51
54
};
52
- // Returns an identifier for the global string holding the binary.
53
- std::string getBinaryIdentifier (StringRef binaryName) {
54
- return binaryName.str () + " _bin_cst" ;
55
- }
56
55
} // namespace
57
56
58
- void mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels (
59
- DialectRegistry ®istry) {
60
- registry.addExtension (+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
61
- SelectObjectAttr::attachInterface<SelectObjectAttrImpl>(*ctx);
62
- });
63
- }
64
-
65
57
gpu::ObjectAttr
66
58
SelectObjectAttrImpl::getSelectedObject (gpu::BinaryOp op) const {
67
59
ArrayRef<Attribute> objects = op.getObjectsAttr ().getValue ();
@@ -96,6 +88,94 @@ SelectObjectAttrImpl::getSelectedObject(gpu::BinaryOp op) const {
96
88
return mlir::dyn_cast<gpu::ObjectAttr>(objects[index ]);
97
89
}
98
90
91
+ static Twine getModuleIdentifier (StringRef moduleName) {
92
+ return moduleName + " _module" ;
93
+ }
94
+
95
+ namespace llvm {
96
+ static LogicalResult embedBinaryImpl (StringRef moduleName,
97
+ gpu::ObjectAttr object, Module &module) {
98
+
99
+ // Embed the object as a global string.
100
+ // Add null for assembly output for JIT paths that expect null-terminated
101
+ // strings.
102
+ bool addNull = (object.getFormat () == gpu::CompilationTarget::Assembly);
103
+ StringRef serializedStr = object.getObject ().getValue ();
104
+ Constant *serializedCst =
105
+ ConstantDataArray::getString (module.getContext (), serializedStr, addNull);
106
+ GlobalVariable *serializedObj =
107
+ new GlobalVariable (module, serializedCst->getType (), true ,
108
+ GlobalValue::LinkageTypes::InternalLinkage,
109
+ serializedCst, moduleName + " _binary" );
110
+ serializedObj->setAlignment (MaybeAlign (8 ));
111
+ serializedObj->setUnnamedAddr (GlobalValue::UnnamedAddr::None);
112
+
113
+ // Default JIT optimization level.
114
+ auto optLevel = APInt::getZero (32 );
115
+
116
+ if (DictionaryAttr objectProps = object.getProperties ()) {
117
+ if (auto section = dyn_cast_or_null<StringAttr>(
118
+ objectProps.get (gpu::elfSectionName))) {
119
+ serializedObj->setSection (section.getValue ());
120
+ }
121
+ // Check if there's an optimization level embedded in the object.
122
+ if (auto optAttr = dyn_cast_or_null<IntegerAttr>(objectProps.get (" O" )))
123
+ optLevel = optAttr.getValue ();
124
+ }
125
+
126
+ IRBuilder<> builder (module.getContext ());
127
+ auto i32Ty = builder.getInt32Ty ();
128
+ auto i64Ty = builder.getInt64Ty ();
129
+ auto ptrTy = builder.getPtrTy (0 );
130
+ auto voidTy = builder.getVoidTy ();
131
+
132
+ // Embed the module as a global object.
133
+ auto *modulePtr = new GlobalVariable (
134
+ module, ptrTy, /* isConstant=*/ false , GlobalValue::InternalLinkage,
135
+ /* Initializer=*/ ConstantPointerNull::get (ptrTy),
136
+ getModuleIdentifier (moduleName));
137
+
138
+ auto *loadFn = Function::Create (FunctionType::get (voidTy, /* IsVarArg=*/ false ),
139
+ GlobalValue::InternalLinkage,
140
+ moduleName + " _load" , module);
141
+ loadFn->setSection (" .text.startup" );
142
+ auto *loadBlock = BasicBlock::Create (module.getContext (), " entry" , loadFn);
143
+ builder.SetInsertPoint (loadBlock);
144
+ Value *moduleObj = [&] {
145
+ if (object.getFormat () == gpu::CompilationTarget::Assembly) {
146
+ FunctionCallee moduleLoadFn = module.getOrInsertFunction (
147
+ " mgpuModuleLoadJIT" , FunctionType::get (ptrTy, {ptrTy, i32Ty}, false ));
148
+ Constant *optValue = ConstantInt::get (i32Ty, optLevel);
149
+ return builder.CreateCall (moduleLoadFn, {serializedObj, optValue});
150
+ } else {
151
+ FunctionCallee moduleLoadFn = module.getOrInsertFunction (
152
+ " mgpuModuleLoad" , FunctionType::get (ptrTy, {ptrTy, i64Ty}, false ));
153
+ Constant *binarySize =
154
+ ConstantInt::get (i64Ty, serializedStr.size () + (addNull ? 1 : 0 ));
155
+ return builder.CreateCall (moduleLoadFn, {serializedObj, binarySize});
156
+ }
157
+ }();
158
+ builder.CreateStore (moduleObj, modulePtr);
159
+ builder.CreateRetVoid ();
160
+ appendToGlobalCtors (module, loadFn, /* Priority=*/ 123 );
161
+
162
+ auto *unloadFn = Function::Create (
163
+ FunctionType::get (voidTy, /* IsVarArg=*/ false ),
164
+ GlobalValue::InternalLinkage, moduleName + " _unload" , module);
165
+ unloadFn->setSection (" .text.startup" );
166
+ auto *unloadBlock =
167
+ BasicBlock::Create (module.getContext (), " entry" , unloadFn);
168
+ builder.SetInsertPoint (unloadBlock);
169
+ FunctionCallee moduleUnloadFn = module.getOrInsertFunction (
170
+ " mgpuModuleUnload" , FunctionType::get (voidTy, ptrTy, false ));
171
+ builder.CreateCall (moduleUnloadFn, builder.CreateLoad (ptrTy, modulePtr));
172
+ builder.CreateRetVoid ();
173
+ appendToGlobalDtors (module, unloadFn, /* Priority=*/ 123 );
174
+
175
+ return success ();
176
+ }
177
+ } // namespace llvm
178
+
99
179
LogicalResult SelectObjectAttrImpl::embedBinary (
100
180
Attribute attribute, Operation *operation, llvm::IRBuilderBase &builder,
101
181
LLVM::ModuleTranslation &moduleTranslation) const {
@@ -113,29 +193,8 @@ LogicalResult SelectObjectAttrImpl::embedBinary(
113
193
if (!object)
114
194
return failure ();
115
195
116
- llvm::Module *module = moduleTranslation.getLLVMModule ();
117
-
118
- // Embed the object as a global string.
119
- // Add null for assembly output for JIT paths that expect null-terminated
120
- // strings.
121
- bool addNull = (object.getFormat () == gpu::CompilationTarget::Assembly);
122
- llvm::Constant *binary = llvm::ConstantDataArray::getString (
123
- builder.getContext (), object.getObject ().getValue (), addNull);
124
- llvm::GlobalVariable *serializedObj =
125
- new llvm::GlobalVariable (*module, binary->getType (), true ,
126
- llvm::GlobalValue::LinkageTypes::InternalLinkage,
127
- binary, getBinaryIdentifier (op.getName ()));
128
-
129
- if (object.getProperties ()) {
130
- if (auto section = mlir::dyn_cast_or_null<mlir::StringAttr>(
131
- object.getProperties ().get (gpu::elfSectionName))) {
132
- serializedObj->setSection (section.getValue ());
133
- }
134
- }
135
- serializedObj->setLinkage (llvm::GlobalValue::LinkageTypes::InternalLinkage);
136
- serializedObj->setAlignment (llvm::MaybeAlign (8 ));
137
- serializedObj->setUnnamedAddr (llvm::GlobalValue::UnnamedAddr::None);
138
- return success ();
196
+ return embedBinaryImpl (op.getName (), object,
197
+ *moduleTranslation.getLLVMModule ());
139
198
}
140
199
141
200
namespace llvm {
@@ -153,15 +212,6 @@ class LaunchKernel {
153
212
// Get the module function callee.
154
213
FunctionCallee getModuleFunctionFn ();
155
214
156
- // Get the module load callee.
157
- FunctionCallee getModuleLoadFn ();
158
-
159
- // Get the module load JIT callee.
160
- FunctionCallee getModuleLoadJITFn ();
161
-
162
- // Get the module unload callee.
163
- FunctionCallee getModuleUnloadFn ();
164
-
165
215
// Get the stream create callee.
166
216
FunctionCallee getStreamCreateFn ();
167
217
@@ -261,24 +311,6 @@ llvm::FunctionCallee llvm::LaunchKernel::getModuleFunctionFn() {
261
311
FunctionType::get (ptrTy, ArrayRef<Type *>({ptrTy, ptrTy}), false ));
262
312
}
263
313
264
- llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadFn () {
265
- return module.getOrInsertFunction (
266
- " mgpuModuleLoad" ,
267
- FunctionType::get (ptrTy, ArrayRef<Type *>({ptrTy, i64Ty}), false ));
268
- }
269
-
270
- llvm::FunctionCallee llvm::LaunchKernel::getModuleLoadJITFn () {
271
- return module.getOrInsertFunction (
272
- " mgpuModuleLoadJIT" ,
273
- FunctionType::get (ptrTy, ArrayRef<Type *>({ptrTy, i32Ty}), false ));
274
- }
275
-
276
- llvm::FunctionCallee llvm::LaunchKernel::getModuleUnloadFn () {
277
- return module.getOrInsertFunction (
278
- " mgpuModuleUnload" ,
279
- FunctionType::get (voidTy, ArrayRef<Type *>({ptrTy}), false ));
280
- }
281
-
282
314
llvm::FunctionCallee llvm::LaunchKernel::getStreamCreateFn () {
283
315
return module.getOrInsertFunction (" mgpuStreamCreate" ,
284
316
FunctionType::get (ptrTy, false ));
@@ -301,9 +333,9 @@ llvm::FunctionCallee llvm::LaunchKernel::getStreamSyncFn() {
301
333
llvm::Value *llvm::LaunchKernel::getOrCreateFunctionName (StringRef moduleName,
302
334
StringRef kernelName) {
303
335
std::string globalName =
304
- std::string (formatv (" {0}_{1}_kernel_name " , moduleName, kernelName));
336
+ std::string (formatv (" {0}_{1}_name " , moduleName, kernelName));
305
337
306
- if (GlobalVariable *gv = module.getGlobalVariable (globalName))
338
+ if (GlobalVariable *gv = module.getGlobalVariable (globalName, true ))
307
339
return gv;
308
340
309
341
return builder.CreateGlobalString (kernelName, globalName);
@@ -346,16 +378,13 @@ llvm::LaunchKernel::createKernelArgArray(mlir::gpu::LaunchFuncOp op) {
346
378
}
347
379
348
380
// Emits LLVM IR to launch a kernel function:
349
- // %0 = call %binarygetter
350
- // %1 = call %moduleLoad(%0)
351
- // %2 = <see generateKernelNameConstant>
352
- // %3 = call %moduleGetFunction(%1, %2)
353
- // %4 = call %streamCreate()
354
- // %5 = <see generateParamsArray>
355
- // call %launchKernel(%3, <launchOp operands 0..5>, 0, %4, %5, nullptr)
356
- // call %streamSynchronize(%4)
357
- // call %streamDestroy(%4)
358
- // call %moduleUnload(%1)
381
+ // %1 = load %global_module_object
382
+ // %2 = call @mgpuModuleGetFunction(%1, %global_kernel_name)
383
+ // %3 = call @mgpuStreamCreate()
384
+ // %4 = <see createKernelArgArray()>
385
+ // call @mgpuLaunchKernel(%2, ..., %3, %4, ...)
386
+ // call @mgpuStreamSynchronize(%3)
387
+ // call @mgpuStreamDestroy(%3)
359
388
llvm::LogicalResult
360
389
llvm::LaunchKernel::createKernelLaunch (mlir::gpu::LaunchFuncOp op,
361
390
mlir::gpu::ObjectAttr object) {
@@ -385,58 +414,29 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
385
414
// Create the argument array.
386
415
Value *argArray = createKernelArgArray (op);
387
416
388
- // Default JIT optimization level.
389
- llvm::Constant *optV = llvm::ConstantInt::get (i32Ty, 0 );
390
- // Check if there's an optimization level embedded in the object.
391
- DictionaryAttr objectProps = object.getProperties ();
392
- mlir::Attribute optAttr;
393
- if (objectProps && (optAttr = objectProps.get (" O" ))) {
394
- auto optLevel = dyn_cast<IntegerAttr>(optAttr);
395
- if (!optLevel)
396
- return op.emitError (" the optimization level must be an integer" );
397
- optV = llvm::ConstantInt::get (i32Ty, optLevel.getValue ());
398
- }
399
-
400
- // Load the kernel module.
401
- StringRef moduleName = op.getKernelModuleName ().getValue ();
402
- std::string binaryIdentifier = getBinaryIdentifier (moduleName);
403
- Value *binary = module.getGlobalVariable (binaryIdentifier, true );
404
- if (!binary)
405
- return op.emitError () << " Couldn't find the binary: " << binaryIdentifier;
406
-
407
- auto binaryVar = dyn_cast<llvm::GlobalVariable>(binary);
408
- if (!binaryVar)
409
- return op.emitError () << " Binary is not a global variable: "
410
- << binaryIdentifier;
411
- llvm::Constant *binaryInit = binaryVar->getInitializer ();
412
- auto binaryDataSeq =
413
- dyn_cast_if_present<llvm::ConstantDataSequential>(binaryInit);
414
- if (!binaryDataSeq)
415
- return op.emitError () << " Couldn't find binary data array: "
416
- << binaryIdentifier;
417
- llvm::Constant *binarySize =
418
- llvm::ConstantInt::get (i64Ty, binaryDataSeq->getNumElements () *
419
- binaryDataSeq->getElementByteSize ());
420
-
421
- Value *moduleObject =
422
- object.getFormat () == gpu::CompilationTarget::Assembly
423
- ? builder.CreateCall (getModuleLoadJITFn (), {binary, optV})
424
- : builder.CreateCall (getModuleLoadFn (), {binary, binarySize});
425
-
426
417
// Load the kernel function.
427
- Value *moduleFunction = builder.CreateCall (
428
- getModuleFunctionFn (),
429
- {moduleObject,
430
- getOrCreateFunctionName (moduleName, op.getKernelName ().getValue ())});
418
+ StringRef moduleName = op.getKernelModuleName ().getValue ();
419
+ Twine moduleIdentifier = getModuleIdentifier (moduleName);
420
+ Value *modulePtr = module.getGlobalVariable (moduleIdentifier.str (), true );
421
+ if (!modulePtr)
422
+ return op.emitError () << " Couldn't find the binary: " << moduleIdentifier;
423
+ Value *moduleObj = builder.CreateLoad (ptrTy, modulePtr);
424
+ Value *functionName = getOrCreateFunctionName (moduleName, op.getKernelName ());
425
+ Value *moduleFunction =
426
+ builder.CreateCall (getModuleFunctionFn (), {moduleObj, functionName});
431
427
432
428
// Get the stream to use for execution. If there's no async object then create
433
429
// a stream to make a synchronous kernel launch.
434
430
Value *stream = nullptr ;
435
- bool handleStream = false ;
431
+ // Sync & destroy the stream, for synchronous launches.
432
+ auto destroyStream = make_scope_exit ([&]() {
433
+ builder.CreateCall (getStreamSyncFn (), {stream});
434
+ builder.CreateCall (getStreamDestroyFn (), {stream});
435
+ });
436
436
if (mlir::Value asyncObject = op.getAsyncObject ()) {
437
437
stream = llvmValue (asyncObject);
438
+ destroyStream.release ();
438
439
} else {
439
- handleStream = true ;
440
440
stream = builder.CreateCall (getStreamCreateFn (), {});
441
441
}
442
442
@@ -462,14 +462,12 @@ llvm::LaunchKernel::createKernelLaunch(mlir::gpu::LaunchFuncOp op,
462
462
argArray, nullPtr, paramsCount}));
463
463
}
464
464
465
- // Sync & destroy the stream, for synchronous launches.
466
- if (handleStream) {
467
- builder.CreateCall (getStreamSyncFn (), {stream});
468
- builder.CreateCall (getStreamDestroyFn (), {stream});
469
- }
470
-
471
- // Unload the kernel module.
472
- builder.CreateCall (getModuleUnloadFn (), {moduleObject});
473
-
474
465
return success ();
475
466
}
467
+
468
+ void mlir::gpu::registerOffloadingLLVMTranslationInterfaceExternalModels (
469
+ DialectRegistry ®istry) {
470
+ registry.addExtension (+[](MLIRContext *ctx, gpu::GPUDialect *dialect) {
471
+ SelectObjectAttr::attachInterface<SelectObjectAttrImpl>(*ctx);
472
+ });
473
+ }
0 commit comments