Skip to content

Commit

Permalink
Add basic JIT Python Bindings
Browse files Browse the repository at this point in the history
This offers the ability to create a JIT and invoke a function by passing
ctypes pointers to the argument and the result.

Differential Revision: https://reviews.llvm.org/D97523
  • Loading branch information
joker-eph committed Mar 3, 2021
1 parent 86c8a78 commit 13cb431
Show file tree
Hide file tree
Showing 14 changed files with 333 additions and 2 deletions.
24 changes: 24 additions & 0 deletions mlir/include/mlir-c/Bindings/Python/Interop.h
Expand Up @@ -25,6 +25,7 @@

#include "mlir-c/AffineExpr.h"
#include "mlir-c/AffineMap.h"
#include "mlir-c/ExecutionEngine.h"
#include "mlir-c/IR.h"
#include "mlir-c/IntegerSet.h"
#include "mlir-c/Pass.h"
Expand All @@ -33,6 +34,8 @@
#define MLIR_PYTHON_CAPSULE_AFFINE_MAP "mlir.ir.AffineMap._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_ATTRIBUTE "mlir.ir.Attribute._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_CONTEXT "mlir.ir.Context._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_EXECUTION_ENGINE \
"mlir.execution_engine.ExecutionEngine._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_INTEGER_SET "mlir.ir.IntegerSet._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_LOCATION "mlir.ir.Location._CAPIPtr"
#define MLIR_PYTHON_CAPSULE_MODULE "mlir.ir.Module._CAPIPtr"
Expand Down Expand Up @@ -261,6 +264,27 @@ static inline MlirIntegerSet mlirPythonCapsuleToIntegerSet(PyObject *capsule) {
return integerSet;
}

/** Creates a capsule object encapsulating the raw C-API MlirExecutionEngine.
* The returned capsule does not extend or affect ownership of any Python
* objects that reference the set in any way. */
static inline PyObject *
mlirPythonExecutionEngineToCapsule(MlirExecutionEngine jit) {
return PyCapsule_New(MLIR_PYTHON_GET_WRAPPED_POINTER(jit),
MLIR_PYTHON_CAPSULE_EXECUTION_ENGINE, NULL);
}

/** Extracts an MlirExecutionEngine from a capsule as produced from
* mlirPythonIntegerSetToCapsule. If the capsule is not of the right type, then
* a null set is returned (as checked via mlirExecutionEngineIsNull). In such a
* case, the Python APIs will have already set an error. */
static inline MlirExecutionEngine
mlirPythonCapsuleToExecutionEngine(PyObject *capsule) {
void *ptr =
PyCapsule_GetPointer(capsule, MLIR_PYTHON_CAPSULE_EXECUTION_ENGINE);
MlirExecutionEngine jit = {ptr};
return jit;
}

#ifdef __cplusplus
}
#endif
Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir-c/ExecutionEngine.h
Expand Up @@ -56,6 +56,11 @@ static inline bool mlirExecutionEngineIsNull(MlirExecutionEngine jit) {
MLIR_CAPI_EXPORTED MlirLogicalResult mlirExecutionEngineInvokePacked(
MlirExecutionEngine jit, MlirStringRef name, void **arguments);

/// Lookup a native function in the execution engine by name, returns nullptr
/// if the name can't be looked-up.
MLIR_CAPI_EXPORTED void *mlirExecutionEngineLookup(MlirExecutionEngine jit,
MlirStringRef name);

#ifdef __cplusplus
}
#endif
Expand Down
5 changes: 4 additions & 1 deletion mlir/lib/Bindings/Python/CMakeLists.txt
Expand Up @@ -8,11 +8,12 @@ add_custom_target(MLIRBindingsPythonExtension)
set(PY_SRC_FILES
mlir/__init__.py
mlir/_dlloader.py
mlir/ir.py
mlir/conversions/__init__.py
mlir/dialects/__init__.py
mlir/dialects/_linalg.py
mlir/dialects/_builtin.py
mlir/ir.py
mlir/execution_engine.py
mlir/passmanager.py
mlir/transforms/__init__.py
)
Expand Down Expand Up @@ -74,6 +75,7 @@ add_mlir_python_extension(MLIRCoreBindingsPythonExtension _mlir
IRModules.cpp
PybindUtils.cpp
Pass.cpp
ExecutionEngine.cpp
)
add_dependencies(MLIRBindingsPythonExtension MLIRCoreBindingsPythonExtension)

Expand Down Expand Up @@ -114,3 +116,4 @@ if (NOT LLVM_ENABLE_IDE)
endif()

