Skip to content

Commit

Permalink
Switch to the Python Backend for the TransformWorkflow operator (#119)
Browse files Browse the repository at this point in the history
  • Loading branch information
oliverholworthy committed Jun 14, 2022
1 parent 4eee7aa commit c3f7d36
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 7 deletions.
9 changes: 3 additions & 6 deletions merlin/systems/dag/ops/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from merlin.dag import ColumnSelector
from merlin.schema import Schema
from merlin.systems.dag.ops.operator import InferenceOperator
from merlin.systems.triton.export import _generate_nvtabular_config
from merlin.systems.triton.export import generate_nvtabular_model


class TransformWorkflow(InferenceOperator):
Expand Down Expand Up @@ -73,6 +73,7 @@ def __init__(
def compute_output_schema(
self, input_schema: Schema, col_selector: ColumnSelector, prev_output_schema: Schema = None
) -> Schema:
"""Returns output schema of operator"""
return self.workflow.output_schema

def export(self, path, input_schema, output_schema, node_id=None, version=1):
Expand All @@ -84,14 +85,10 @@ def export(self, path, input_schema, output_schema, node_id=None, version=1):
node_export_path = pathlib.Path(path) / node_name
node_export_path.mkdir(parents=True, exist_ok=True)

workflow_export_path = node_export_path / str(version) / "workflow"
modified_workflow.save(str(workflow_export_path))

return _generate_nvtabular_config(
return generate_nvtabular_model(
modified_workflow,
node_name,
node_export_path,
backend="nvtabular",
sparse_max=self.sparse_max,
max_batch_size=self.max_batch_size,
cats=self.cats,
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/systems/test_inference_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,4 +79,4 @@ def test_workflow_op_exports_own_config(tmpdir, dataset, engine):

# The config file contents are correct
assert parsed.name == triton_op.export_name
assert parsed.backend == "nvtabular"
assert parsed.backend == "python"

0 comments on commit c3f7d36

Please sign in to comment.