Skip to content

Commit

Permalink
Use NumVal with specified 'dtype' instead of converting the threshold…
Browse files Browse the repository at this point in the history
… value in place when assembling a sklearn tree (#190)
  • Loading branch information
izeigerman committed May 25, 2020
1 parent ca8bf0e commit 8b3f657
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions m2cgen/assemblers/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,6 @@ def _assemble_cond(self, node_id):
# sklearn's trees internally work with float32 numbers, so in order
# to have consistent results across all supported languages, we convert
# all thresholds into float32.
threshold = threshold.astype(np.float32)
threshold_num_val = ast.NumVal(threshold, dtype=np.float32)

return utils.lte(ast.FeatureRef(feature_idx), ast.NumVal(threshold))
return utils.lte(ast.FeatureRef(feature_idx), threshold_num_val)

0 comments on commit 8b3f657

Please sign in to comment.