Skip to content

Commit 8db31e9

Browse files
committed
[NVPTX] Do not addrspacecast AS-specific kernel arguments.
Fixes llvm#46954 The assumption that generic pointers passed to a CUDA kernel is CUDA-specific and should not be applied to non-CUDA compilations. Addrspacecasts to global AS and back should never be applied to AS-specific pointers. In order to make tests actually do the testing for non-CUDA compilation, we need to get TargetMachine from the TargetPassConfig, instead of passing it explicitly as a pass constructor argument. Differential Revision: https://reviews.llvm.org/D142581
1 parent 6185246 commit 8db31e9

File tree

4 files changed

+76
-32
lines changed

4 files changed

+76
-32
lines changed

llvm/lib/Target/NVPTX/NVPTX.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ FunctionPass *createNVVMReflectPass(unsigned int SmVersion);
4444
MachineFunctionPass *createNVPTXPrologEpilogPass();
4545
MachineFunctionPass *createNVPTXReplaceImageHandlesPass();
4646
FunctionPass *createNVPTXImageOptimizerPass();
47-
FunctionPass *createNVPTXLowerArgsPass(const NVPTXTargetMachine *TM);
47+
FunctionPass *createNVPTXLowerArgsPass();
4848
FunctionPass *createNVPTXLowerAllocaPass();
4949
MachineFunctionPass *createNVPTXPeephole();
5050
MachineFunctionPass *createNVPTXProxyRegErasurePass();

llvm/lib/Target/NVPTX/NVPTXLowerArgs.cpp

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,12 @@
9393
#include "NVPTXTargetMachine.h"
9494
#include "NVPTXUtilities.h"
9595
#include "llvm/Analysis/ValueTracking.h"
96+
#include "llvm/CodeGen/TargetPassConfig.h"
9697
#include "llvm/IR/Function.h"
9798
#include "llvm/IR/Instructions.h"
9899
#include "llvm/IR/Module.h"
99100
#include "llvm/IR/Type.h"
101+
#include "llvm/InitializePasses.h"
100102
#include "llvm/Pass.h"
101103
#include <numeric>
102104
#include <queue>
@@ -113,11 +115,11 @@ namespace {
113115
class NVPTXLowerArgs : public FunctionPass {
114116
bool runOnFunction(Function &F) override;
115117

116-
bool runOnKernelFunction(Function &F);
117-
bool runOnDeviceFunction(Function &F);
118+
bool runOnKernelFunction(const NVPTXTargetMachine &TM, Function &F);
119+
bool runOnDeviceFunction(const NVPTXTargetMachine &TM, Function &F);
118120

119121
// handle byval parameters
120-
void handleByValParam(Argument *Arg);
122+
void handleByValParam(const NVPTXTargetMachine &TM, Argument *Arg);
121123
// Knowing Ptr must point to the global address space, this function
122124
// addrspacecasts Ptr to global and then back to generic. This allows
123125
// NVPTXInferAddressSpaces to fold the global-to-generic cast into
@@ -126,21 +128,23 @@ class NVPTXLowerArgs : public FunctionPass {
126128

127129
public:
128130
static char ID; // Pass identification, replacement for typeid
129-
NVPTXLowerArgs(const NVPTXTargetMachine *TM = nullptr)
130-
: FunctionPass(ID), TM(TM) {}
131+
NVPTXLowerArgs() : FunctionPass(ID) {}
131132
StringRef getPassName() const override {
132133
return "Lower pointer arguments of CUDA kernels";
133134
}
134-
135-
private:
136-
const NVPTXTargetMachine *TM;
135+
void getAnalysisUsage(AnalysisUsage &AU) const override {
136+
AU.addRequired<TargetPassConfig>();
137+
}
137138
};
138139
} // namespace
139140

140141
char NVPTXLowerArgs::ID = 1;
141142

142-
INITIALIZE_PASS(NVPTXLowerArgs, "nvptx-lower-args",
143-
"Lower arguments (NVPTX)", false, false)
143+
INITIALIZE_PASS_BEGIN(NVPTXLowerArgs, "nvptx-lower-args",
144+
"Lower arguments (NVPTX)", false, false)
145+
INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
146+
INITIALIZE_PASS_END(NVPTXLowerArgs, "nvptx-lower-args",
147+
"Lower arguments (NVPTX)", false, false)
144148

145149
// =============================================================================
146150
// If the function had a byval struct ptr arg, say foo(%struct.x* byval %d),
@@ -310,7 +314,8 @@ static void adjustByValArgAlignment(Argument *Arg, Value *ArgInParamAS,
310314
}
311315
}
312316

