diff --git a/tests/unit/systems/fil/test_forest.py b/tests/unit/systems/fil/test_forest.py index 22ae1c038..ea4792123 100644 --- a/tests/unit/systems/fil/test_forest.py +++ b/tests/unit/systems/fil/test_forest.py @@ -85,9 +85,9 @@ def test_export(tmpdir): output_schema = Schema([ColumnSchema("output__0", dtype=np.float32)]) _ = PredictForest(model, input_schema).export(tmpdir, input_schema, output_schema, node_id=2) - config_path = tmpdir / "2_forest" / "config.pbtxt" + config_path = tmpdir / "2_predictforest" / "config.pbtxt" parsed_config = read_config(config_path) - assert parsed_config.name == "2_forest" + assert parsed_config.name == "2_predictforest" assert parsed_config.backend == "python" config_path = tmpdir / "2_fil" / "config.pbtxt" @@ -126,9 +126,9 @@ def test_ensemble(tmpdir): triton_ens.export(tmpdir) - config_path = tmpdir / "1_forest" / "config.pbtxt" + config_path = tmpdir / "1_predictforest" / "config.pbtxt" parsed_config = read_config(config_path) - assert parsed_config.name == "1_forest" + assert parsed_config.name == "1_predictforest" assert parsed_config.backend == "python" config_path = tmpdir / "1_fil" / "config.pbtxt"