-
Notifications
You must be signed in to change notification settings - Fork 25
/
HipPrintf.cpp
182 lines (149 loc) · 5.94 KB
/
HipPrintf.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
// LLVM IR pass to convert calls to the CUDA/HIP printf() to OpenCL/SPIR-V
// compatible ones.
//
// (c) 2021 Pekka Jääskeläinen / Parmance for Argonne National Laboratory
//
// SPIRV-LLVM translator generates a wrong pointer address space to printf
// format string if it's not the correct (constant) one in the input. This
// pass moves the format string to constant address space before we pass the IR
// to SPIRV emission.
#include "HipPrintf.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/Pass.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Passes/PassPlugin.h"
using namespace llvm;
class HipPrintfToOpenCLPrintfLegacyPass : public ModulePass {
public:
static char ID;
HipPrintfToOpenCLPrintfLegacyPass() : ModulePass(ID) {}
bool runOnModule(Module &M) override {
return false;
}
StringRef getPassName() const override {
return "Convert printf calls to OpenCL compatible ones.";
}
};
char HipPrintfToOpenCLPrintfLegacyPass::ID = 0;
static RegisterPass<HipPrintfToOpenCLPrintfLegacyPass>
X("hip-printf", "Convert printf calls to OpenCL compatible ones.");
// Converts the address space of the format string to OpenCL compatible one
// by creating a copy of it in the module global scope.
//
// Also counts the number of format args for replacing the return
// value of the printf() call with it for CUDA emulation.
Value* convertFormatString(Value *HipFmtStrArg, Instruction *Before,
unsigned &NumberOfFormatSpecs) {
Module *M = Before->getParent()->getParent()->getParent();
Type *Int8Ty = IntegerType::get(M->getContext(), 8);
ConstantExpr *CE = cast<ConstantExpr>(HipFmtStrArg);
Value *FmtStrOpr = CE->getOperand(0);
GlobalVariable *OrigFmtStr =
isa<GetElementPtrInst>(FmtStrOpr) ?
cast<GlobalVariable>(
cast<GetElementPtrInst>(FmtStrOpr)->getPointerOperand()) :
cast<GlobalVariable>(FmtStrOpr);
ConstantDataSequential *FmtStrData =
cast<ConstantDataSequential>(OrigFmtStr->getInitializer());
NumberOfFormatSpecs =
FmtStrData->getAsString().count("%") -
FmtStrData->getAsString().count("%%");
GlobalVariable *NewFmtStr = new GlobalVariable(
*M, OrigFmtStr->getValueType(), true, OrigFmtStr->getLinkage(),
FmtStrData, OrigFmtStr->getName() + ".cl",
(GlobalVariable *)nullptr, OrigFmtStr->getThreadLocalMode(),
SPIRV_OPENCL_PRINTF_FMT_ARG_AS);
NewFmtStr->copyAttributesFrom(OrigFmtStr);
IntegerType *Int64Ty = Type::getInt64Ty(M->getContext());
ConstantInt *Zero = ConstantInt::get(Int64Ty, 0);
std::array<Constant*, 2> Indices = {Zero, Zero};
PointerType *PtrTy =
cast<PointerType>(NewFmtStr->getType()->getScalarType());
return llvm::ConstantExpr::getGetElementPtr(
PtrTy->getElementType(), NewFmtStr, Indices);
}
PreservedAnalyses HipPrintfToOpenCLPrintfPass::run(
Function &F,
FunctionAnalysisManager &AM) {
Module *M = F.getParent();
GlobalValue *Printf = M->getNamedValue("printf");
GlobalValue *HipPrintf = M->getNamedValue("_hip_printf");
if (Printf == nullptr) {
// No printf decl in module, no printf calls in the function.
return PreservedAnalyses::all();
}
auto *Int8Ty = IntegerType::get(F.getContext(), 8);
auto *Int32Ty = IntegerType::get(F.getContext(), 32);
PointerType *OCLPrintfFmtArgT =
PointerType::get(Int8Ty, SPIRV_OPENCL_PRINTF_FMT_ARG_AS);
FunctionType *OpenCLPrintfTy =
FunctionType::get(Int32Ty, {OCLPrintfFmtArgT}, true);
FunctionCallee OpenCLPrintf;
if (HipPrintf == nullptr) {
// Create the OpenCL printf which will be used instead. Rename the
// old one away to _hip_printf.
Printf->setName("_hip_printf");
HipPrintf = Printf;
OpenCLPrintf =
M->getOrInsertFunction(
"printf", OpenCLPrintfTy, cast<Function>(HipPrintf)->getAttributes());
} else {
OpenCLPrintf = FunctionCallee(OpenCLPrintfTy, Printf);
}
SmallPtrSet<Instruction *, 8> EraseList;
for (auto &BB : F) {
for (auto &I : BB) {
if (!isa<CallInst>(I) ||
(cast<CallInst>(I).getCalledFunction()->getName() != "_hip_printf"))
continue;
CallInst &OrigCall = cast<CallInst>(I);
std::vector<Value *> Args;
unsigned FmtSpecCount;
for (auto &OrigArg : OrigCall.args()) {
if (Args.size() == 0) {
Args.push_back(
convertFormatString(OrigArg, &OrigCall, FmtSpecCount));
continue;
}
Args.push_back(OrigArg);
}
CallInst *NewCall =
CallInst::Create(OpenCLPrintf, Args, "", &OrigCall);
// CHECK: Does this invalidate I?
OrigCall.replaceAllUsesWith(NewCall);
// Instead of returning the success/failure from the OpenCL printf(),
// assume that the parsing succeeds and return the number of format
// strings. A slight improvement would be to return 0 in case of a
// failure, but it still would not necessary conform to CUDA nor HIP
// since it should return the number of valid format replacements?
IntegerType *Int32Ty = Type::getInt32Ty(M->getContext());
ConstantInt *RV = ConstantInt::get(Int32Ty, FmtSpecCount);
NewCall->replaceAllUsesWith(RV);
EraseList.insert(&I);
}
}
for (auto I : EraseList)
I->eraseFromParent();
return EraseList.size() > 0 ?
PreservedAnalyses::none() : PreservedAnalyses::all();
}
namespace {
extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK
llvmGetPassPluginInfo() {
return {LLVM_PLUGIN_API_VERSION, "hip-printf",
LLVM_VERSION_STRING, [](PassBuilder &PB) {
PB.registerPipelineParsingCallback(
[](StringRef Name, FunctionPassManager &FPM,
ArrayRef<PassBuilder::PipelineElement>) {
if (Name == "hip-printf") {
FPM.addPass(HipPrintfToOpenCLPrintfPass());
return true;
}
return false;
});
}};
}
}