313-
void NVPTXLowerArgs::handleByValParam(Argument *Arg) {
317+
void NVPTXLowerArgs::handleByValParam(const NVPTXTargetMachine &TM,
318+
Argument *Arg) {
314319
Function *Func = Arg->getParent();
315320
Instruction *FirstInst = &(Func->getEntryBlock().front());
316321
Type *StructType = Arg->getParamByValType();
@@ -354,12 +359,8 @@ void NVPTXLowerArgs::handleByValParam(Argument *Arg) {
354359
convertToParamAS(V, ArgInParamAS);
355360
LLVM_DEBUG(dbgs() << "No need to copy " << *Arg << "\n");
356361

357-
// Further optimizations require target lowering info.
358-
if (!TM)
359-
return;
360-
361362
const auto *TLI =
362-
cast<NVPTXTargetLowering>(TM->getSubtargetImpl()->getTargetLowering());
363+
cast<NVPTXTargetLowering>(TM.getSubtargetImpl()->getTargetLowering());
363364

364365
adjustByValArgAlignment(Arg, ArgInParamAS, TLI);
365366

@@ -390,7 +391,7 @@ void NVPTXLowerArgs::handleByValParam(Argument *Arg) {
390391
}
391392

392393
void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
393-
if (Ptr->getType()->getPointerAddressSpace() == ADDRESS_SPACE_GLOBAL)
394+
if (Ptr->getType()->getPointerAddressSpace() != ADDRESS_SPACE_GENERIC)
394395
return;
395396

396397
// Deciding where to emit the addrspacecast pair.
@@ -420,8 +421,9 @@ void NVPTXLowerArgs::markPointerAsGlobal(Value *Ptr) {
420421
// =============================================================================
421422
// Main function for this pass.
422423
// =============================================================================
423-
bool NVPTXLowerArgs::runOnKernelFunction(Function &F) {
424-
if (TM && TM->getDrvInterface() == NVPTX::CUDA) {
424+
bool NVPTXLowerArgs::runOnKernelFunction(const NVPTXTargetMachine &TM,
425+
Function &F) {
426+
if (TM.getDrvInterface() == NVPTX::CUDA) {
425427
// Mark pointers in byval structs as global.
426428
for (auto &B : F) {
427429
for (auto &I : B) {
@@ -444,28 +446,29 @@ bool NVPTXLowerArgs::runOnKernelFunction(Function &F) {
444446
for (Argument &Arg : F.args()) {
445447
if (Arg.getType()->isPointerTy()) {
446448
if (Arg.hasByValAttr())
447-
handleByValParam(&Arg);
448-
else if (TM && TM->getDrvInterface() == NVPTX::CUDA)
449+
handleByValParam(TM, &Arg);
450+
else if (TM.getDrvInterface() == NVPTX::CUDA)
449451
markPointerAsGlobal(&Arg);
450452
}
451453
}
452454
return true;
453455
}
454456

455457
// Device functions only need to copy byval args into local memory.
456-
bool NVPTXLowerArgs::runOnDeviceFunction(Function &F) {
458+
bool NVPTXLowerArgs::runOnDeviceFunction(const NVPTXTargetMachine &TM,
459+
Function &F) {
457460
LLVM_DEBUG(dbgs() << "Lowering function args of " << F.getName() << "\n");
458461
for (Argument &Arg : F.args())
459462
if (Arg.getType()->isPointerTy() && Arg.hasByValAttr())
460-
handleByValParam(&Arg);
463+
handleByValParam(TM, &Arg);
461464
return true;
462465
}
463466

464467
bool NVPTXLowerArgs::runOnFunction(Function &F) {
465-
return isKernelFunction(F) ? runOnKernelFunction(F) : runOnDeviceFunction(F);
466-
}
468+
auto &TM = getAnalysis<TargetPassConfig>().getTM<NVPTXTargetMachine>();
467469

468-
FunctionPass *
469-
llvm::createNVPTXLowerArgsPass(const NVPTXTargetMachine *TM) {
470-
return new NVPTXLowerArgs(TM);
470+
return isKernelFunction(F) ? runOnKernelFunction(TM, F)
471+
: runOnDeviceFunction(TM, F);
471472
}
473+
474+
FunctionPass *llvm::createNVPTXLowerArgsPass() { return new NVPTXLowerArgs(); }

llvm/lib/Target/NVPTX/NVPTXTargetMachine.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ void NVPTXPassConfig::addIRPasses() {
326326

327327
// NVPTXLowerArgs is required for correctness and should be run right
328328
// before the address space inference passes.
329-
addPass(createNVPTXLowerArgsPass(&getNVPTXTargetMachine()));
329+
addPass(createNVPTXLowerArgsPass());
330330
if (getOptLevel() != CodeGenOpt::None) {
331331
addAddressSpaceInferencePasses();
332332
addStraightLineScalarOptimizationPasses();

llvm/test/CodeGen/NVPTX/lower-args.ll

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1-
; RUN: opt < %s -S -nvptx-lower-args | FileCheck %s --check-prefix IR
2-
; RUN: llc < %s -mcpu=sm_20 | FileCheck %s --check-prefix PTX
1+
; RUN: opt < %s -S -nvptx-lower-args --mtriple nvptx64-nvidia-cuda | FileCheck %s --check-prefixes COMMON,IR,IRC
2+
; RUN: opt < %s -S -nvptx-lower-args --mtriple nvptx64-nvidia-nvcl | FileCheck %s --check-prefixes COMMON,IR,IRO
3+
; RUN: llc < %s -mcpu=sm_20 --mtriple nvptx64-nvidia-cuda | FileCheck %s --check-prefixes COMMON,PTX,PTXC
4+
; RUN: llc < %s -mcpu=sm_20 --mtriple nvptx64-nvidia-nvcl| FileCheck %s --check-prefixes COMMON,PTX,PTXO
35
; RUN: %if ptxas %{ llc < %s -mcpu=sm_20 | %ptxas-verify %}
46

57
target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"
@@ -9,6 +11,7 @@ target triple = "nvptx64-nvidia-cuda"
911
%class.inner = type { ptr, ptr }
1012

1113
; Check that nvptx-lower-args preserves arg alignment
14+
; COMMON-LABEL: load_alignment
1215
define void @load_alignment(ptr nocapture readonly byval(%class.outer) align 8 %arg) {
1316
entry:
1417
; IR: load %class.outer, ptr addrspace(101)
@@ -30,5 +33,43 @@ entry:
3033
ret void
3134
}
3235

36+
37+
; COMMON-LABEL: ptr_generic
38+
define void @ptr_generic(ptr %out, ptr %in) {
39+
; IRC: %in3 = addrspacecast ptr %in to ptr addrspace(1)
40+
; IRC: %in4 = addrspacecast ptr addrspace(1) %in3 to ptr
41+
; IRC: %out1 = addrspacecast ptr %out to ptr addrspace(1)
42+
; IRC: %out2 = addrspacecast ptr addrspace(1) %out1 to ptr
43+
; PTXC: cvta.to.global.u64
44+
; PTXC: cvta.to.global.u64
45+
; PTXC: ld.global.u32
46+
; PTXC: st.global.u32
47+
48+
; OpenCL can't make assumptions about incoming pointer, so we should generate
49+
; generic pointers load/store.
50+
; IRO-NOT: addrspacecast
51+
; PTXO-NOT: cvta.to.global
52+
; PTXO: ld.u32
53+
; PTXO: st.u32
54+
%v = load i32, ptr %in, align 4
55+
store i32 %v, ptr %out, align 4
56+
ret void
57+
}
58+
59+
; COMMON-LABEL: ptr_nongeneric
60+
define void @ptr_nongeneric(ptr addrspace(1) %out, ptr addrspace(4) %in) {
61+
; IR-NOT: addrspacecast
62+
; PTX-NOT: cvta.to.global
63+
; PTX: ld.const.u32
64+
; PTX st.global.u32
65+
%v = load i32, ptr addrspace(4) %in, align 4
66+
store i32 %v, ptr addrspace(1) %out, align 4
67+
ret void
68+
}
69+
70+
3371
; Function Attrs: convergent nounwind
3472
declare dso_local ptr @escape(ptr) local_unnamed_addr
73+
!nvvm.annotations = !{!0, !1}
74+
!0 = !{ptr @ptr_generic, !"kernel", i32 1}
75+
!1 = !{ptr @ptr_nongeneric, !"kernel", i32 1}

0 commit comments

Comments
 (0)