Skip to content

Commit fa51c17

Browse files
committed
Introduce mlir::tracing::ExecutionContext
This component acts as an action handler that can be registered in the MLIRContext. It is the main orchestration of the infrastructure, and implements support for clients to hook there and snoop on or control the execution. This is the basis to build tracing as well as a "gdb-like" control of the compilation flow. The ExecutionContext acts as a handler in the MLIRContext for executing an Action. When an action is dispatched, it'll query its set of Breakpoints managers for a breakpoint matching this action. If a breakpoint is hit, it passes the action and the breakpoint information to a callback. The callback is responsible for controlling the execution of the action through an enum value it returns. Optionally, observers can be registered to be notified before and after the callback is executed. Differential Revision: https://reviews.llvm.org/D144812
1 parent 4356228 commit fa51c17

File tree

7 files changed

+743
-0
lines changed

7 files changed

+743
-0
lines changed
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
//===- BreakpointManager.h - Breakpoint Manager Support ----*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_TRACING_BREAKPOINTMANAGER_H
10+
#define MLIR_TRACING_BREAKPOINTMANAGER_H
11+
12+
#include "mlir/IR/Action.h"
13+
#include "llvm/ADT/MapVector.h"
14+
15+
namespace mlir {
16+
namespace tracing {
17+
18+
/// This abstract class represents a breakpoint.
19+
class Breakpoint {
20+
public:
21+
virtual ~Breakpoint() = default;
22+
23+
/// TypeID for the subclass, used for casting purpose.
24+
TypeID getTypeID() const { return typeID; }
25+
26+
bool isEnabled() const { return enableStatus; }
27+
void enable() { enableStatus = true; }
28+
void disable() { enableStatus = false; }
29+
virtual void print(raw_ostream &os) const = 0;
30+
31+
protected:
32+
Breakpoint(TypeID typeID) : enableStatus(true), typeID(typeID) {}
33+
34+
private:
35+
/// The current state of the breakpoint. A breakpoint can be either enabled
36+
/// or disabled.
37+
bool enableStatus;
38+
TypeID typeID;
39+
};
40+
41+
inline raw_ostream &operator<<(raw_ostream &os, const Breakpoint &breakpoint) {
42+
breakpoint.print(os);
43+
return os;
44+
}
45+
46+
/// This class provides a CRTP wrapper around a base breakpoint class to define
47+
/// a few necessary utility methods.
48+
template <typename Derived>
49+
class BreakpointBase : public Breakpoint {
50+
public:
51+
/// Support isa/dyn_cast functionality for the derived pass class.
52+
static bool classof(const Breakpoint *breakpoint) {
53+
return breakpoint->getTypeID() == TypeID::get<Derived>();
54+
}
55+
56+
protected:
57+
BreakpointBase() : Breakpoint(TypeID::get<Derived>()) {}
58+
};
59+
60+
/// A breakpoint manager is responsible for managing a set of breakpoints and
61+
/// matching them to a given action.
62+
class BreakpointManager {
63+
public:
64+
virtual ~BreakpointManager() = default;
65+
66+
/// TypeID for the subclass, used for casting purpose.
67+
TypeID getTypeID() const { return typeID; }
68+
69+
/// Try to match a Breakpoint to a given Action. If there is a match and
70+
/// the breakpoint is enabled, return the breakpoint. Otherwise, return
71+
/// nullptr.
72+
virtual Breakpoint *match(const Action &action) const = 0;
73+
74+
protected:
75+
BreakpointManager(TypeID typeID) : typeID(typeID) {}
76+
77+
TypeID typeID;
78+
};
79+
80+
/// CRTP base class for BreakpointManager implementations.
81+
template <typename Derived>
82+
class BreakpointManagerBase : public BreakpointManager {
83+
public:
84+
BreakpointManagerBase() : BreakpointManager(TypeID::get<Derived>()) {}
85+
86+
/// Provide classof to allow casting between breakpoint manager types.
87+
static bool classof(const BreakpointManager *breakpointManager) {
88+
return breakpointManager->getTypeID() == TypeID::get<Derived>();
89+
}
90+
};
91+
92+
} // namespace tracing
93+
} // namespace mlir
94+
95+
#endif // MLIR_TRACING_BREAKPOINTMANAGER_H
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
//===- TagBreakpointManager.h - Simple breakpoint Support -------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_DEBUG_BREAKPOINTMANAGERS_TAGBREAKPOINTMANAGER_H
10+
#define MLIR_DEBUG_BREAKPOINTMANAGERS_TAGBREAKPOINTMANAGER_H
11+
12+
#include "mlir/Debug/BreakpointManager.h"
13+
#include "mlir/Debug/ExecutionContext.h"
14+
#include "mlir/IR/Action.h"
15+
#include "llvm/ADT/MapVector.h"
16+
17+
namespace mlir {
18+
namespace tracing {
19+
20+
/// Simple breakpoint matching an action "tag".
21+
class TagBreakpoint : public BreakpointBase<TagBreakpoint> {
22+
public:
23+
TagBreakpoint(StringRef tag) : tag(tag) {}
24+
25+
void print(raw_ostream &os) const override { os << "Tag: `" << tag << '`'; }
26+
27+
private:
28+
/// A tag to associate the TagBreakpoint with.
29+
std::string tag;
30+
31+
/// Allow access to `tag`.
32+
friend class TagBreakpointManager;
33+
};
34+
35+
/// This is a manager to store a collection of breakpoints that trigger
36+
/// on tags.
37+
class TagBreakpointManager
38+
: public BreakpointManagerBase<TagBreakpointManager> {
39+
public:
40+
Breakpoint *match(const Action &action) const override {
41+
auto it = breakpoints.find(action.getTag());
42+
if (it != breakpoints.end() && it->second->isEnabled())
43+
return it->second.get();
44+
return {};
45+
}
46+
47+
/// Add a breakpoint to the manager for the given tag and return it.
48+
/// If a breakpoint already exists for the given tag, return the existing
49+
/// instance.
50+
TagBreakpoint *addBreakpoint(StringRef tag) {
51+
auto result = breakpoints.insert({tag, nullptr});
52+
auto &it = result.first;
53+
if (result.second)
54+
it->second = std::make_unique<TagBreakpoint>(tag.str());
55+
return it->second.get();
56+
}
57+
58+
private:
59+
llvm::StringMap<std::unique_ptr<TagBreakpoint>> breakpoints;
60+
};
61+
62+
} // namespace tracing
63+
} // namespace mlir
64+
65+
#endif // MLIR_DEBUG_BREAKPOINTMANAGERS_TAGBREAKPOINTMANAGER_H
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
//===- ExecutionContext.h - Execution Context Support *- C++ -*-=============//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef MLIR_TRACING_EXECUTIONCONTEXT_H
10+
#define MLIR_TRACING_EXECUTIONCONTEXT_H
11+
12+
#include "mlir/Debug/BreakpointManager.h"
13+
#include "mlir/IR/Action.h"
14+
#include "llvm/ADT/SmallVector.h"
15+
16+
namespace mlir {
17+
namespace tracing {
18+
19+
/// This class is used to keep track of the active actions in the stack.
20+
/// It provides the current action but also access to the parent entry in the
21+
/// stack. This allows to keep track of the nested nature in which actions may
22+
/// be executed.
23+
struct ActionActiveStack {
24+
public:
25+
ActionActiveStack(const ActionActiveStack *parent, const Action &action,
26+
int depth)
27+
: parent(parent), action(action), depth(depth) {}
28+
const ActionActiveStack *getParent() const { return parent; }
29+
const Action &getAction() const { return action; }
30+
int getDepth() const { return depth; }
31+
32+
private:
33+
const ActionActiveStack *parent;
34+
const Action &action;
35+
int depth;
36+
};
37+
38+
/// The ExecutionContext is the main orchestration of the infrastructure, it
39+
/// acts as a handler in the MLIRContext for executing an Action. When an action
40+
/// is dispatched, it'll query its set of Breakpoints managers for a breakpoint
41+
/// matching this action. If a breakpoint is hit, it passes the action and the
42+
/// breakpoint information to a callback. The callback is responsible for
43+
/// controlling the execution of the action through an enum value it returns.
44+
/// Optionally, observers can be registered to be notified before and after the
45+
/// callback is executed.
46+
class ExecutionContext {
47+
public:
48+
/// Enum that allows the client of the context to control the execution of the
49+
/// action.
50+
/// - Apply: The action is executed.
51+
/// - Skip: The action is skipped.
52+
/// - Step: The action is executed and the execution is paused before the next
53+
/// action, including for nested actions encountered before the
54+
/// current action finishes.
55+
/// - Next: The action is executed and the execution is paused after the
56+
/// current action finishes before the next action.
57+
/// - Finish: The action is executed and the execution is paused only when we
58+
/// reach the parent/enclosing operation. If there are no enclosing
59+
/// operation, the execution continues without stopping.
60+
enum Control { Apply = 1, Skip = 2, Step = 3, Next = 4, Finish = 5 };
61+
62+
/// The type of the callback that is used to control the execution.
63+
/// The callback is passed the current action.
64+
using CallbackTy = function_ref<Control(const ActionActiveStack *)>;
65+
66+
/// Create an ExecutionContext with a callback that is used to control the
67+
/// execution.
68+
ExecutionContext(CallbackTy callback) { setCallback(callback); }
69+
ExecutionContext() = default;
70+
71+
/// Set the callback that is used to control the execution.
72+
void setCallback(CallbackTy callback);
73+
74+
/// This abstract class defines the interface used to observe an Action
75+
/// execution. It allows to be notified before and after the callback is
76+
/// processed, but can't affect the execution.
77+
struct Observer {
78+
virtual ~Observer() = default;
79+
/// This method is called before the Action is executed
80+
/// If a breakpoint was hit, it is passed as an argument to the callback.
81+
/// The `willExecute` argument indicates whether the action will be executed
82+
/// or not.
83+
/// Note that this method will be called from multiple threads concurrently
84+
/// when MLIR multi-threading is enabled.
85+
virtual void beforeExecute(const ActionActiveStack *action,
86+
Breakpoint *breakpoint, bool willExecute) {}
87+
88+
/// This method is called after the Action is executed, if it was executed.
89+
/// It is not called if the action is skipped.
90+
/// Note that this method will be called from multiple threads concurrently
91+
/// when MLIR multi-threading is enabled.
92+
virtual void afterExecute(const ActionActiveStack *action) {}
93+
};
94+
95+
/// Register a new `Observer` on this context. It'll be notified before and
96+
/// after executing an action. Note that this method is not thread-safe: it
97+
/// isn't supported to add a new observer while actions may be executed.
98+
void registerObserver(Observer *observer);
99+
100+
/// Register a new `BreakpointManager` on this context. It'll have a chance to
101+
/// match an action before it gets executed. Note that this method is not
102+
/// thread-safe: it isn't supported to add a new manager while actions may be
103+
/// executed.
104+
void addBreakpointManager(BreakpointManager *manager) {
105+
breakpoints.push_back(manager);
106+
}
107+
108+
/// Process the given action. This is the operator called by MLIRContext on
109+
/// `executeAction()`.
110+
void operator()(function_ref<void()> transform, const Action &action);
111+
112+
private:
113+
/// Callback that is executed when a breakpoint is hit and allows the client
114+
/// to control the execution.
115+
CallbackTy onBreakpointControlExecutionCallback;
116+
117+
/// Next point to stop execution as describe by `Control` enum.
118+
/// This is handle by indicating at which levels of depth the next
119+
/// break should happen.
120+
Optional<int> depthToBreak;
121+
122+
/// Observers that are notified before and after the callback is executed.
123+
SmallVector<Observer *> observers;
124+
125+
/// The list of managers that are queried for breakpoints.
126+
SmallVector<BreakpointManager *> breakpoints;
127+
};
128+
129+
} // namespace tracing
130+
} // namespace mlir
131+
132+
#endif // MLIR_TRACING_EXECUTIONCONTEXT_H

