Skip to content

Commit

Permalink
Parametrize NumVal with type
Browse files Browse the repository at this point in the history
  • Loading branch information
izeigerman committed Apr 4, 2020
1 parent e13bea3 commit 242165e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
2 changes: 1 addition & 1 deletion m2cgen/assemblers/boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _assemble_tree(self, tree):
if "leaf" in tree:
return ast.NumVal(tree["leaf"])

threshold = ast.NumVal(np.float32(tree["split_condition"]))
threshold = ast.NumVal(tree["split_condition"], dtype=np.float32)
split = tree["split"]
feature_idx = self._feature_name_to_idx.get(split, split)
feature_ref = ast.FeatureRef(feature_idx)
Expand Down
5 changes: 4 additions & 1 deletion m2cgen/ast.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
from enum import Enum


Expand Down Expand Up @@ -30,7 +31,9 @@ class NumExpr(Expr):


class NumVal(NumExpr):
def __init__(self, value):
def __init__(self, value, dtype=None):
if dtype:
value = dtype(value)
self.value = value

def __str__(self):
Expand Down

0 comments on commit 242165e

Please sign in to comment.