| 
 | 1 | +//===- ExecutionContext.h -  Execution Context Support *- C++ -*-=============//  | 
 | 2 | +//  | 
 | 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.  | 
 | 4 | +// See https://llvm.org/LICENSE.txt for license information.  | 
 | 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception  | 
 | 6 | +//  | 
 | 7 | +//===----------------------------------------------------------------------===//  | 
 | 8 | + | 
 | 9 | +#ifndef MLIR_TRACING_EXECUTIONCONTEXT_H  | 
 | 10 | +#define MLIR_TRACING_EXECUTIONCONTEXT_H  | 
 | 11 | + | 
 | 12 | +#include "mlir/Debug/BreakpointManager.h"  | 
 | 13 | +#include "mlir/IR/Action.h"  | 
 | 14 | +#include "llvm/ADT/SmallVector.h"  | 
 | 15 | + | 
 | 16 | +namespace mlir {  | 
 | 17 | +namespace tracing {  | 
 | 18 | + | 
 | 19 | +/// This class is used to keep track of the active actions in the stack.  | 
 | 20 | +/// It provides the current action but also access to the parent entry in the  | 
 | 21 | +/// stack. This allows to keep track of the nested nature in which actions may  | 
 | 22 | +/// be executed.  | 
 | 23 | +struct ActionActiveStack {  | 
 | 24 | +public:  | 
 | 25 | +  ActionActiveStack(const ActionActiveStack *parent, const Action &action,  | 
 | 26 | +                    int depth)  | 
 | 27 | +      : parent(parent), action(action), depth(depth) {}  | 
 | 28 | +  const ActionActiveStack *getParent() const { return parent; }  | 
 | 29 | +  const Action &getAction() const { return action; }  | 
 | 30 | +  int getDepth() const { return depth; }  | 
 | 31 | + | 
 | 32 | +private:  | 
 | 33 | +  const ActionActiveStack *parent;  | 
 | 34 | +  const Action &action;  | 
 | 35 | +  int depth;  | 
 | 36 | +};  | 
 | 37 | + | 
 | 38 | +/// The ExecutionContext is the main orchestration of the infrastructure, it  | 
 | 39 | +/// acts as a handler in the MLIRContext for executing an Action. When an action  | 
 | 40 | +/// is dispatched, it'll query its set of Breakpoints managers for a breakpoint  | 
 | 41 | +/// matching this action. If a breakpoint is hit, it passes the action and the  | 
 | 42 | +/// breakpoint information to a callback. The callback is responsible for  | 
 | 43 | +/// controlling the execution of the action through an enum value it returns.  | 
 | 44 | +/// Optionally, observers can be registered to be notified before and after the  | 
 | 45 | +/// callback is executed.  | 
 | 46 | +class ExecutionContext {  | 
 | 47 | +public:  | 
 | 48 | +  /// Enum that allows the client of the context to control the execution of the  | 
 | 49 | +  /// action.  | 
 | 50 | +  /// - Apply: The action is executed.  | 
 | 51 | +  /// - Skip: The action is skipped.  | 
 | 52 | +  /// - Step: The action is executed and the execution is paused before the next  | 
 | 53 | +  ///         action, including for nested actions encountered before the  | 
 | 54 | +  ///         current action finishes.  | 
 | 55 | +  /// - Next: The action is executed and the execution is paused after the  | 
 | 56 | +  ///         current action finishes before the next action.  | 
 | 57 | +  /// - Finish: The action is executed and the execution is paused only when we  | 
 | 58 | +  ///           reach the parent/enclosing operation. If there are no enclosing  | 
 | 59 | +  ///           operation, the execution continues without stopping.  | 
 | 60 | +  enum Control { Apply = 1, Skip = 2, Step = 3, Next = 4, Finish = 5 };  | 
 | 61 | + | 
 | 62 | +  /// The type of the callback that is used to control the execution.  | 
 | 63 | +  /// The callback is passed the current action.  | 
 | 64 | +  using CallbackTy = function_ref<Control(const ActionActiveStack *)>;  | 
 | 65 | + | 
 | 66 | +  /// Create an ExecutionContext with a callback that is used to control the  | 
 | 67 | +  /// execution.  | 
 | 68 | +  ExecutionContext(CallbackTy callback) { setCallback(callback); }  | 
 | 69 | +  ExecutionContext() = default;  | 
 | 70 | + | 
 | 71 | +  /// Set the callback that is used to control the execution.  | 
 | 72 | +  void setCallback(CallbackTy callback);  | 
 | 73 | + | 
 | 74 | +  /// This abstract class defines the interface used to observe an Action  | 
 | 75 | +  /// execution. It allows to be notified before and after the callback is  | 
 | 76 | +  /// processed, but can't affect the execution.  | 
 | 77 | +  struct Observer {  | 
 | 78 | +    virtual ~Observer() = default;  | 
 | 79 | +    /// This method is called before the Action is executed  | 
 | 80 | +    /// If a breakpoint was hit, it is passed as an argument to the callback.  | 
 | 81 | +    /// The `willExecute` argument indicates whether the action will be executed  | 
 | 82 | +    /// or not.  | 
 | 83 | +    /// Note that this method will be called from multiple threads concurrently  | 
 | 84 | +    /// when MLIR multi-threading is enabled.  | 
 | 85 | +    virtual void beforeExecute(const ActionActiveStack *action,  | 
 | 86 | +                               Breakpoint *breakpoint, bool willExecute) {}  | 
 | 87 | + | 
 | 88 | +    /// This method is called after the Action is executed, if it was executed.  | 
 | 89 | +    /// It is not called if the action is skipped.  | 
 | 90 | +    /// Note that this method will be called from multiple threads concurrently  | 
 | 91 | +    /// when MLIR multi-threading is enabled.  | 
 | 92 | +    virtual void afterExecute(const ActionActiveStack *action) {}  | 
 | 93 | +  };  | 
 | 94 | + | 
 | 95 | +  /// Register a new `Observer` on this context. It'll be notified before and  | 
 | 96 | +  /// after executing an action. Note that this method is not thread-safe: it  | 
 | 97 | +  /// isn't supported to add a new observer while actions may be executed.  | 
 | 98 | +  void registerObserver(Observer *observer);  | 
 | 99 | + | 
 | 100 | +  /// Register a new `BreakpointManager` on this context. It'll have a chance to  | 
 | 101 | +  /// match an action before it gets executed. Note that this method is not  | 
 | 102 | +  /// thread-safe: it isn't supported to add a new manager while actions may be  | 
 | 103 | +  /// executed.  | 
 | 104 | +  void addBreakpointManager(BreakpointManager *manager) {  | 
 | 105 | +    breakpoints.push_back(manager);  | 
 | 106 | +  }  | 
 | 107 | + | 
 | 108 | +  /// Process the given action. This is the operator called by MLIRContext on  | 
 | 109 | +  /// `executeAction()`.  | 
 | 110 | +  void operator()(function_ref<void()> transform, const Action &action);  | 
 | 111 | + | 
 | 112 | +private:  | 
 | 113 | +  /// Callback that is executed when a breakpoint is hit and allows the client  | 
 | 114 | +  /// to control the execution.  | 
 | 115 | +  CallbackTy onBreakpointControlExecutionCallback;  | 
 | 116 | + | 
 | 117 | +  /// Next point to stop execution as describe by `Control` enum.  | 
 | 118 | +  /// This is handle by indicating at which levels of depth the next  | 
 | 119 | +  /// break should happen.  | 
 | 120 | +  Optional<int> depthToBreak;  | 
 | 121 | + | 
 | 122 | +  /// Observers that are notified before and after the callback is executed.  | 
 | 123 | +  SmallVector<Observer *> observers;  | 
 | 124 | + | 
 | 125 | +  /// The list of managers that are queried for breakpoints.  | 
 | 126 | +  SmallVector<BreakpointManager *> breakpoints;  | 
 | 127 | +};  | 
 | 128 | + | 
 | 129 | +} // namespace tracing  | 
 | 130 | +} // namespace mlir  | 
 | 131 | + | 
 | 132 | +#endif // MLIR_TRACING_EXECUTIONCONTEXT_H  | 
0 commit comments