Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate Triton ensemble and serving code from NVTabular #23

Merged
merged 2 commits into from
Mar 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
21 changes: 21 additions & 0 deletions merlin/systems/dag/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# alias submodules here to avoid breaking everything with moving to submodules
# flake8: noqa
from .ensemble import Ensemble
from .node import Node
from .op_runner import OperatorRunner
125 changes: 125 additions & 0 deletions merlin/systems/dag/ensemble.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os

from merlin.dag import postorder_iter_nodes

# this needs to be before any modules that import protobuf
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"

from google.protobuf import text_format # noqa

import merlin.systems.triton.model_config_pb2 as model_config # noqa
from merlin.dag import Graph # noqa
from merlin.systems.triton.export import _convert_dtype # noqa


class Ensemble:
def __init__(self, ops, schema, name="ensemble_model", label_columns=None):
self.graph = Graph(ops)
self.graph.construct_schema(schema)
self.name = name
self.label_columns = label_columns or []

def export(self, export_path, version=1):
# Create ensemble config
ensemble_config = model_config.ModelConfig(
name=self.name,
platform="ensemble",
# max_batch_size=configs[0].max_batch_size
)

for col_name, col_schema in self.graph.input_schema.column_schemas.items():
ensemble_config.input.append(
model_config.ModelInput(
name=col_name, data_type=_convert_dtype(col_schema.dtype), dims=[-1, -1]
)
)

for col_name, col_schema in self.graph.output_schema.column_schemas.items():
ensemble_config.output.append(
model_config.ModelOutput(
name=col_name, data_type=_convert_dtype(col_schema.dtype), dims=[-1, -1]
)
)

# Build node id lookup table
postorder_nodes = list(postorder_iter_nodes(self.graph.output_node))

node_idx = 0
node_id_lookup = {}
for node in postorder_nodes:
if node.exportable:
node_id_lookup[node] = node_idx
node_idx += 1

node_configs = []
# Export node configs and add ensemble steps
for node in postorder_nodes:
if node.exportable:
node_id = node_id_lookup.get(node, None)
node_name = f"{node_id}_{node.export_name}"

found = False
for step in ensemble_config.ensemble_scheduling.step:
if step.model_name == node_name:
found = True
if found:
continue

node_config = node.export(export_path, node_id=node_id, version=version)

config_step = model_config.ModelEnsembling.Step(
model_name=node_name, model_version=-1
)

for input_col_name in node.input_schema.column_names:
source = _find_column_source(node.parents_with_dependencies, input_col_name)
source_id = node_id_lookup.get(source, None)
in_suffix = f"_{source_id}" if source_id is not None else ""
config_step.input_map[input_col_name] = input_col_name + in_suffix

for output_col_name in node.output_schema.column_names:
out_suffix = (
f"_{node_id}" if node_id is not None and node_id < node_idx - 1 else ""
)
config_step.output_map[output_col_name] = output_col_name + out_suffix

ensemble_config.ensemble_scheduling.step.append(config_step)
node_configs.append(node_config)

# Write the ensemble config file
ensemble_path = os.path.join(export_path, self.name)
os.makedirs(ensemble_path, exist_ok=True)
os.makedirs(os.path.join(ensemble_path, str(version)), exist_ok=True)

with open(os.path.join(ensemble_path, "config.pbtxt"), "w") as o:
text_format.PrintMessage(ensemble_config, o)

return (ensemble_config, node_configs)


def _find_column_source(upstream_nodes, column_name):
source_node = None
for upstream_node in upstream_nodes:
if column_name in upstream_node.output_columns.names:
source_node = upstream_node
break

if source_node and not source_node.exportable:
return _find_column_source(source_node.parents_with_dependencies, column_name)
else:
return source_node
62 changes: 62 additions & 0 deletions merlin/systems/dag/node.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from merlin.dag import Node
from merlin.schema import Schema


class InferenceNode(Node):
def export(self, output_path, node_id=None, version=1):
return self.op.export(
output_path, self.input_schema, self.output_schema, node_id=node_id, version=version
)

@property
def export_name(self):
return self.op.export_name

def match_descendant_dtypes(self, source_node):
self.output_schema = _match_dtypes(source_node.input_schema, self.output_schema)
return self

