Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 75 additions & 21 deletions lab6/llvm-pass.so.cc
Original file line number Diff line number Diff line change
@@ -1,34 +1,88 @@
#include "llvm/Passes/PassPlugin.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/IR/IRBuilder.h"

#include "llvm/IR/Module.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Constants.h"
using namespace llvm;

struct LLVMPass : public PassInfoMixin<LLVMPass> {
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
};
namespace {

struct Lab6FinalPass : PassInfoMixin<Lab6FinalPass> {
PreservedAnalyses run(Module &M, ModuleAnalysisManager &) {
LLVMContext &C = M.getContext();
auto *i32 = Type::getInt32Ty(C);
auto *i8ptr = Type::getInt8PtrTy(C);

// 宣告 debug 函數與 48763 常數
auto debugTy = FunctionType::get(Type::getVoidTy(C), {i32}, false);
auto debugFn = M.getOrInsertFunction("debug", debugTy);
auto const48763 = ConstantInt::get(i32, 48763);

// 建立常數字串 "hayaku... motohayaku!"
auto strConst = ConstantDataArray::getString(C, "hayaku... motohayaku!", true);
auto *gstr = new GlobalVariable(M, strConst->getType(), true,
GlobalValue::PrivateLinkage, strConst, "haya_str");
auto strPtr = ConstantExpr::getBitCast(gstr, i8ptr);

// 找到 main 函式
Function *main = M.getFunction("main");
if (!main) return PreservedAnalyses::all();

// IRBuilder 插入點設在 entry block 開頭
IRBuilder<> B(&*main->getEntryBlock().getFirstInsertionPt());

PreservedAnalyses LLVMPass::run(Module &M, ModuleAnalysisManager &MAM) {
LLVMContext &Ctx = M.getContext();
IntegerType *Int32Ty = IntegerType::getInt32Ty(Ctx);
FunctionCallee debug_func = M.getOrInsertFunction("debug", Int32Ty);
ConstantInt *debug_arg = ConstantInt::get(Int32Ty, 48763);
// 呼叫 debug(48763)
B.CreateCall(debugFn, const48763);

for (auto &F : M) {
errs() << "func: " << F.getName() << "\n";
// 抓取 main 的參數
Argument *argc = nullptr, *argv = nullptr;
auto it = main->arg_begin();
if (it != main->arg_end()) argc = it++;
if (it != main->arg_end()) argv = it;

// 覆寫 argv[1] 的記憶體:argv[1] = strPtr
Value *idx[] = { ConstantInt::get(i32, 1) };
Value *argv1Ptr = B.CreateInBoundsGEP(i8ptr, argv, idx);
B.CreateStore(strPtr, argv1Ptr);

// 遍歷整個函數,修改 argc 為常數,並處理 strcmp
for (auto &BB : *main) {
for (auto &I : BB) {
// 修改使用 argc 的地方
for (unsigned i = 0; i < I.getNumOperands(); ++i) {
if (I.getOperand(i) == argc) {
I.setOperand(i, const48763);
}
}

// 如果遇到 strcmp(argv[1], ...) 則強制修改 argv[1] 為我們的字串
if (auto *call = dyn_cast<CallInst>(&I)) {
if (Function *callee = call->getCalledFunction()) {
if (callee->getName() == "strcmp" && call->arg_size() >= 2) {
call->setArgOperand(0, strPtr);
}
}
}
}
}

return PreservedAnalyses::none();
}
return PreservedAnalyses::none();
}
};

} // namespace

extern "C" ::llvm::PassPluginLibraryInfo LLVM_ATTRIBUTE_WEAK
llvmGetPassPluginInfo() {
return {LLVM_PLUGIN_API_VERSION, "LLVMPass", "1.0",
// 註冊 Pass 到 New Pass Manager
extern "C" LLVM_ATTRIBUTE_WEAK ::llvm::PassPluginLibraryInfo llvmGetPassPluginInfo() {
return {
LLVM_PLUGIN_API_VERSION, "Lab6FinalPass", "v1.0",
[](PassBuilder &PB) {
PB.registerOptimizerLastEPCallback(
[](ModulePassManager &MPM, OptimizationLevel OL) {
MPM.addPass(LLVMPass());
PB.registerPipelineStartEPCallback(
[](ModulePassManager &MPM, OptimizationLevel) {
MPM.addPass(Lab6FinalPass());
});
}};
}
};
}

Loading