Skip to content

Commit 494c82e

Browse files
authored
[IR2Vec] Support for lazy computation of BB Embeddings (#142033)
This PR exposes interfaces to compute embeddings at BB level. This would be necessary for delta patching the embeddings in MLInliner (#141836). (Tracking issue - #141817)
1 parent 8a44cd7 commit 494c82e

File tree

3 files changed

+39
-15
lines changed

3 files changed

+39
-15
lines changed

llvm/include/llvm/Analysis/IR2Vec.h

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ class Embedder {
9191
/// the embeddings is specific to the kind of embeddings being computed.
9292
virtual void computeEmbeddings() const = 0;
9393

94+
/// Helper function to compute the embedding for a given basic block.
95+
/// Specific to the kind of embeddings being computed.
96+
virtual void computeEmbeddings(const BasicBlock &BB) const = 0;
97+
9498
/// Lookup vocabulary for a given Key. If the key is not found, it returns a
9599
/// zero vector.
96100
Embedding lookupVocab(const std::string &Key) const;
@@ -121,6 +125,11 @@ class Embedder {
121125
/// for the function and returns the map.
122126
const BBEmbeddingsMap &getBBVecMap() const;
123127

128+
/// Returns the embedding for a given basic block in the function F if it has
129+
/// been computed. If not, it computes the embedding for the basic block and
130+
/// returns it.
131+
const Embedding &getBBVector(const BasicBlock &BB) const;
132+
124133
/// Computes and returns the embedding for the current function.
125134
const Embedding &getFunctionVector() const;
126135
};
@@ -130,16 +139,14 @@ class Embedder {
130139
/// representations obtained from the Vocabulary.
131140
class SymbolicEmbedder : public Embedder {
132141
private:
133-
/// Utility function to compute the embedding for a given basic block.
134-
Embedding computeBB2Vec(const BasicBlock &BB) const;
135-
136142
/// Utility function to compute the embedding for a given type.
137143
Embedding getTypeEmbedding(const Type *Ty) const;
138144

139145
/// Utility function to compute the embedding for a given operand.
140146
Embedding getOperandEmbedding(const Value *Op) const;
141147

142148
void computeEmbeddings() const override;
149+
void computeEmbeddings(const BasicBlock &BB) const override;
143150

144151
public:
145152
SymbolicEmbedder(const Function &F, const Vocab &Vocabulary,

llvm/lib/Analysis/IR2Vec.cpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,14 @@ const BBEmbeddingsMap &Embedder::getBBVecMap() const {
116116
return BBVecMap;
117117
}
118118

119+
const Embedding &Embedder::getBBVector(const BasicBlock &BB) const {
120+
auto It = BBVecMap.find(&BB);
121+
if (It != BBVecMap.end())
122+
return It->second;
123+
computeEmbeddings(BB);
124+
return BBVecMap[&BB];
125+
}
126+
119127
const Embedding &Embedder::getFunctionVector() const {
120128
// Currently, we always (re)compute the embeddings for the function.
121129
// This is cheaper than caching the vector.
@@ -152,17 +160,7 @@ Embedding SymbolicEmbedder::getOperandEmbedding(const Value *Op) const {
152160

153161
#undef RETURN_LOOKUP_IF
154162

155-
void SymbolicEmbedder::computeEmbeddings() const {
156-
if (F.isDeclaration())
157-
return;
158-
for (const auto &BB : F) {
159-
auto [It, WasInserted] = BBVecMap.try_emplace(&BB, computeBB2Vec(BB));
160-
assert(WasInserted && "Basic block already exists in the map");
161-
addVectors(FuncVector, It->second);
162-
}
163-
}
164-
165-
Embedding SymbolicEmbedder::computeBB2Vec(const BasicBlock &BB) const {
163+
void SymbolicEmbedder::computeEmbeddings(const BasicBlock &BB) const {
166164
Embedding BBVector(Dimension, 0);
167165

168166
for (const auto &I : BB) {
@@ -184,7 +182,16 @@ Embedding SymbolicEmbedder::computeBB2Vec(const BasicBlock &BB) const {
184182
InstVecMap[&I] = InstVector;
185183
addVectors(BBVector, InstVector);
186184
}
187-
return BBVector;
185+
BBVecMap[&BB] = BBVector;
186+
}
187+
188+
void SymbolicEmbedder::computeEmbeddings() const {
189+
if (F.isDeclaration())
190+
return;
191+
for (const auto &BB : F) {
192+
computeEmbeddings(BB);
193+
addVectors(FuncVector, BBVecMap[&BB]);
194+
}
188195
}
189196

190197
// ==----------------------------------------------------------------------===//

llvm/unittests/Analysis/IR2VecTest.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class TestableEmbedder : public Embedder {
3131
TestableEmbedder(const Function &F, const Vocab &V, unsigned Dim)
3232
: Embedder(F, V, Dim) {}
3333
void computeEmbeddings() const override {}
34+
void computeEmbeddings(const BasicBlock &BB) const override {}
3435
using Embedder::lookupVocab;
3536
static void addVectors(Embedding &Dst, const Embedding &Src) {
3637
Embedder::addVectors(Dst, Src);
@@ -229,6 +230,15 @@ TEST(IR2VecTest, GetBBVecMap) {
229230
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
230231
}
231232

233+
TEST(IR2VecTest, GetBBVector) {
234+
GetterTestEnv Env;
235+
const auto &BBVec = Env.Emb->getBBVector(*Env.BB);
236+
237+
EXPECT_EQ(BBVec.size(), 2u);
238+
EXPECT_THAT(BBVec,
239+
ElementsAre(DoubleNear(1.29, 1e-6), DoubleNear(2.31, 1e-6)));
240+
}
241+
232242
TEST(IR2VecTest, GetFunctionVector) {
233243
GetterTestEnv Env;
234244
const auto &FuncVec = Env.Emb->getFunctionVector();

0 commit comments

Comments
 (0)