Skip to content

Commit a60e8a2

Browse files
authored
[mlir] Python: write bytecode to a file path (llvm#127118)
The current `write_bytecode` implementation necessarily requires the serialized module to be duplicated in memory when the python `bytes` object is created and sent over the binding. For modules with large resources, we may want to avoid this in-memory copy by serializing directly to a file instead of sending bytes across the boundary.
1 parent 62ec7b8 commit a60e8a2

File tree

4 files changed

+59
-18
lines changed

4 files changed

+59
-18
lines changed

mlir/lib/Bindings/Python/IRCore.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,25 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include <optional>
10-
#include <utility>
11-
129
#include "Globals.h"
1310
#include "IRModule.h"
1411
#include "NanobindUtils.h"
12+
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
1513
#include "mlir-c/BuiltinAttributes.h"
1614
#include "mlir-c/Debug.h"
1715
#include "mlir-c/Diagnostics.h"
1816
#include "mlir-c/IR.h"
1917
#include "mlir-c/Support.h"
2018
#include "mlir/Bindings/Python/Nanobind.h"
2119
#include "mlir/Bindings/Python/NanobindAdaptors.h"
22-
#include "mlir-c/Bindings/Python/Interop.h" // This is expected after nanobind.
20+
#include "nanobind/nanobind.h"
2321
#include "llvm/ADT/ArrayRef.h"
2422
#include "llvm/ADT/SmallVector.h"
23+
#include "llvm/Support/raw_ostream.h"
24+
25+
#include <optional>
26+
#include <system_error>
27+
#include <utility>
2528

2629
namespace nb = nanobind;
2730
using namespace nb::literals;
@@ -1329,11 +1332,11 @@ void PyOperationBase::print(PyAsmState &state, nb::object fileObject,
13291332
accum.getUserData());
13301333
}
13311334

