/
xgboost_wrapper.py
81 lines (71 loc) · 2.77 KB
/
xgboost_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import warnings
from typing import Dict, Optional
import numpy
from coffea.ml_tools.helper import nonserializable_attribute, numpy_call_wrapper
_xgboost_import_error = None
try:
import xgboost
except (ImportError, ModuleNotFoundError) as err:
_xgboost_import_error = err
class xgboost_wrapper(numpy_call_wrapper, nonserializable_attribute):
"""
Very simple wrapper for xgbooster inference. The xgboost.Booster object is
nonserializable, so the users should pass in the xgboost model file.
"""
def __init__(self, fname):
if _xgboost_import_error is not None:
warnings.warn(
"Users should make sure the xgboost package is installed before proceeding!\n"
"> pip install xgboost==1.5.1\n"
"or\n"
"> conda install xgboost==1.5.1",
UserWarning,
)
raise _xgboost_import_error
nonserializable_attribute.__init__(self, ["xgbooster"])
self.xgboost_file = fname
def _create_xgbooster(self) -> xgboost.Booster:
# Automatic detection of compressed model file
return xgboost.Booster(model_file=self.xgboost_file)
def validate_numpy_input(
self,
data: numpy.ndarray,
dmat_args: Optional[Dict] = None,
predict_args: Optional[Dict] = None,
):
"""
The inner most dimension of the data array should be smaller than the
number of features of the xgboost model. (Will raise a warning if
mismatched). We will not attempt to parse the kwargs passed to the
construction of a DMatrix, or the predict call, as those advanced
features are expected to be properly handled by the user.
"""
ndims = data.shape[-1]
nfeat = self.xgbooster.num_features()
if ndims > nfeat:
raise ValueError(
f"Input shape {data.shape} exceeded number of features ({nfeat})"
)
elif ndims < nfeat:
warnings.warn(
f"Input shape {data.shape} smaller than number of features ({nfeat})",
UserWarning,
)
def numpy_call(
self,
data: numpy.ndarray,
dmat_args: Optional[Dict] = None,
predict_args: Optional[Dict] = None,
):
"""
Passing the numpy array data as-is to the construction of an
xgboost.DMatrix constructor (with additional keyword arguments should
they be specified), the run the xgboost.Booster.predict method (with
additional keyword arguments).
"""
if dmat_args is None:
dmat_args = {}
if predict_args is None:
predict_args = {}
mat = xgboost.DMatrix(data, **dmat_args)
return self.xgbooster.predict(mat, **predict_args)