mlir/lib/Debug/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
add_mlir_library(MLIRDebug
22
DebugCounter.cpp
3+
ExecutionContext.cpp
34

45
ADDITIONAL_HEADER_DIRS
56
${MLIR_MAIN_INCLUDE_DIR}/mlir/Debug
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
//===- ExecutionContext.cpp - Debug Execution Context Support -------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Debug/ExecutionContext.h"
10+
11+
#include "llvm/ADT/ScopeExit.h"
12+
13+
#include <cstddef>
14+
15+
using namespace mlir;
16+
using namespace mlir::tracing;
17+
18+
//===----------------------------------------------------------------------===//
19+
// ExecutionContext
20+
//===----------------------------------------------------------------------===//
21+
22+
static const thread_local ActionActiveStack *actionStack = nullptr;
23+
24+
void ExecutionContext::setCallback(CallbackTy callback) {
25+
onBreakpointControlExecutionCallback = callback;
26+
}
27+
28+
void ExecutionContext::registerObserver(Observer *observer) {
29+
observers.push_back(observer);
30+
}
31+
32+
void ExecutionContext::operator()(llvm::function_ref<void()> transform,
33+
const Action &action) {
34+
// Update the top of the stack with the current action.
35+
int depth = 0;
36+
if (actionStack)
37+
depth = actionStack->getDepth() + 1;
38+
ActionActiveStack info{actionStack, action, depth};
39+
actionStack = &info;
40+
auto raii = llvm::make_scope_exit([&]() { actionStack = info.getParent(); });
41+
Breakpoint *breakpoint = nullptr;
42+
43+
// Invoke the callback here and handles control requests here.
44+
auto handleUserInput = [&]() -> bool {
45+
if (!onBreakpointControlExecutionCallback)
46+
return true;
47+
auto todoNext = onBreakpointControlExecutionCallback(actionStack);
48+
switch (todoNext) {
49+
case ExecutionContext::Apply:
50+
depthToBreak = std::nullopt;
51+
return true;
52+
case ExecutionContext::Skip:
53+
depthToBreak = std::nullopt;
54+
return false;
55+
case ExecutionContext::Step:
56+
depthToBreak = depth + 1;
57+
return true;
58+
case ExecutionContext::Next:
59+
depthToBreak = depth;
60+
return true;
61+
case ExecutionContext::Finish:
62+
depthToBreak = depth - 1;
63+
return true;
64+
}
65+
llvm::report_fatal_error("Unknown control request");
66+
};
67+
68+
// Try to find a breakpoint that would hit on this action.
69+
// Right now there is no way to collect them all, we stop at the first one.
70+
for (auto *breakpointManager : breakpoints) {
71+
breakpoint = breakpointManager->match(action);
72+
if (breakpoint)
73+
break;
74+
}
75+
76+
bool shouldExecuteAction = true;
77+
// If we have a breakpoint, or if `depthToBreak` was previously set and the
78+
// current depth matches, we invoke the user-provided callback.
79+
if (breakpoint || (depthToBreak && depth <= depthToBreak))
80+
shouldExecuteAction = handleUserInput();
81+
82+
// Notify the observers about the current action.
83+
for (auto *observer : observers)
84+
observer->beforeExecute(actionStack, breakpoint, shouldExecuteAction);
85+
86+
if (shouldExecuteAction) {
87+
// Execute the action here.
88+
transform();
89+
90+
// Notify the observers about completion of the action.
91+
for (auto *observer : observers)
92+
observer->afterExecute(actionStack);
93+
}
94+
95+
if (depthToBreak && depth <= depthToBreak)
96+
handleUserInput();
97+
}

0 commit comments

Comments
 (0)