Skip to content

Commit

Permalink
Migrate Triton ensemble and serving code from NVTabular (#23)
Browse files Browse the repository at this point in the history
* Migrate Triton ensemble and serving code from NVTabular

* Remove unnecessary dev dependencies
  • Loading branch information
karlhigley committed Mar 14, 2022
1 parent 74f37be commit 71b230f
Show file tree
Hide file tree
Showing 39 changed files with 10,470 additions and 16 deletions.
21 changes: 21 additions & 0 deletions merlin/systems/dag/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# alias submodules here to avoid breaking everything with moving to submodules
# flake8: noqa
from .ensemble import Ensemble
from .node import Node
from .op_runner import OperatorRunner
125 changes: 125 additions & 0 deletions merlin/systems/dag/ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os

from merlin.dag import postorder_iter_nodes

# this needs to be before any modules that import protobuf
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"

from google.protobuf import text_format # noqa

import merlin.systems.triton.model_config_pb2 as model_config # noqa
from merlin.dag import Graph # noqa
from merlin.systems.triton.export import _convert_dtype # noqa


class Ensemble:
def __init__(self, ops, schema, name="ensemble_model", label_columns=None):
self.graph = Graph(ops)
self.graph.construct_schema(schema)
self.name = name
self.label_columns = label_columns or []

def export(self, export_path, version=1):
# Create ensemble config
ensemble_config = model_config.ModelConfig(
name=self.name,
platform="ensemble",
# max_batch_size=configs[0].max_batch_size
)

for col_name, col_schema in self.graph.input_schema.column_schemas.items():
ensemble_config.input.append(
model_config.ModelInput(
name=col_name, data_type=_convert_dtype(col_schema.dtype), dims=[-1, -1]
)
)

for col_name, col_schema in self.graph.output_schema.column_schemas.items():
ensemble_config.output.append(
model_config.ModelOutput(
name=col_name, data_type=_convert_dtype(col_schema.dtype), dims=[-1, -1]
)
)

# Build node id lookup table
postorder_nodes = list(postorder_iter_nodes(self.graph.output_node))

node_idx = 0
node_id_lookup = {}
for node in postorder_nodes:
if node.exportable:
node_id_lookup[node] = node_idx
node_idx += 1

node_configs = []
# Export node configs and add ensemble steps
for node in postorder_nodes:
if node.exportable:
node_id = node_id_lookup.get(node, None)
node_name = f"{node_id}_{node.export_name}"

found = False
for step in ensemble_config.ensemble_scheduling.step:
if step.model_name == node_name:
found = True
if found:
continue

node_config = node.export(export_path, node_id=node_id, version=version)

config_step = model_config.ModelEnsembling.Step(
model_name=node_name, model_version=-1
)

for input_col_name in node.input_schema.column_names:
source = _find_column_source(node.parents_with_dependencies, input_col_name)
source_id = node_id_lookup.get(source, None)
in_suffix = f"_{source_id}" if source_id is not None else ""
config_step.input_map[input_col_name] = input_col_name + in_suffix

for output_col_name in node.output_schema.column_names:
out_suffix = (
f"_{node_id}" if node_id is not None and node_id < node_idx - 1 else ""
)
config_step.output_map[output_col_name] = output_col_name + out_suffix

ensemble_config.ensemble_scheduling.step.append(config_step)
node_configs.append(node_config)

# Write the ensemble config file
ensemble_path = os.path.join(export_path, self.name)
os.makedirs(ensemble_path, exist_ok=True)
os.makedirs(os.path.join(ensemble_path, str(version)), exist_ok=True)

with open(os.path.join(ensemble_path, "config.pbtxt"), "w") as o:
text_format.PrintMessage(ensemble_config, o)

return (ensemble_config, node_configs)


def _find_column_source(upstream_nodes, column_name):
source_node = None
for upstream_node in upstream_nodes:
if column_name in upstream_node.output_columns.names:
source_node = upstream_node
break

if source_node and not source_node.exportable:
return _find_column_source(source_node.parents_with_dependencies, column_name)
else:
return source_node
62 changes: 62 additions & 0 deletions merlin/systems/dag/node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from merlin.dag import Node
from merlin.schema import Schema


class InferenceNode(Node):
def export(self, output_path, node_id=None, version=1):
return self.op.export(
output_path, self.input_schema, self.output_schema, node_id=node_id, version=version
)

@property
def export_name(self):
return self.op.export_name

def match_descendant_dtypes(self, source_node):
self.output_schema = _match_dtypes(source_node.input_schema, self.output_schema)
return self

def match_ancestor_dtypes(self, source_node):
self.input_schema = _match_dtypes(source_node.output_schema, self.input_schema)
return self

def validate_schemas(self, root_schema, strict_dtypes=False):
super().validate_schemas(root_schema, strict_dtypes)

if self.children:
childrens_schema = Schema()
for elem in self.children:
childrens_schema += elem.input_schema

for col_name, col_schema in self.output_schema.column_schemas.items():
sink_col_schema = childrens_schema.get(col_name)

if not sink_col_schema:
raise ValueError(
f"Output column '{col_name}' not detected in any "
f"child inputs for '{self.op.__class__.__name__}'."
)


def _match_dtypes(source_schema, dest_schema):
matched = Schema()
for col_name, col_schema in dest_schema.column_schemas.items():
source_dtype = source_schema.get(col_name, col_schema).dtype
matched[col_name] = col_schema.with_dtype(source_dtype)

return matched
43 changes: 43 additions & 0 deletions merlin/systems/dag/op_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import importlib
import json


class OperatorRunner:
def __init__(self, config, repository="./", version=1, kind=""):
operator_names = self.fetch_json_param(config, "operator_names")
op_configs = [self.fetch_json_param(config, op_name) for op_name in operator_names]

self.operators = []
for op_config in op_configs:
module_name = op_config["module_name"]
class_name = op_config["class_name"]

op_module = importlib.import_module(module_name)
op_class = getattr(op_module, class_name)

operator = op_class.from_config(op_config)
self.operators.append(operator)

def execute(self, tensors):
for operator in self.operators:
tensors = operator.transform(tensors)
return tensors

def fetch_json_param(self, model_config, param_name):
string_value = model_config["parameters"][param_name]["string_value"]
return json.loads(string_value)
15 changes: 15 additions & 0 deletions merlin/systems/dag/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
98 changes: 98 additions & 0 deletions merlin/systems/dag/ops/faiss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import json

import faiss
import numpy as np

from merlin.dag import ColumnSelector
from merlin.schema import ColumnSchema, Schema
from merlin.systems.dag.ops.operator import InferenceDataFrame, PipelineableInferenceOperator


class QueryFaiss(PipelineableInferenceOperator):
def __init__(self, index_path, topk=10):
self.index_path = str(index_path)
self.topk = topk
self._index = None
super().__init__()

@classmethod
def from_config(cls, config):
parameters = json.loads(config.get("params", ""))
index_path = parameters["index_path"]
topk = parameters["topk"]

operator = QueryFaiss(index_path, topk=topk)
operator._index = faiss.read_index(str(index_path))

return operator

def export(self, path, input_schema, output_schema, params=None, node_id=None, version=1):
params = params or {}

# TODO: Copy the index into the export directory

self_params = {
# TODO: Write the (relative) path from inside the export directory
"index_path": self.index_path,
"topk": self.topk,
}
self_params.update(params)
return super().export(path, input_schema, output_schema, self_params, node_id, version)

def transform(self, df: InferenceDataFrame):
user_vector = list(df.tensors.values())[0]

_, indices = self._index.search(user_vector, self.topk)
# distances, indices = self.index.search(user_vector, self.topk)

candidate_ids = np.array(indices).T.astype(np.int32)

return InferenceDataFrame({"candidate_ids": candidate_ids})

def compute_input_schema(
self,
root_schema: Schema,
parents_schema: Schema,
deps_schema: Schema,
selector: ColumnSelector,
) -> Schema:
input_schema = super().compute_input_schema(
root_schema, parents_schema, deps_schema, selector
)
if len(input_schema.column_schemas) > 1:
raise ValueError(
"More than one input has been detected for this node,"
/ f"inputs received: {input_schema.column_names}"
)
return input_schema

def compute_output_schema(
self, input_schema: Schema, col_selector: ColumnSelector, prev_output_schema: Schema = None
) -> Schema:
return Schema(
[
ColumnSchema("candidate_ids", dtype=np.int32),
]
)


def setup_faiss(item_vector, output_path):
index = faiss.IndexFlatL2(item_vector[0].shape[0])
index.add(item_vector)
faiss.write_index(index, str(output_path))

0 comments on commit 71b230f

Please sign in to comment.