add_subdirectory(Transforms)
add_subdirectory(Conversions)
10 changes: 10 additions & 0 deletions mlir/lib/Bindings/Python/Conversions/CMakeLists.txt
@@ -0,0 +1,10 @@
################################################################################
# Build python extension
################################################################################

add_mlir_python_extension(MLIRConversionsBindingsPythonExtension _mlirConversions
INSTALL_DIR
python
SOURCES
Conversions.cpp
)
24 changes: 24 additions & 0 deletions mlir/lib/Bindings/Python/Conversions/Conversions.cpp
@@ -0,0 +1,24 @@
//===- Conversions.cpp - Pybind module for the Conversionss library -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir-c/Conversion.h"

#include <pybind11/pybind11.h>

namespace py = pybind11;

// -----------------------------------------------------------------------------
// Module initialization.
// -----------------------------------------------------------------------------

PYBIND11_MODULE(_mlirConversions, m) {
m.doc() = "MLIR Conversions library";

// Register all the passes in the Conversions library on load.
mlirRegisterConversionPasses();
}
87 changes: 87 additions & 0 deletions mlir/lib/Bindings/Python/ExecutionEngine.cpp
@@ -0,0 +1,87 @@
//===- ExecutionEngine.cpp - Python MLIR ExecutionEngine Bindings ---------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "ExecutionEngine.h"

#include "IRModules.h"
#include "mlir-c/Bindings/Python/Interop.h"
#include "mlir-c/ExecutionEngine.h"

namespace py = pybind11;
using namespace mlir;
using namespace mlir::python;

namespace {

/// Owning Wrapper around an ExecutionEngine.
class PyExecutionEngine {
public:
PyExecutionEngine(MlirExecutionEngine executionEngine)
: executionEngine(executionEngine) {}
PyExecutionEngine(PyExecutionEngine &&other)
: executionEngine(other.executionEngine) {
other.executionEngine.ptr = nullptr;
}
~PyExecutionEngine() {
if (!mlirExecutionEngineIsNull(executionEngine))
mlirExecutionEngineDestroy(executionEngine);
}
MlirExecutionEngine get() { return executionEngine; }

void release() { executionEngine.ptr = nullptr; }
pybind11::object getCapsule() {
return py::reinterpret_steal<py::object>(
mlirPythonExecutionEngineToCapsule(get()));
}

static pybind11::object createFromCapsule(pybind11::object capsule) {
MlirExecutionEngine rawPm =
mlirPythonCapsuleToExecutionEngine(capsule.ptr());
if (mlirExecutionEngineIsNull(rawPm))
throw py::error_already_set();
return py::cast(PyExecutionEngine(rawPm), py::return_value_policy::move);
}

private:
MlirExecutionEngine executionEngine;
};

} // anonymous namespace

/// Create the `mlir.execution_engine` module here.
void mlir::python::populateExecutionEngineSubmodule(py::module &m) {
//----------------------------------------------------------------------------
// Mapping of the top-level PassManager
//----------------------------------------------------------------------------
py::class_<PyExecutionEngine>(m, "ExecutionEngine")
.def(py::init<>([](PyModule &module) {
MlirExecutionEngine executionEngine =
mlirExecutionEngineCreate(module.get());
if (mlirExecutionEngineIsNull(executionEngine))
throw std::runtime_error(
"Failure while creating the ExecutionEngine.");
return new PyExecutionEngine(executionEngine);
}),
"Create a new ExecutionEngine instance for the given Module. The "
"module must "
"contain only dialects that can be translated to LLVM.")
.def_property_readonly(MLIR_PYTHON_CAPI_PTR_ATTR,
&PyExecutionEngine::getCapsule)
.def("_testing_release", &PyExecutionEngine::release,
"Releases (leaks) the backing ExecutionEngine (for testing purpose)")
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyExecutionEngine::createFromCapsule)
.def(
"raw_lookup",
[](PyExecutionEngine &executionEngine, const std::string &func) {
auto *res = mlirExecutionEngineLookup(
executionEngine.get(),
mlirStringRefCreate(func.c_str(), func.size()));
return (int64_t)res;
},
"Lookup function `func` in the ExecutionEngine.");
}
22 changes: 22 additions & 0 deletions mlir/lib/Bindings/Python/ExecutionEngine.h
@@ -0,0 +1,22 @@
//===- ExecutionEngine.h - ExecutionEngine submodule of pybind module -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_BINDINGS_PYTHON_EXECUTIONENGINE_H
#define MLIR_BINDINGS_PYTHON_EXECUTIONENGINE_H

