/
CheriStackInvalidatePass.cpp
127 lines (113 loc) · 4.5 KB
/
CheriStackInvalidatePass.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#define DEBUG_TYPE "capability-register-invalidate"
#include "Mips.h"
#include "MipsTargetMachine.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineInstrBuilder.h"
#include "llvm/CodeGen/TargetInstrInfo.h"
#include "llvm/CodeGen/TargetRegisterInfo.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Target/TargetMachine.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
using namespace llvm;
namespace {
class CheriInvalidatePass : public MachineFunctionPass {
const MipsInstrInfo *InstrInfo = nullptr;
SmallVector<MachineInstr *, 16> StackStores;
SmallVector<MachineInstr *, 4> Returns;
public:
static char ID;
CheriInvalidatePass() : MachineFunctionPass(ID) {}
StringRef getPassName() const override { return "CHERI invalidate pass"; }
void runOnMachineBasicBlock(MachineBasicBlock &MBB) {
if (!InstrInfo)
InstrInfo = MBB.getParent()->getSubtarget<MipsSubtarget>().getInstrInfo();
for (MachineBasicBlock::iterator I = MBB.instr_begin();
I != MBB.instr_end(); ++I) {
int FI;
MachineInstr &Inst = *I;
if (InstrInfo->isStoreToStackSlot(Inst, FI)) {
StackStores.push_back(&Inst);
} else if (I->isReturn()) {
Returns.push_back(&Inst);
}
}
}
bool runOnMachineFunction(MachineFunction &F) override{
if (!InstrInfo)
InstrInfo = F.getSubtarget<MipsSubtarget>().getInstrInfo();
// Metadata nodes are no longer allowed to refer to functions, so we need
// another mechanism for identifying them. We should do it properly by adding
// a function attribute.
#if 0
const Function *IRFunction = F.getFunction();
const Module *Mod = IRFunction->getParent();
NamedMDNode *SensitiveFunctions =
Mod->getNamedMetadata("cheri.sensitive.functions");
if (!SensitiveFunctions) return false;
bool foundFunction = false;
for (unsigned i=0 ; i<SensitiveFunctions->getNumOperands() ; i++) {
Value *SensitiveFunction =
cast<MDNode>(SensitiveFunctions->getOperand(i))->getOperand(0);
if (SensitiveFunction == IRFunction) {
foundFunction = true;
break;
}
}
if (!foundFunction) return false;
LLVM_DEBUG(dbgs() << "Zeroing stack spills\n");
StackStores.clear();
Returns.clear();
SmallSet<int, 16> ZeroedLocations;
for (MachineFunction::iterator FI = F.begin(), FE = F.end();
FI != FE; ++FI)
runOnMachineBasicBlock(*FI);
if (StackStores.size() == 0) return false;
for (SmallVector<MachineInstr*, 4>::iterator i=Returns.begin(),
e=Returns.end() ; i!=e ; ++i) {
MachineInstr *Ret = *i;
ZeroedLocations.clear();
for (SmallVector<MachineInstr*, 16>::iterator si=StackStores.begin(),
se=StackStores.end() ; si!=se ; ++si) {
MachineInstr *Store = *si;
unsigned Opc = Store->getOpcode();
MachineBasicBlock &MBB = *Ret->getParent();
int FI = Store->getOperand(1).getIndex();
// If we've already zeroed this location, skip it.
if (ZeroedLocations.count(FI)) continue;
ZeroedLocations.insert(FI);
// If this is a capability store, then we just do a 64-bit integer
// write. This leaks information, but invalidates the capability.
if (Opc == Mips::STORECAP) {
MachineInstrBuilder MIB(F, Ret);
BuildMI(MBB, Ret, Ret->getDebugLoc(), InstrInfo->get(Mips::SD))
.addReg(Mips::ZERO)
.addFrameIndex(FI).addImm(Store->getOperand(2).getImm())
.addMemOperand(InstrInfo->GetMemOperand(MBB, FI, MachineMemOperand::MOStore));
LLVM_DEBUG(dbgs() << "Zeroing capability spill\n");
} else {
// For other stores, we do the same type of store as was used for the spill, now with zeros.
BuildMI(MBB, Ret, Ret->getDebugLoc(), InstrInfo->get(Opc))
.addReg(Mips::ZERO)
.addFrameIndex(FI).addImm(Store->getOperand(2).getImm())
.addMemOperand(InstrInfo->GetMemOperand(MBB, FI, MachineMemOperand::MOStore));
LLVM_DEBUG(dbgs() << "Zeroing non-capability spill\n");
}
}
}
return true;
#else
return false;
#endif
}
};
}
char CheriInvalidatePass::ID;
FunctionPass *llvm::createCheriInvalidatePass() {
return new CheriInvalidatePass();
}