diff --git a/m2cgen/assemblers/xgboost.py b/m2cgen/assemblers/xgboost.py index 4472f9d8..e83b0150 100644 --- a/m2cgen/assemblers/xgboost.py +++ b/m2cgen/assemblers/xgboost.py @@ -77,6 +77,9 @@ def _assemble_tree(self, tree): feature_idx = self._feature_name_to_idx[tree["split"]] feature_ref = ast.FeatureRef(feature_idx) + # Since comparison with NaN (missing) value always returns false we + # should make sure that the node ID specified in the "missing" field + # always ends up in the "else" branch of the ast.IfExpr. use_lt_comp = tree["missing"] == tree["no"] if use_lt_comp: comp_op = ast.CompOpType.LT diff --git a/m2cgen/interpreters/interpreter.py b/m2cgen/interpreters/interpreter.py index fa256c9c..3997fa27 100644 --- a/m2cgen/interpreters/interpreter.py +++ b/m2cgen/interpreters/interpreter.py @@ -16,6 +16,8 @@ def interpret(self, expr): self._reset_reused_expr_cache() return self._do_interpret(expr) + # Private methods implementing Visitor pattern + def _pre_interpret_hook(self, expr, **kwargs): return None, kwargs