#include "PybindUtils.h"

namespace mlir {
namespace python {

void populateExecutionEngineSubmodule(pybind11::module &m);

} // namespace python
} // namespace mlir

#endif // MLIR_BINDINGS_PYTHON_EXECUTIONENGINE_H
6 changes: 6 additions & 0 deletions mlir/lib/Bindings/Python/MainModule.cpp
Expand Up @@ -10,6 +10,7 @@

#include "PybindUtils.h"

#include "ExecutionEngine.h"
#include "Globals.h"
#include "IRModules.h"
#include "Pass.h"
Expand Down Expand Up @@ -216,4 +217,9 @@ PYBIND11_MODULE(_mlir, m) {
auto passModule =
m.def_submodule("passmanager", "MLIR Pass Management Bindings");
populatePassManagerSubmodule(passModule);

// Define and populate ExecutionEngine submodule.
auto executionEngineModule =
m.def_submodule("execution_engine", "MLIR JIT Execution Engine");
populateExecutionEngineSubmodule(executionEngineModule);
}
3 changes: 2 additions & 1 deletion mlir/lib/Bindings/Python/mlir/__init__.py
Expand Up @@ -10,6 +10,7 @@

__all__ = [
"ir",
"execution_engine",
"passmanager",
]

Expand Down Expand Up @@ -61,7 +62,7 @@ def _reexport_cext(cext_module_name, target_module_name):

# Import sub-modules. Since these may import from here, this must come after
# any exported definitions.
from . import ir, passmanager
from . import ir, execution_engine, passmanager

# Add our 'dialects' parent module to the search path for implementations.
_cext.globals.append_dialect_search_prefix("mlir.dialects")
8 changes: 8 additions & 0 deletions mlir/lib/Bindings/Python/mlir/conversions/__init__.py
@@ -0,0 +1,8 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# Expose the corresponding C-Extension module with a well-known name at this
# level.
from .. import _load_extension
_cextConversions = _load_extension("_mlirConversions")
31 changes: 31 additions & 0 deletions mlir/lib/Bindings/Python/mlir/execution_engine.py
@@ -0,0 +1,31 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

# Simply a wrapper around the extension module of the same name.
from . import _cext
import ctypes

class ExecutionEngine(_cext.execution_engine.ExecutionEngine):

def lookup(self, name):
"""Lookup a function emitted with the `llvm.emit_c_interface`
attribute and returns a ctype callable.
Raise a RuntimeError if the function isn't found.
"""
func = self.raw_lookup("_mlir_ciface_" + name)
if not func:
raise RuntimeError("Unknown function " + name)
prototype = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
return prototype(func)

def invoke(self, name, *ctypes_args):
"""Invoke a function with the list of ctypes arguments.
All arguments must be pointers.
Raise a RuntimeError if the function isn't found.
"""
func = self.lookup(name)
packed_args = (ctypes.c_void_p * len(ctypes_args))()
for argNum in range(len(ctypes_args)):
packed_args[argNum] = ctypes.cast(ctypes_args[argNum], ctypes.c_void_p)
func(packed_args)
10 changes: 10 additions & 0 deletions mlir/lib/CAPI/ExecutionEngine/ExecutionEngine.cpp
Expand Up @@ -10,6 +10,7 @@
#include "mlir/CAPI/ExecutionEngine.h"
#include "mlir/CAPI/IR.h"
#include "mlir/CAPI/Support.h"
#include "mlir/Target/LLVMIR.h"
#include "llvm/Support/TargetSelect.h"

using namespace mlir;
Expand All @@ -22,6 +23,7 @@ extern "C" MlirExecutionEngine mlirExecutionEngineCreate(MlirModule op) {
}();
(void)init_once;

mlir::registerLLVMDialectTranslation(*unwrap(op)->getContext());
auto jitOrError = ExecutionEngine::create(unwrap(op));
if (!jitOrError) {
consumeError(jitOrError.takeError());
Expand All @@ -44,3 +46,11 @@ mlirExecutionEngineInvokePacked(MlirExecutionEngine jit, MlirStringRef name,
return wrap(failure());
return wrap(success());
}

extern "C" void *mlirExecutionEngineLookup(MlirExecutionEngine jit,
MlirStringRef name) {
auto expectedFPtr = unwrap(jit)->lookup(unwrap(name));
if (!expectedFPtr)
return nullptr;
return reinterpret_cast<void *>(*expectedFPtr);
}

0 comments on commit 13cb431

Please sign in to comment.