Skip to content

Commit dac9eb1

Browse files
authored
Implement new save_raw in Python. (dmlc#7572)
* Expose the new C API function to Python. * Remove old document and helper script. * Small optimization to the `save_raw` and Json ctors.
1 parent 9f20a33 commit dac9eb1

File tree

8 files changed

+104
-150
lines changed

8 files changed

+104
-150
lines changed

doc/python/convert_090to100.py

Lines changed: 0 additions & 79 deletions
This file was deleted.

doc/tutorials/saving_model.rst

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,18 @@
22
Introduction to Model IO
33
########################
44

5-
In XGBoost 1.0.0, we introduced experimental support of using `JSON
5+
In XGBoost 1.0.0, we introduced support of using `JSON
66
<https://www.json.org/json-en.html>`_ for saving/loading XGBoost models and related
77
hyper-parameters for training, aiming to replace the old binary internal format with an
8-
open format that can be easily reused. The support for binary format will be continued in
9-
the future until JSON format is no-longer experimental and has satisfying performance.
10-
This tutorial aims to share some basic insights into the JSON serialisation method used in
11-
XGBoost. Without explicitly mentioned, the following sections assume you are using the
12-
JSON format, which can be enabled by providing the file name with ``.json`` as file
13-
extension when saving/loading model: ``booster.save_model('model.json')``. More details
14-
below.
8+
open format that can be easily reused. Later in XGBoost 1.6.0, additional support for
9+
`Universal Binary JSON <https://ubjson.org/>`__ is added as an optimization for more
10+
efficient model IO. They have the same document structure with different representations,
11+
and we will refer them collectively as the JSON format. This tutorial aims to share some
12+
basic insights into the JSON serialisation method used in XGBoost. Without explicitly
13+
mentioned, the following sections assume you are using the one of the 2 outputs formats,
14+
which can be enabled by providing the file name with ``.json`` (or ``.ubj`` for binary
15+
JSON) as file extension when saving/loading model: ``booster.save_model('model.json')``.
16+
More details below.
1517

1618
Before we get started, XGBoost is a gradient boosting library with focus on tree model,
1719
which means inside XGBoost, there are 2 distinct parts:
@@ -53,7 +55,8 @@ Other language bindings are still working in progress.
5355
based serialisation methods.
5456

5557
To enable JSON format support for model IO (saving only the trees and objective), provide
56-
a filename with ``.json`` as file extension:
58+
a filename with ``.json`` or ``.ubj`` as file extension, the latter is the extension for
59+
`Universal Binary JSON <https://ubjson.org/>`__
5760

5861
.. code-block:: python
5962
:caption: Python
@@ -65,7 +68,7 @@ a filename with ``.json`` as file extension:
6568
6669
xgb.save(bst, 'model_file_name.json')
6770
68-
While for memory snapshot, JSON is the default starting with xgboost 1.3.
71+
While for memory snapshot, UBJSON is the default starting with xgboost 1.6.
6972

7073
***************************************************************
7174
A note on backward compatibility of models and memory snapshots
@@ -105,15 +108,10 @@ Loading pickled file from different version of XGBoost
105108

106109
As noted, pickled model is neither portable nor stable, but in some cases the pickled
107110
models are valuable. One way to restore it in the future is to load it back with that
108-
specific version of Python and XGBoost, export the model by calling `save_model`. To help
109-
easing the mitigation, we created a simple script for converting pickled XGBoost 0.90
110-
Scikit-Learn interface object to XGBoost 1.0.0 native model. Please note that the script
111-
suits simple use cases, and it's advised not to use pickle when stability is needed. It's
112-
located in ``xgboost/doc/python`` with the name ``convert_090to100.py``. See comments in
113-
the script for more details.
111+
specific version of Python and XGBoost, export the model by calling `save_model`.
114112

115-
A similar procedure may be used to recover the model persisted in an old RDS file. In R, you are
116-
able to install an older version of XGBoost using the ``remotes`` package:
113+
A similar procedure may be used to recover the model persisted in an old RDS file. In R,
114+
you are able to install an older version of XGBoost using the ``remotes`` package:
117115

118116
.. code-block:: r
119117
@@ -244,10 +242,3 @@ leaf directly, instead it saves the weights as a separated array.
244242

245243
.. include:: ../model.schema
246244
:code: json
247-
248-
************
249-
Future Plans
250-
************
251-
252-
Right now using the JSON format incurs longer serialisation time, we have been working on
253-
optimizing the JSON implementation to close the gap between binary format and JSON format.

include/xgboost/json.h

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,10 @@ class JsonString : public Value {
8989
JsonString(std::string const& str) : // NOLINT
9090
Value(ValueKind::kString), str_{str} {}
9191
JsonString(std::string&& str) noexcept : // NOLINT
92-
Value(ValueKind::kString), str_{std::move(str)} {}
93-
JsonString(JsonString&& str) noexcept : // NOLINT
94-
Value(ValueKind::kString), str_{std::move(str.str_)} {}
92+
Value(ValueKind::kString), str_{std::forward<std::string>(str)} {}
93+
JsonString(JsonString&& str) noexcept : Value(ValueKind::kString) { // NOLINT
94+
std::swap(str.str_, this->str_);
95+
}
9596

9697
void Save(JsonWriter* writer) const override;
9798

@@ -111,8 +112,8 @@ class JsonArray : public Value {
111112

112113
public:
113114
JsonArray() : Value(ValueKind::kArray) {}
114-
JsonArray(std::vector<Json>&& arr) noexcept : // NOLINT
115-
Value(ValueKind::kArray), vec_{std::move(arr)} {}
115+
JsonArray(std::vector<Json>&& arr) noexcept // NOLINT
116+
: Value(ValueKind::kArray), vec_{std::forward<std::vector<Json>>(arr)} {}
116117
JsonArray(std::vector<Json> const& arr) : // NOLINT
117118
Value(ValueKind::kArray), vec_{arr} {}
118119
JsonArray(JsonArray const& that) = delete;
@@ -381,10 +382,9 @@ class Json {
381382
return *this;
382383
}
383384
// array
384-
explicit Json(JsonArray list) :
385-
ptr_ {new JsonArray(std::move(list))} {}
386-
Json& operator=(JsonArray array) {
387-
ptr_.reset(new JsonArray(std::move(array)));
385+
explicit Json(JsonArray&& list) : ptr_{new JsonArray(std::forward<JsonArray>(list))} {}
386+
Json& operator=(JsonArray&& array) {
387+
ptr_.reset(new JsonArray(std::forward<JsonArray>(array)));
388388
return *this;
389389
}
390390
// typed array
@@ -397,17 +397,15 @@ class Json {
397397
return *this;
398398
}
399399
// object
400-
explicit Json(JsonObject object) :
401-
ptr_{new JsonObject(std::move(object))} {}
402-
Json& operator=(JsonObject object) {
403-
ptr_.reset(new JsonObject(std::move(object)));
400+
explicit Json(JsonObject&& object) : ptr_{new JsonObject(std::forward<JsonObject>(object))} {}
401+
Json& operator=(JsonObject&& object) {
402+
ptr_.reset(new JsonObject(std::forward<JsonObject>(object)));
404403
return *this;
405404
}
406405
// string
407-
explicit Json(JsonString str) :
408-
ptr_{new JsonString(std::move(str))} {}
409-
Json& operator=(JsonString str) {
410-
ptr_.reset(new JsonString(std::move(str)));
406+
explicit Json(JsonString&& str) : ptr_{new JsonString(std::forward<JsonString>(str))} {}
407+
Json& operator=(JsonString&& str) {
408+
ptr_.reset(new JsonString(std::forward<JsonString>(str)));
411409
return *this;
412410
}
413411
// bool

include/xgboost/learner.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ enum class PredictionType : std::uint8_t { // NOLINT
4545
struct XGBAPIThreadLocalEntry {
4646
/*! \brief result holder for returning string */
4747
std::string ret_str;
48+
/*! \brief result holder for returning raw buffer */
49+
std::vector<char> ret_char_vec;
4850
/*! \brief result holder for returning strings */
4951
std::vector<std::string> ret_vec_str;
5052
/*! \brief result holder for returning string pointers */

python-package/xgboost/core.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2135,9 +2135,15 @@ def save_model(self, fname: Union[str, os.PathLike]) -> None:
21352135
21362136
The model is saved in an XGBoost internal format which is universal among the
21372137
various XGBoost interfaces. Auxiliary attributes of the Python Booster object
2138-
(such as feature_names) will not be saved when using binary format. To save those
2139-
attributes, use JSON instead. See :doc:`Model IO </tutorials/saving_model>` for
2140-
more info.
2138+
(such as feature_names) will not be saved when using binary format. To save
2139+
those attributes, use JSON/UBJ instead. See :doc:`Model IO
2140+
</tutorials/saving_model>` for more info.
2141+
2142+
.. code-block:: python
2143+
2144+
model.save_model("model.json")
2145+
# or
2146+
model.save_model("model.ubj")
21412147
21422148
Parameters
21432149
----------
@@ -2152,18 +2158,28 @@ def save_model(self, fname: Union[str, os.PathLike]) -> None:
21522158
else:
21532159
raise TypeError("fname must be a string or os PathLike")
21542160

2155-
def save_raw(self) -> bytearray:
2161+
def save_raw(self, raw_format: str = "deprecated") -> bytearray:
21562162
"""Save the model to a in memory buffer representation instead of file.
21572163
2164+
Parameters
2165+
----------
2166+
raw_format :
2167+
Format of output buffer. Can be `json`, `ubj` or `deprecated`. Right now
2168+
the default is `deprecated` but it will be changed to `ubj` (univeral binary
2169+
json) in the future.
2170+
21582171
Returns
21592172
-------
2160-
a in memory buffer representation of the model
2173+
An in memory buffer representation of the model
21612174
"""
21622175
length = c_bst_ulong()
21632176
cptr = ctypes.POINTER(ctypes.c_char)()
2164-
_check_call(_LIB.XGBoosterGetModelRaw(self.handle,
2165-
ctypes.byref(length),
2166-
ctypes.byref(cptr)))
2177+
config = from_pystr_to_cstr(json.dumps({"format": raw_format}))
2178+
_check_call(
2179+
_LIB.XGBoosterSaveModelToBuffer(
2180+
self.handle, config, ctypes.byref(length), ctypes.byref(cptr)
2181+
)
2182+
)
21672183
return ctypes2buffer(cptr, length.value)
21682184

21692185
def load_model(self, fname: Union[str, bytearray, os.PathLike]) -> None:
@@ -2173,8 +2189,14 @@ def load_model(self, fname: Union[str, bytearray, os.PathLike]) -> None:
21732189
The model is loaded from XGBoost format which is universal among the various
21742190
XGBoost interfaces. Auxiliary attributes of the Python Booster object (such as
21752191
feature_names) will not be loaded when using binary format. To save those
2176-
attributes, use JSON instead. See :doc:`Model IO </tutorials/saving_model>` for
2177-
more info.
2192+
attributes, use JSON/UBJ instead. See :doc:`Model IO </tutorials/saving_model>`
2193+
for more info.
2194+
2195+
.. code-block:: python
2196+
2197+
model.load_model("model.json")
2198+
# or
2199+
model.load_model("model.ubj")
21782200
21792201
Parameters
21802202
----------

src/c_api/c_api.cc

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -971,28 +971,34 @@ XGB_DLL int XGBoosterSaveModelToBuffer(BoosterHandle handle, char const *json_co
971971
auto format = RequiredArg<String>(config, "format", __func__);
972972

973973
auto *learner = static_cast<Learner *>(handle);
974-
std::string &raw_str = learner->GetThreadLocal().ret_str;
975-
raw_str.clear();
976-
977974
learner->Configure();
975+
976+
auto save_json = [&](std::ios::openmode mode) {
977+
std::vector<char> &raw_char_vec = learner->GetThreadLocal().ret_char_vec;
978+
Json out{Object{}};
979+
learner->SaveModel(&out);
980+
Json::Dump(out, &raw_char_vec, mode);
981+
*out_dptr = dmlc::BeginPtr(raw_char_vec);
982+
*out_len = static_cast<xgboost::bst_ulong>(raw_char_vec.size());
983+
};
984+
978985
Json out{Object{}};
979986
if (format == "json") {
980-
learner->SaveModel(&out);
981-
Json::Dump(out, &raw_str);
987+
save_json(std::ios::out);
982988
} else if (format == "ubj") {
983-
learner->SaveModel(&out);
984-
Json::Dump(out, &raw_str, std::ios::binary);
989+
save_json(std::ios::binary);
985990
} else if (format == "deprecated") {
986991
WarnOldModel();
992+
auto &raw_str = learner->GetThreadLocal().ret_str;
993+
raw_str.clear();
987994
common::MemoryBufferStream fo(&raw_str);
988995
learner->SaveModel(&fo);
996+
*out_dptr = dmlc::BeginPtr(raw_str);
997+
*out_len = static_cast<xgboost::bst_ulong>(raw_str.size());
989998
} else {
990999
LOG(FATAL) << "Unknown format: `" << format << "`";
9911000
}
9921001

993-
*out_dptr = dmlc::BeginPtr(raw_str);
994-
*out_len = static_cast<xgboost::bst_ulong>(raw_str.length());
995-
9961002
API_END();
9971003
}
9981004

src/common/json.cc

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -195,11 +195,12 @@ Json& Value::operator[](int) {
195195
}
196196

197197
// Json Object
198-
JsonObject::JsonObject(JsonObject && that) noexcept :
199-
Value(ValueKind::kObject), object_{std::move(that.object_)} {}
198+
JsonObject::JsonObject(JsonObject&& that) noexcept : Value(ValueKind::kObject) {
199+
std::swap(that.object_, this->object_);
200+
}
200201

201-
JsonObject::JsonObject(std::map<std::string, Json> &&object) noexcept
202-
: Value(ValueKind::kObject), object_{std::move(object)} {}
202+
JsonObject::JsonObject(std::map<std::string, Json>&& object) noexcept
203+
: Value(ValueKind::kObject), object_{std::forward<std::map<std::string, Json>>(object)} {}
203204

204205
bool JsonObject::operator==(Value const& rhs) const {
205206
if (!IsA<JsonObject>(&rhs)) {
@@ -220,8 +221,9 @@ bool JsonString::operator==(Value const& rhs) const {
220221
void JsonString::Save(JsonWriter* writer) const { writer->Visit(this); }
221222

222223
// Json Array
223-
JsonArray::JsonArray(JsonArray && that) noexcept :
224-
Value(ValueKind::kArray), vec_{std::move(that.vec_)} {}
224+
JsonArray::JsonArray(JsonArray&& that) noexcept : Value(ValueKind::kArray) {
225+
std::swap(that.vec_, this->vec_);
226+
}
225227

226228
bool JsonArray::operator==(Value const& rhs) const {
227229
if (!IsA<JsonArray>(&rhs)) {
@@ -696,6 +698,7 @@ void Json::Dump(Json json, std::string* str, std::ios::openmode mode) {
696698
}
697699

698700
void Json::Dump(Json json, std::vector<char>* str, std::ios::openmode mode) {
701+
str->clear();
699702
if (mode & std::ios::binary) {
700703
UBJWriter writer{str};
701704
writer.Save(json);
@@ -768,9 +771,7 @@ std::string UBJReader::DecodeStr() {
768771
str.resize(bsize);
769772
auto ptr = raw_str_.c_str() + cursor_.Pos();
770773
std::memcpy(&str[0], ptr, bsize);
771-
for (int64_t i = 0; i < bsize; ++i) {
772-
this->cursor_.Forward();
773-
}
774+
this->cursor_.Forward(bsize);
774775
return str;
775776
}
776777

0 commit comments

Comments
 (0)