diff --git a/merlin/systems/dag/ops/compat.py b/merlin/systems/dag/ops/compat.py index 96e59c294..af7286197 100644 --- a/merlin/systems/dag/ops/compat.py +++ b/merlin/systems/dag/ops/compat.py @@ -31,3 +31,7 @@ import cuml.ensemble as cuml_ensemble except ImportError: cuml_ensemble = None +try: + import triton_python_backend_utils as pb_utils +except ImportError: + pb_utils = None diff --git a/merlin/systems/dag/ops/fil.py b/merlin/systems/dag/ops/fil.py index 43288802b..7338490fc 100644 --- a/merlin/systems/dag/ops/fil.py +++ b/merlin/systems/dag/ops/fil.py @@ -1,3 +1,18 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# import json import pathlib import pickle @@ -9,8 +24,142 @@ from merlin.dag import ColumnSelector # noqa from merlin.schema import ColumnSchema, Schema # noqa -from merlin.systems.dag.ops.compat import cuml_ensemble, lightgbm, sklearn_ensemble, xgboost -from merlin.systems.dag.ops.operator import InferenceOperator +from merlin.systems.dag.ops.compat import ( + cuml_ensemble, + lightgbm, + pb_utils, + sklearn_ensemble, + xgboost, +) +from merlin.systems.dag.ops.operator import ( + InferenceDataFrame, + InferenceOperator, + PipelineableInferenceOperator, +) + + +class PredictForest(PipelineableInferenceOperator): + """Operator for running inference on Forest models. + + This works for gradient-boosted decision trees (GBDTs) and Random forests (RF). + While RF and GBDT algorithms differ in the way they train the models, + they both produce a decision forest as their output. + + Uses the Forest Inference Library (FIL) backend for inference. + """ + + def __init__(self, model, input_schema, *, backend="python", **fil_params): + """Instantiate a FIL inference operator. + + Parameters + ---------- + model : Forest Model Instance + A forest model class. Supports XGBoost, LightGBM, and Scikit-Learn. + input_schema : merlin.schema.Schema + The schema representing the input columns expected by the model. + backend : str + The Triton backend to use to when running this operator. + **fil_params + The parameters to pass to the FIL operator. + """ + if model is not None: + self.fil_op = FIL(model, **fil_params) + self.backend = backend + self.input_schema = input_schema + self._fil_model_name = None + + def compute_output_schema( + self, + input_schema: Schema, + col_selector: ColumnSelector, + prev_output_schema: Schema = None, + ) -> Schema: + """Return the output schema representing the columns this operator returns.""" + return self.fil_op.compute_output_schema( + input_schema, col_selector, prev_output_schema=prev_output_schema + ) + + def compute_input_schema( + self, + root_schema: Schema, + parents_schema: Schema, + deps_schema: Schema, + selector: ColumnSelector, + ) -> Schema: + """Return the input schema representing the input columns this operator expects to use.""" + return self.input_schema + + def export(self, path, input_schema, output_schema, params=None, node_id=None, version=1): + """Export the class and related files to the path specified.""" + fil_model_config = self.fil_op.export( + path, + input_schema, + output_schema, + params=params, + node_id=node_id, + version=version, + ) + params = params or {} + params = {**params, "fil_model_name": fil_model_config.name} + return super().export( + path, + input_schema, + output_schema, + params=params, + node_id=node_id, + version=version, + backend=self.backend, + ) + + @classmethod + def from_config(cls, config: dict) -> "PredictForest": + """Instantiate the class from a dictionary representation. + + Expected structure: + { + "input_dict": str # JSON dict with input names and schemas + "params": str # JSON dict with params saved at export + } + + """ + column_schemas = [ + ColumnSchema(name, **schema_properties) + for name, schema_properties in json.loads(config["input_dict"]).items() + ] + input_schema = Schema(column_schemas) + cls_instance = cls(None, input_schema) + params = json.loads(config["params"]) + cls_instance.set_fil_model_name(params["fil_model_name"]) + return cls_instance + + @property + def fil_model_name(self): + return self._fil_model_name + + def set_fil_model_name(self, fil_model_name): + self._fil_model_name = fil_model_name + + def transform(self, df: InferenceDataFrame) -> InferenceDataFrame: + """Transform the dataframe by applying this FIL operator to the set of input columns. + + Parameters + ----------- + df: InferenceDataFrame + A pandas or cudf dataframe that this operator will work on + + Returns + ------- + InferenceDataFrame + Returns a transformed dataframe for this operator""" + input0 = np.array([x.ravel() for x in df.tensors.values()]).astype(np.float32).T + inference_request = pb_utils.InferenceRequest( + model_name=self.fil_model_name, + requested_output_names=["output__0"], + inputs=[pb_utils.Tensor("input__0", input0)], + ) + inference_response = inference_request.exec() + output0 = pb_utils.get_output_tensor_by_name(inference_response, "output__0") + return InferenceDataFrame({"output__0": output0}) class FIL(InferenceOperator): @@ -32,6 +181,7 @@ def __init__( threads_per_tree=1, blocks_per_sm=0, transfer_threshold=0, + instance_group="AUTO", ): """Instantiate a FIL inference operator. @@ -88,6 +238,9 @@ def __init__( to the GPU for processing) will provide optimal latency and throughput, but for low-latency deployments with the use_experimental_optimizations flag set to true, higher values may be desirable. + instance_group : str + One of "AUTO", "GPU", "CPU". Default value is "AUTO". Specifies whether + inference will take place on the GPU or CPU. """ self.max_batch_size = max_batch_size self.parameters = dict( @@ -98,6 +251,7 @@ def __init__( blocks_per_sm=blocks_per_sm, storage_type=storage_type, threshold=threshold, + instance_group=instance_group, ) self.fil_model = get_fil_model(model) super().__init__() @@ -121,7 +275,15 @@ def compute_output_schema( """Returns output schema for FIL op""" return Schema([ColumnSchema("output__0", dtype=np.float32)]) - def export(self, path, input_schema, output_schema, node_id=None, version=1): + def export( + self, + path, + input_schema, + output_schema, + params: dict = None, + node_id=None, + version=1, + ): """Export the model to the supplied path. Returns the config""" node_name = f"{node_id}_{self.export_name}" if node_id is not None else self.export_name node_export_path = pathlib.Path(path) / node_name @@ -391,6 +553,7 @@ def fil_config( blocks_per_sm=0, threads_per_tree=1, transfer_threshold=0, + instance_group="AUTO", ) -> model_config.ModelConfig: """Construct and return a FIL ModelConfig protobuf object. @@ -453,6 +616,9 @@ def fil_config( to the GPU for processing) will provide optimal latency and throughput, but for low-latency deployments with the use_experimental_optimizations flag set to true, higher values may be desirable. + instance_group : str + One of "AUTO", "GPU", "CPU". Default value is "AUTO". Specifies whether + inference will take place on the GPU or CPU. Returns model_config.ModelConfig @@ -485,6 +651,17 @@ def fil_config( "transfer_threshold": f"{transfer_threshold:d}", } + supported_instance_groups = {"auto", "cpu", "gpu"} + instance_group = instance_group.lower() if isinstance(instance_group, str) else instance_group + if instance_group == "auto": + instance_group_kind = model_config.ModelInstanceGroup.Kind.KIND_AUTO + elif instance_group == "cpu": + instance_group_kind = model_config.ModelInstanceGroup.Kind.KIND_CPU + elif instance_group == "gpu": + instance_group_kind = model_config.ModelInstanceGroup.Kind.KIND_GPU + else: + raise ValueError(f"instance_group must be one of {supported_instance_groups}") + config = model_config.ModelConfig( name=name, backend="fil", @@ -501,9 +678,7 @@ def fil_config( name="output__0", data_type=model_config.TYPE_FP32, dims=[output_dim] ) ], - instance_group=[ - model_config.ModelInstanceGroup(kind=model_config.ModelInstanceGroup.Kind.KIND_AUTO) - ], + instance_group=[model_config.ModelInstanceGroup(kind=instance_group_kind)], ) for parameter_key, parameter_value in parameters.items(): diff --git a/merlin/systems/dag/ops/operator.py b/merlin/systems/dag/ops/operator.py index 70ef01aa9..a181a6619 100644 --- a/merlin/systems/dag/ops/operator.py +++ b/merlin/systems/dag/ops/operator.py @@ -166,6 +166,7 @@ def export( params: dict = None, node_id: int = None, version: int = 1, + backend: str = "python", ): """ Export the class object as a config and all related files to the user-defined path. @@ -200,7 +201,7 @@ def export( node_export_path = pathlib.Path(path) / node_name node_export_path.mkdir(parents=True, exist_ok=True) - config = model_config.ModelConfig(name=node_name, backend="python", platform="op_runner") + config = model_config.ModelConfig(name=node_name, backend=backend, platform="op_runner") config.parameters["operator_names"].string_value = json.dumps([node_name]) diff --git a/merlin/systems/triton/oprunner_model.py b/merlin/systems/triton/oprunner_model.py index 2eaeac3ab..54e132d14 100644 --- a/merlin/systems/triton/oprunner_model.py +++ b/merlin/systems/triton/oprunner_model.py @@ -42,11 +42,48 @@ class TritonPythonModel: + """Model for Triton Python Backend. + + Every Python model must have "TritonPythonModel" as the class name + """ + def initialize(self, args): + """Called only once when the model is being loaded. Allowing + the model to initialize any state associated with this model. + + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ self.model_config = json.loads(args["model_config"]) self.runner = OperatorRunner(self.model_config) def execute(self, requests: List[InferenceRequest]) -> List[InferenceResponse]: + """Receives a list of pb_utils.InferenceRequest as the only argument. This + function is called when an inference is requested for this model. Depending on the + batching configuration (e.g. Dynamic Batching) used, `requests` may contain + multiple requests. Every Python model, must create one pb_utils.InferenceResponse + for every pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ params = self.model_config["parameters"] op_names = json.loads(params["operator_names"]["string_value"]) first_operator_name = op_names[0] @@ -67,14 +104,15 @@ def execute(self, requests: List[InferenceRequest]) -> List[InferenceResponse]: raw_tensor_tuples = self.runner.execute(inf_df) - tensors = { - name: (data.get() if hasattr(data, "get") else data) - for name, data in raw_tensor_tuples - } - - result = [Tensor(name, data) for name, data in tensors.items()] + output_tensors = [] + for name, data in raw_tensor_tuples: + if isinstance(data, Tensor): + output_tensors.append(data) + data = data.get() if hasattr(data, "get") else data + tensor = Tensor(name, data) + output_tensors.append(tensor) - responses.append(InferenceResponse(result)) + responses.append(InferenceResponse(output_tensors)) except Exception: # pylint: disable=broad-except exc_type, exc_value, exc_traceback = sys.exc_info() diff --git a/tests/unit/systems/fil/test_forest.py b/tests/unit/systems/fil/test_forest.py new file mode 100644 index 000000000..8e2d546c4 --- /dev/null +++ b/tests/unit/systems/fil/test_forest.py @@ -0,0 +1,147 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import json + +import numpy as np +import pandas as pd +import pytest +import sklearn.datasets +import xgboost +from google.protobuf import text_format + +from merlin.dag import ColumnSelector +from merlin.io import Dataset +from merlin.schema import ColumnSchema, Schema +from merlin.systems.dag.ensemble import Ensemble +from merlin.systems.dag.ops.fil import PredictForest +from merlin.systems.dag.ops.workflow import TransformWorkflow +from nvtabular import Workflow +from nvtabular import ops as wf_ops + +tritonclient = pytest.importorskip("tritonclient") +import tritonclient.grpc.model_config_pb2 as model_config # noqa + + +def test_load_from_config(tmpdir): + rows = 200 + num_features = 16 + X, y = sklearn.datasets.make_regression( + n_samples=rows, + n_features=num_features, + n_informative=num_features // 3, + random_state=0, + ) + model = xgboost.XGBRegressor() + model.fit(X, y) + feature_names = [str(i) for i in range(num_features)] + input_schema = Schema([ColumnSchema(col, dtype=np.float32) for col in feature_names]) + output_schema = Schema([ColumnSchema("output__0", dtype=np.float32)]) + config = PredictForest(model, input_schema).export( + tmpdir, input_schema, output_schema, node_id=2 + ) + node_config = json.loads(config.parameters[config.name].string_value) + + assert json.loads(node_config["output_dict"]) == { + "output__0": {"dtype": "float32", "is_list": False, "is_ragged": False} + } + + cls = PredictForest.from_config(node_config) + assert cls.fil_model_name == "2_fil" + + +def read_config(config_path): + with open(config_path, "rb") as f: + config = model_config.ModelConfig() + raw_config = f.read() + return text_format.Parse(raw_config, config) + + +def test_export(tmpdir): + rows = 200 + num_features = 16 + X, y = sklearn.datasets.make_regression( + n_samples=rows, + n_features=num_features, + n_informative=num_features // 3, + random_state=0, + ) + model = xgboost.XGBRegressor() + model.fit(X, y) + feature_names = [str(i) for i in range(num_features)] + input_schema = Schema([ColumnSchema(col, dtype=np.float32) for col in feature_names]) + output_schema = Schema([ColumnSchema("output__0", dtype=np.float32)]) + _ = PredictForest(model, input_schema).export(tmpdir, input_schema, output_schema, node_id=2) + + config_path = tmpdir / "2_predictforest" / "config.pbtxt" + parsed_config = read_config(config_path) + assert parsed_config.name == "2_predictforest" + assert parsed_config.backend == "python" + + config_path = tmpdir / "2_fil" / "config.pbtxt" + parsed_config = read_config(config_path) + assert parsed_config.name == "2_fil" + assert parsed_config.backend == "fil" + + +def test_ensemble(tmpdir): + rows = 200 + num_features = 16 + X, y = sklearn.datasets.make_regression( + n_samples=rows, + n_features=num_features, + n_informative=num_features // 3, + random_state=0, + ) + feature_names = [str(i) for i in range(num_features)] + df = pd.DataFrame(X, columns=feature_names) + dataset = Dataset(df) + + # Fit GBDT Model + model = xgboost.XGBRegressor() + model.fit(X, y) + + input_schema = Schema([ColumnSchema(col, dtype=np.float32) for col in feature_names]) + selector = ColumnSelector(feature_names) + + workflow_ops = ["0", "1", "2"] >> wf_ops.LogOp() + workflow = Workflow(workflow_ops) + workflow.fit(dataset) + + triton_chain = selector >> TransformWorkflow(workflow) >> PredictForest(model, input_schema) + + triton_ens = Ensemble(triton_chain, input_schema) + + triton_ens.export(tmpdir) + + config_path = tmpdir / "1_predictforest" / "config.pbtxt" + parsed_config = read_config(config_path) + assert parsed_config.name == "1_predictforest" + assert parsed_config.backend == "python" + + config_path = tmpdir / "1_fil" / "config.pbtxt" + parsed_config = read_config(config_path) + assert parsed_config.name == "1_fil" + assert parsed_config.backend == "fil" + + config_path = tmpdir / "0_transformworkflow" / "config.pbtxt" + parsed_config = read_config(config_path) + assert parsed_config.name == "0_transformworkflow" + assert parsed_config.backend == "python" + + config_path = tmpdir / "ensemble_model" / "config.pbtxt" + parsed_config = read_config(config_path) + assert parsed_config.name == "ensemble_model" + assert parsed_config.platform == "ensemble"