1332-
void PyOperationBase::writeBytecode(const nb::object &fileObject,
1335+
void PyOperationBase::writeBytecode(const nb::object &fileOrStringObject,
13331336
std::optional<int64_t> bytecodeVersion) {
13341337
PyOperation &operation = getOperation();
13351338
operation.checkValid();
1336-
PyFileAccumulator accum(fileObject, /*binary=*/true);
1339+
PyFileAccumulator accum(fileOrStringObject, /*binary=*/true);
13371340

13381341
if (!bytecodeVersion.has_value())
13391342
return mlirOperationWriteBytecode(operation, accum.getCallback(),

mlir/lib/Bindings/Python/NanobindUtils.h

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,13 @@
1313
#include "mlir-c/Support.h"
1414
#include "mlir/Bindings/Python/Nanobind.h"
1515
#include "llvm/ADT/STLExtras.h"
16+
#include "llvm/ADT/StringRef.h"
1617
#include "llvm/ADT/Twine.h"
1718
#include "llvm/Support/DataTypes.h"
19+
#include "llvm/Support/raw_ostream.h"
20+
21+
#include <string>
22+
#include <variant>
1823

1924
template <>
2025
struct std::iterator_traits<nanobind::detail::fast_iterator> {
@@ -128,33 +133,59 @@ struct PyPrintAccumulator {
128133
}
129134
};
130135

131-
/// Accumulates int a python file-like object, either writing text (default)
132-
/// or binary.
136+
/// Accumulates into a file, either writing text (default)
137+
/// or binary. The file may be a Python file-like object or a path to a file.
133138
class PyFileAccumulator {
134139
public:
135-
PyFileAccumulator(const nanobind::object &fileObject, bool binary)
136-
: pyWriteFunction(fileObject.attr("write")), binary(binary) {}
140+
PyFileAccumulator(const nanobind::object &fileOrStringObject, bool binary)
141+
: binary(binary) {
142+
std::string filePath;
143+
if (nanobind::try_cast<std::string>(fileOrStringObject, filePath)) {
144+
std::error_code ec;
145+
writeTarget.emplace<llvm::raw_fd_ostream>(filePath, ec);
146+
if (ec) {
147+
throw nanobind::value_error(
148+
(std::string("Unable to open file for writing: ") + ec.message())
149+
.c_str());
150+
}
151+
} else {
152+
writeTarget.emplace<nanobind::object>(fileOrStringObject.attr("write"));
153+
}
154+
}
155+
156+
MlirStringCallback getCallback() {
157+
return writeTarget.index() == 0 ? getPyWriteCallback()
158+
: getOstreamCallback();
159+
}
137160

138161
void *getUserData() { return this; }
139162

140-
MlirStringCallback getCallback() {
163+
private:
164+
MlirStringCallback getPyWriteCallback() {
141165
return [](MlirStringRef part, void *userData) {
142166
nanobind::gil_scoped_acquire acquire;
143167
PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
144168
if (accum->binary) {
145169
// Note: Still has to copy and not avoidable with this API.
146170
nanobind::bytes pyBytes(part.data, part.length);
147-
accum->pyWriteFunction(pyBytes);
171+
std::get<nanobind::object>(accum->writeTarget)(pyBytes);
148172
} else {
149173
nanobind::str pyStr(part.data,
150174
part.length); // Decodes as UTF-8 by default.
151-
accum->pyWriteFunction(pyStr);
175+
std::get<nanobind::object>(accum->writeTarget)(pyStr);
152176
}
153177
};
154178
}
155179

156-
private:
157-
nanobind::object pyWriteFunction;
180+
MlirStringCallback getOstreamCallback() {
181+
return [](MlirStringRef part, void *userData) {
182+
PyFileAccumulator *accum = static_cast<PyFileAccumulator *>(userData);
183+
std::get<llvm::raw_fd_ostream>(accum->writeTarget)
184+
.write(part.data, part.length);
185+
};
186+
}
187+
188+
std::variant<nanobind::object, llvm::raw_fd_ostream> writeTarget;
158189
bool binary;
159190
};
160191

mlir/python/mlir/_mlir_libs/_mlir/ir.pyi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ import collections
4747
from collections.abc import Callable, Sequence
4848
import io
4949
from pathlib import Path
50-
from typing import Any, ClassVar, TypeVar, overload
50+
from typing import Any, BinaryIO, ClassVar, TypeVar, overload
5151

5252
__all__ = [
5353
"AffineAddExpr",
@@ -285,12 +285,12 @@ class _OperationBase:
285285
"""
286286
Verify the operation. Raises MLIRError if verification fails, and returns true otherwise.
287287
"""
288-
def write_bytecode(self, file: Any, desired_version: int | None = None) -> None:
288+
def write_bytecode(self, file: BinaryIO | str, desired_version: int | None = None) -> None:
289289
"""
290290
Write the bytecode form of the operation to a file like object.
291291
292292
Args:
293-
file: The file like object to write to.
293+
file: The file like object or path to write to.
294294
desired_version: The version of bytecode to emit.
295295
Returns:
296296
The bytecode writer status.

mlir/test/python/ir/operation.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import gc
44
import io
55
import itertools
6+
from tempfile import NamedTemporaryFile
67
from mlir.ir import *
78
from mlir.dialects.builtin import ModuleOp
89
from mlir.dialects import arith
@@ -617,6 +618,12 @@ def testOperationPrint():
617618
module.operation.write_bytecode(bytecode_stream, desired_version=1)
618619
bytecode = bytecode_stream.getvalue()
619620
assert bytecode.startswith(b"ML\xefR"), "Expected bytecode to start with MLïR"
621+
with NamedTemporaryFile() as tmpfile:
622+
module.operation.write_bytecode(str(tmpfile.name), desired_version=1)
623+
tmpfile.seek(0)
624+
assert tmpfile.read().startswith(
625+
b"ML\xefR"
626+
), "Expected bytecode to start with MLïR"
620627
ctx2 = Context()
621628
module_roundtrip = Module.parse(bytecode, ctx2)
622629
f = io.StringIO()

0 commit comments

Comments
 (0)