-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Migrate Triton ensemble and serving code from NVTabular (#23)
* Migrate Triton ensemble and serving code from NVTabular * Remove unnecessary dev dependencies
- Loading branch information
1 parent
74f37be
commit 71b230f
Showing
39 changed files
with
10,470 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# | ||
# Copyright (c) 2022, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
# alias submodules here to avoid breaking everything with moving to submodules | ||
# flake8: noqa | ||
from .ensemble import Ensemble | ||
from .node import Node | ||
from .op_runner import OperatorRunner |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# | ||
# Copyright (c) 2022, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
import os | ||
|
||
from merlin.dag import postorder_iter_nodes | ||
|
||
# this needs to be before any modules that import protobuf | ||
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" | ||
|
||
from google.protobuf import text_format # noqa | ||
|
||
import merlin.systems.triton.model_config_pb2 as model_config # noqa | ||
from merlin.dag import Graph # noqa | ||
from merlin.systems.triton.export import _convert_dtype # noqa | ||
|
||
|
||
class Ensemble: | ||
def __init__(self, ops, schema, name="ensemble_model", label_columns=None): | ||
self.graph = Graph(ops) | ||
self.graph.construct_schema(schema) | ||
self.name = name | ||
self.label_columns = label_columns or [] | ||
|
||
def export(self, export_path, version=1): | ||
# Create ensemble config | ||
ensemble_config = model_config.ModelConfig( | ||
name=self.name, | ||
platform="ensemble", | ||
# max_batch_size=configs[0].max_batch_size | ||
) | ||
|
||
for col_name, col_schema in self.graph.input_schema.column_schemas.items(): | ||
ensemble_config.input.append( | ||
model_config.ModelInput( | ||
name=col_name, data_type=_convert_dtype(col_schema.dtype), dims=[-1, -1] | ||
) | ||
) | ||
|
||
for col_name, col_schema in self.graph.output_schema.column_schemas.items(): | ||
ensemble_config.output.append( | ||
model_config.ModelOutput( | ||
name=col_name, data_type=_convert_dtype(col_schema.dtype), dims=[-1, -1] | ||
) | ||
) | ||
|
||
# Build node id lookup table | ||
postorder_nodes = list(postorder_iter_nodes(self.graph.output_node)) | ||
|
||
node_idx = 0 | ||
node_id_lookup = {} | ||
for node in postorder_nodes: | ||
if node.exportable: | ||
node_id_lookup[node] = node_idx | ||
node_idx += 1 | ||
|
||
node_configs = [] | ||
# Export node configs and add ensemble steps | ||
for node in postorder_nodes: | ||
if node.exportable: | ||
node_id = node_id_lookup.get(node, None) | ||
node_name = f"{node_id}_{node.export_name}" | ||
|
||
found = False | ||
for step in ensemble_config.ensemble_scheduling.step: | ||
if step.model_name == node_name: | ||
found = True | ||
if found: | ||
continue | ||
|
||
node_config = node.export(export_path, node_id=node_id, version=version) | ||
|
||
config_step = model_config.ModelEnsembling.Step( | ||
model_name=node_name, model_version=-1 | ||
) | ||
|
||
for input_col_name in node.input_schema.column_names: | ||
source = _find_column_source(node.parents_with_dependencies, input_col_name) | ||
source_id = node_id_lookup.get(source, None) | ||
in_suffix = f"_{source_id}" if source_id is not None else "" | ||
config_step.input_map[input_col_name] = input_col_name + in_suffix | ||
|
||
for output_col_name in node.output_schema.column_names: | ||
out_suffix = ( | ||
f"_{node_id}" if node_id is not None and node_id < node_idx - 1 else "" | ||
) | ||
config_step.output_map[output_col_name] = output_col_name + out_suffix | ||
|
||
ensemble_config.ensemble_scheduling.step.append(config_step) | ||
node_configs.append(node_config) | ||
|
||
# Write the ensemble config file | ||
ensemble_path = os.path.join(export_path, self.name) | ||
os.makedirs(ensemble_path, exist_ok=True) | ||
os.makedirs(os.path.join(ensemble_path, str(version)), exist_ok=True) | ||
|
||
with open(os.path.join(ensemble_path, "config.pbtxt"), "w") as o: | ||
text_format.PrintMessage(ensemble_config, o) | ||
|
||
return (ensemble_config, node_configs) | ||
|
||
|
||
def _find_column_source(upstream_nodes, column_name): | ||
source_node = None | ||
for upstream_node in upstream_nodes: | ||
if column_name in upstream_node.output_columns.names: | ||
source_node = upstream_node | ||
break | ||
|
||
if source_node and not source_node.exportable: | ||
return _find_column_source(source_node.parents_with_dependencies, column_name) | ||
else: | ||
return source_node |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# | ||
# Copyright (c) 2022, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
from merlin.dag import Node | ||
from merlin.schema import Schema | ||
|
||
|
||
class InferenceNode(Node): | ||
def export(self, output_path, node_id=None, version=1): | ||
return self.op.export( | ||
output_path, self.input_schema, self.output_schema, node_id=node_id, version=version | ||
) | ||
|
||
@property | ||
def export_name(self): | ||
return self.op.export_name | ||
|
||
def match_descendant_dtypes(self, source_node): | ||
self.output_schema = _match_dtypes(source_node.input_schema, self.output_schema) | ||
return self | ||
|
||
def match_ancestor_dtypes(self, source_node): | ||
self.input_schema = _match_dtypes(source_node.output_schema, self.input_schema) | ||
return self | ||
|
||
def validate_schemas(self, root_schema, strict_dtypes=False): | ||
super().validate_schemas(root_schema, strict_dtypes) | ||
|
||
if self.children: | ||
childrens_schema = Schema() | ||
for elem in self.children: | ||
childrens_schema += elem.input_schema | ||
|
||
for col_name, col_schema in self.output_schema.column_schemas.items(): | ||
sink_col_schema = childrens_schema.get(col_name) | ||
|
||
if not sink_col_schema: | ||
raise ValueError( | ||
f"Output column '{col_name}' not detected in any " | ||
f"child inputs for '{self.op.__class__.__name__}'." | ||
) | ||
|
||
|
||
def _match_dtypes(source_schema, dest_schema): | ||
matched = Schema() | ||
for col_name, col_schema in dest_schema.column_schemas.items(): | ||
source_dtype = source_schema.get(col_name, col_schema).dtype | ||
matched[col_name] = col_schema.with_dtype(source_dtype) | ||
|
||
return matched |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# | ||
# Copyright (c) 2022, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
import importlib | ||
import json | ||
|
||
|
||
class OperatorRunner: | ||
def __init__(self, config, repository="./", version=1, kind=""): | ||
operator_names = self.fetch_json_param(config, "operator_names") | ||
op_configs = [self.fetch_json_param(config, op_name) for op_name in operator_names] | ||
|
||
self.operators = [] | ||
for op_config in op_configs: | ||
module_name = op_config["module_name"] | ||
class_name = op_config["class_name"] | ||
|
||
op_module = importlib.import_module(module_name) | ||
op_class = getattr(op_module, class_name) | ||
|
||
operator = op_class.from_config(op_config) | ||
self.operators.append(operator) | ||
|
||
def execute(self, tensors): | ||
for operator in self.operators: | ||
tensors = operator.transform(tensors) | ||
return tensors | ||
|
||
def fetch_json_param(self, model_config, param_name): | ||
string_value = model_config["parameters"][param_name]["string_value"] | ||
return json.loads(string_value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# | ||
# Copyright (c) 2022, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# | ||
# Copyright (c) 2022, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import json | ||
|
||
import faiss | ||
import numpy as np | ||
|
||
from merlin.dag import ColumnSelector | ||
from merlin.schema import ColumnSchema, Schema | ||
from merlin.systems.dag.ops.operator import InferenceDataFrame, PipelineableInferenceOperator | ||
|
||
|
||
class QueryFaiss(PipelineableInferenceOperator): | ||
def __init__(self, index_path, topk=10): | ||
self.index_path = str(index_path) | ||
self.topk = topk | ||
self._index = None | ||
super().__init__() | ||
|
||
@classmethod | ||
def from_config(cls, config): | ||
parameters = json.loads(config.get("params", "")) | ||
index_path = parameters["index_path"] | ||
topk = parameters["topk"] | ||
|
||
operator = QueryFaiss(index_path, topk=topk) | ||
operator._index = faiss.read_index(str(index_path)) | ||
|
||
return operator | ||
|
||
def export(self, path, input_schema, output_schema, params=None, node_id=None, version=1): | ||
params = params or {} | ||
|
||
# TODO: Copy the index into the export directory | ||
|
||
self_params = { | ||
# TODO: Write the (relative) path from inside the export directory | ||
"index_path": self.index_path, | ||
"topk": self.topk, | ||
} | ||
self_params.update(params) | ||
return super().export(path, input_schema, output_schema, self_params, node_id, version) | ||
|
||
def transform(self, df: InferenceDataFrame): | ||
user_vector = list(df.tensors.values())[0] | ||
|
||
_, indices = self._index.search(user_vector, self.topk) | ||
# distances, indices = self.index.search(user_vector, self.topk) | ||
|
||
candidate_ids = np.array(indices).T.astype(np.int32) | ||
|
||
return InferenceDataFrame({"candidate_ids": candidate_ids}) | ||
|
||
def compute_input_schema( | ||
self, | ||
root_schema: Schema, | ||
parents_schema: Schema, | ||
deps_schema: Schema, | ||
selector: ColumnSelector, | ||
) -> Schema: | ||
input_schema = super().compute_input_schema( | ||
root_schema, parents_schema, deps_schema, selector | ||
) | ||
if len(input_schema.column_schemas) > 1: | ||
raise ValueError( | ||
"More than one input has been detected for this node," | ||
/ f"inputs received: {input_schema.column_names}" | ||
) | ||
return input_schema | ||
|
||
def compute_output_schema( | ||
self, input_schema: Schema, col_selector: ColumnSelector, prev_output_schema: Schema = None | ||
) -> Schema: | ||
return Schema( | ||
[ | ||
ColumnSchema("candidate_ids", dtype=np.int32), | ||
] | ||
) | ||
|
||
|
||
def setup_faiss(item_vector, output_path): | ||
index = faiss.IndexFlatL2(item_vector[0].shape[0]) | ||
index.add(item_vector) | ||
faiss.write_index(index, str(output_path)) |
Oops, something went wrong.