93
93
#include " NVPTXTargetMachine.h"
94
94
#include " NVPTXUtilities.h"
95
95
#include " llvm/Analysis/ValueTracking.h"
96
+ #include " llvm/CodeGen/TargetPassConfig.h"
96
97
#include " llvm/IR/Function.h"
97
98
#include " llvm/IR/Instructions.h"
98
99
#include " llvm/IR/Module.h"
99
100
#include " llvm/IR/Type.h"
101
+ #include " llvm/InitializePasses.h"
100
102
#include " llvm/Pass.h"
101
103
#include < numeric>
102
104
#include < queue>
@@ -113,11 +115,11 @@ namespace {
113
115
class NVPTXLowerArgs : public FunctionPass {
114
116
bool runOnFunction (Function &F) override ;
115
117
116
- bool runOnKernelFunction (Function &F);
117
- bool runOnDeviceFunction (Function &F);
118
+ bool runOnKernelFunction (const NVPTXTargetMachine &TM, Function &F);
119
+ bool runOnDeviceFunction (const NVPTXTargetMachine &TM, Function &F);
118
120
119
121
// handle byval parameters
120
- void handleByValParam (Argument *Arg);
122
+ void handleByValParam (const NVPTXTargetMachine &TM, Argument *Arg);
121
123
// Knowing Ptr must point to the global address space, this function
122
124
// addrspacecasts Ptr to global and then back to generic. This allows
123
125
// NVPTXInferAddressSpaces to fold the global-to-generic cast into
@@ -126,21 +128,23 @@ class NVPTXLowerArgs : public FunctionPass {
126
128
127
129
public:
128
130
static char ID; // Pass identification, replacement for typeid
129
- NVPTXLowerArgs (const NVPTXTargetMachine *TM = nullptr )
130
- : FunctionPass(ID), TM(TM) {}
131
+ NVPTXLowerArgs () : FunctionPass(ID) {}
131
132
StringRef getPassName () const override {
132
133
return " Lower pointer arguments of CUDA kernels" ;
133
134
}
134
-
135
- private:
136
- const NVPTXTargetMachine *TM;
135
+ void getAnalysisUsage (AnalysisUsage &AU) const override {
136
+ AU. addRequired <TargetPassConfig>();
137
+ }
137
138
};
138
139
} // namespace
139
140
140
141
char NVPTXLowerArgs::ID = 1 ;
141
142
142
- INITIALIZE_PASS (NVPTXLowerArgs, " nvptx-lower-args" ,
143
- " Lower arguments (NVPTX)" , false , false )
143
+ INITIALIZE_PASS_BEGIN (NVPTXLowerArgs, " nvptx-lower-args" ,
144
+ " Lower arguments (NVPTX)" , false , false )
145
+ INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
146
+ INITIALIZE_PASS_END(NVPTXLowerArgs, " nvptx-lower-args" ,
147
+ " Lower arguments (NVPTX)" , false , false )
144
148
145
149
// =============================================================================
146
150
// If the function had a byval struct ptr arg, say foo(%struct.x* byval %d),
@@ -310,7 +314,8 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
310
314
}
311
315
}
312
316
313
- void NVPTXLowerArgs::handleByValParam (Argument *Arg) {
317
+ void NVPTXLowerArgs::handleByValParam (const NVPTXTargetMachine &TM,
318
+ Argument *Arg) {
314
319
Function *Func = Arg->getParent ();
315
320
Instruction *FirstInst = &(Func->getEntryBlock ().front ());
316
321
Type *StructType = Arg->getParamByValType ();
@@ -354,12 +359,8 @@ void NVPTXLowerArgs::handleByValParam(Argument *Arg) {
354
359
convertToParamAS (V, ArgInParamAS);
355
360
LLVM_DEBUG (dbgs () << " No need to copy " << *Arg << " \n " );
356
361
357
- // Further optimizations require target lowering info.
358
- if (!TM)
359
- return ;
360
-
361
362
const auto *TLI =
362
- cast<NVPTXTargetLowering>(TM-> getSubtargetImpl ()->getTargetLowering ());
363
+ cast<NVPTXTargetLowering>(TM. getSubtargetImpl ()->getTargetLowering ());
363
364
364
365
adjustByValArgAlignment (Arg, ArgInParamAS, TLI);
365
366
@@ -390,7 +391,7 @@ void NVPTXLowerArgs::handleByValParam(Argument *Arg) {
390
391
}
391
392
392
393
void NVPTXLowerArgs::markPointerAsGlobal (Value *Ptr) {
393
- if (Ptr->getType ()->getPointerAddressSpace () == ADDRESS_SPACE_GLOBAL )
394
+ if (Ptr->getType ()->getPointerAddressSpace () != ADDRESS_SPACE_GENERIC )
394
395
return ;
395
396
396
397
// Deciding where to emit the addrspacecast pair.
@@ -420,8 +421,9 @@ void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
420
421
// =============================================================================
421
422
// Main function for this pass.
422
423
// =============================================================================
423
- bool NVPTXLowerArgs::runOnKernelFunction (Function &F) {
424
- if (TM && TM->getDrvInterface () == NVPTX::CUDA) {
424
+ bool NVPTXLowerArgs::runOnKernelFunction (const NVPTXTargetMachine &TM,
425
+ Function &F) {
426
+ if (TM.getDrvInterface () == NVPTX::CUDA) {
425
427
// Mark pointers in byval structs as global.
426
428
for (auto &B : F) {
427
429
for (auto &I : B) {
@@ -444,28 +446,29 @@ bool NVPTXLowerArgs::runOnKernelFunction(Function &F) {
444
446
for (Argument &Arg : F.args ()) {
445
447
if (Arg.getType ()->isPointerTy ()) {
446
448
if (Arg.hasByValAttr ())
447
- handleByValParam (&Arg);
448
- else if (TM && TM-> getDrvInterface () == NVPTX::CUDA)
449
+ handleByValParam (TM, &Arg);
450
+ else if (TM. getDrvInterface () == NVPTX::CUDA)
449
451
markPointerAsGlobal (&Arg);
450
452
}
451
453
}
452
454
return true ;
453
455
}
454
456
455
457
// Device functions only need to copy byval args into local memory.
456
- bool NVPTXLowerArgs::runOnDeviceFunction (Function &F) {
458
+ bool NVPTXLowerArgs::runOnDeviceFunction (const NVPTXTargetMachine &TM,
459
+ Function &F) {
457
460
LLVM_DEBUG (dbgs () << " Lowering function args of " << F.getName () << " \n " );
458
461
for (Argument &Arg : F.args ())
459
462
if (Arg.getType ()->isPointerTy () && Arg.hasByValAttr ())
460
- handleByValParam (&Arg);
463
+ handleByValParam (TM, &Arg);
461
464
return true ;
462
465
}
463
466
464
467
bool NVPTXLowerArgs::runOnFunction (Function &F) {
465
- return isKernelFunction (F) ? runOnKernelFunction (F) : runOnDeviceFunction (F);
466
- }
468
+ auto &TM = getAnalysis<TargetPassConfig>().getTM <NVPTXTargetMachine>();
467
469
468
- FunctionPass *
469
- llvm::createNVPTXLowerArgsPass (const NVPTXTargetMachine *TM) {
470
- return new NVPTXLowerArgs (TM);
470
+ return isKernelFunction (F) ? runOnKernelFunction (TM, F)
471
+ : runOnDeviceFunction (TM, F);
471
472
}
473
+
474
+ FunctionPass *llvm::createNVPTXLowerArgsPass () { return new NVPTXLowerArgs (); }
0 commit comments