def match_ancestor_dtypes(self, source_node):
self.input_schema = _match_dtypes(source_node.output_schema, self.input_schema)
return self

def validate_schemas(self, root_schema, strict_dtypes=False):
super().validate_schemas(root_schema, strict_dtypes)

if self.children:
childrens_schema = Schema()
for elem in self.children:
childrens_schema += elem.input_schema

for col_name, col_schema in self.output_schema.column_schemas.items():
sink_col_schema = childrens_schema.get(col_name)

if not sink_col_schema:
raise ValueError(
f"Output column '{col_name}' not detected in any "
f"child inputs for '{self.op.__class__.__name__}'."
)


def _match_dtypes(source_schema, dest_schema):
matched = Schema()
for col_name, col_schema in dest_schema.column_schemas.items():
source_dtype = source_schema.get(col_name, col_schema).dtype
matched[col_name] = col_schema.with_dtype(source_dtype)

return matched
43 changes: 43 additions & 0 deletions merlin/systems/dag/op_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import importlib
import json


class OperatorRunner:
def __init__(self, config, repository="./", version=1, kind=""):
operator_names = self.fetch_json_param(config, "operator_names")
op_configs = [self.fetch_json_param(config, op_name) for op_name in operator_names]

self.operators = []
for op_config in op_configs:
module_name = op_config["module_name"]
class_name = op_config["class_name"]

op_module = importlib.import_module(module_name)
op_class = getattr(op_module, class_name)

operator = op_class.from_config(op_config)
self.operators.append(operator)

def execute(self, tensors):
for operator in self.operators:
tensors = operator.transform(tensors)
return tensors

def fetch_json_param(self, model_config, param_name):
string_value = model_config["parameters"][param_name]["string_value"]
return json.loads(string_value)
15 changes: 15 additions & 0 deletions merlin/systems/dag/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
98 changes: 98 additions & 0 deletions merlin/systems/dag/ops/faiss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import json

import faiss
import numpy as np

from merlin.dag import ColumnSelector
from merlin.schema import ColumnSchema, Schema
from merlin.systems.dag.ops.operator import InferenceDataFrame, PipelineableInferenceOperator


class QueryFaiss(PipelineableInferenceOperator):
def __init__(self, index_path, topk=10):
self.index_path = str(index_path)
self.topk = topk
self._index = None
super().__init__()

@classmethod
def from_config(cls, config):
parameters = json.loads(config.get("params", ""))
index_path = parameters["index_path"]
topk = parameters["topk"]

operator = QueryFaiss(index_path, topk=topk)
operator._index = faiss.read_index(str(index_path))

return operator

def export(self, path, input_schema, output_schema, params=None, node_id=None, version=1):
params = params or {}

# TODO: Copy the index into the export directory

self_params = {
# TODO: Write the (relative) path from inside the export directory
"index_path": self.index_path,
"topk": self.topk,
}
self_params.update(params)
return super().export(path, input_schema, output_schema, self_params, node_id, version)

def transform(self, df: InferenceDataFrame):
user_vector = list(df.tensors.values())[0]

_, indices = self._index.search(user_vector, self.topk)
# distances, indices = self.index.search(user_vector, self.topk)

candidate_ids = np.array(indices).T.astype(np.int32)

return InferenceDataFrame({"candidate_ids": candidate_ids})

def compute_input_schema(
self,
root_schema: Schema,
parents_schema: Schema,
deps_schema: Schema,
selector: ColumnSelector,
) -> Schema:
input_schema = super().compute_input_schema(
root_schema, parents_schema, deps_schema, selector
)
if len(input_schema.column_schemas) > 1:
raise ValueError(
"More than one input has been detected for this node,"
/ f"inputs received: {input_schema.column_names}"
)
return input_schema

def compute_output_schema(
self, input_schema: Schema, col_selector: ColumnSelector, prev_output_schema: Schema = None
) -> Schema:
return Schema(
[
ColumnSchema("candidate_ids", dtype=np.int32),
]
)


def setup_faiss(item_vector, output_path):
index = faiss.IndexFlatL2(item_vector[0].shape[0])
index.add(item_vector)
faiss.write_index(index, str(output_path))