Skip to content

Commit 6db4e60

Browse files
committed
Lazy BB Embeddings
1 parent a0d699a commit 6db4e60

File tree

3 files changed

+38
-15
lines changed

3 files changed

+38
-15
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ class Embedder {
9191
/// the embeddings is specific to the kind of embeddings being computed.
9292
virtual void computeEmbeddings() const = 0;
9393

94+
/// Helper function to compute the embedding for a given basic block.
95+
/// Specific to the kind of embeddings being computed.
96+
virtual void computeEmbeddings(const BasicBlock &BB) const = 0;
97+
9498
/// Lookup vocabulary for a given Key. If the key is not found, it returns a
9599
/// zero vector.
96100
Embedding lookupVocab(const std::string &Key) const;
@@ -121,6 +125,11 @@ class Embedder {
121125
/// for the function and returns the map.
122126
const BBEmbeddingsMap &getBBVecMap() const;
123127

128+
/// Returns the embedding for a given basic block in the function F if it has
129+
/// been computed. If not, it computes the embedding for the basic block and
130+
/// returns it.
131+
const Embedding &getBBVector(const BasicBlock &BB) const;
132+
124133
/// Computes and returns the embedding for the current function.
125134
const Embedding &getFunctionVector() const;
126135
};
@@ -130,16 +139,14 @@ class Embedder {
130139
/// representations obtained from the Vocabulary.
131140
class SymbolicEmbedder : public Embedder {
132141
private:
133-
/// Utility function to compute the embedding for a given basic block.
134-
Embedding computeBB2Vec(const BasicBlock &BB) const;
135-
136142
/// Utility function to compute the embedding for a given type.
137143
Embedding getTypeEmbedding(const Type *Ty) const;
138144

139145
/// Utility function to compute the embedding for a given operand.
140146
Embedding getOperandEmbedding(const Value *Op) const;
141147

142148
void computeEmbeddings() const override;
149+
void computeEmbeddings(const BasicBlock &BB) const override;
143150

144151
public:
145152
SymbolicEmbedder(const Function &F, const Vocab &Vocabulary,

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,13 @@ const BBEmbeddingsMap &Embedder::getBBVecMap() const {
116116
return BBVecMap;
117117
}
118118

119+
const Embedding &Embedder::getBBVector(const BasicBlock &BB) const {
120+
auto It = BBVecMap.find(&BB);
121+
if (It == BBVecMap.end())
122+
computeEmbeddings(BB);
123+
return It->second;
124+
}
125+
119126
const Embedding &Embedder::getFunctionVector() const {
120127
// Currently, we always (re)compute the embeddings for the function.
121128
// This is cheaper than caching the vector.
@@ -152,17 +159,7 @@ Embedding SymbolicEmbedder::getOperandEmbedding(const Value *Op) const {
152159

153160
#undef RETURN_LOOKUP_IF
154161

155-
void SymbolicEmbedder::computeEmbeddings() const {
156-
if (F.isDeclaration())
157-
return;
158-
for (const auto &BB : F) {
159-
auto [It, WasInserted] = BBVecMap.try_emplace(&BB, computeBB2Vec(BB));
160-
assert(WasInserted && "Basic block already exists in the map");
161-
addVectors(FuncVector, It->second);
162-
}
163-
}
164-
165-
Embedding SymbolicEmbedder::computeBB2Vec(const BasicBlock &BB) const {
162+
void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
166163
Embedding BBVector(Dimension, 0);
167164

168165
for (const auto &I : BB) {
@@ -184,7 +181,16 @@ Embedding SymbolicEmbedder::computeBB2Vec(const BasicBlock &BB) const {
184181
InstVecMap[&I] = InstVector;
185182
addVectors(BBVector, InstVector);
186183
}
187-
return BBVector;
184+
BBVecMap[&BB] = BBVector;
185+
}
186+
187+
void SymbolicEmbedder::computeEmbeddings() const {
188+
if (F.isDeclaration())
189+
return;
190+
for (const auto &BB : F) {
191+
computeEmbeddings(BB);
192+
addVectors(FuncVector, BBVecMap[&BB]);
193+
}
188194
}
189195

190196
// ==----------------------------------------------------------------------===//

llvm/unittests/Analysis/IR2VecTest.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class TestableEmbedder : public Embedder {
3131
TestableEmbedder(const Function &F, const Vocab &V, unsigned Dim)
3232
: Embedder(F, V, Dim) {}
3333
void computeEmbeddings() const override {}
34+
void computeEmbeddings(const BasicBlock &BB) const override {}
3435
using Embedder::lookupVocab;
3536
static void addVectors(Embedding &Dst, const Embedding &Src) {
3637
Embedder::addVectors(Dst, Src);
@@ -229,6 +230,15 @@ TEST(IR2VecTest, GetBBVecMap) {
229230
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
230231
}
231232

233+
TEST(IR2VecTest, GetBBVector) {
234+
GetterTestEnv Env;
235+
const auto &BBVec = Env.Emb->getBBVector(*Env.BB);
236+
237+
EXPECT_EQ(BBVec.size(), 2u);
238+
EXPECT_THAT(BBVec,
239+
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
240+
}
241+
232242
TEST(IR2VecTest, GetFunctionVector) {
233243
GetterTestEnv Env;
234244
const auto &FuncVec = Env.Emb->getFunctionVector();

0 commit comments

Comments
 (0)