Skip to content

Commit

Permalink
Merge pull request #46 from SoftwareAG/bugfix/lightgbm
Browse files Browse the repository at this point in the history
Fix for #39 : LightGBM exporter fails when feature name has suffix 'f'
  • Loading branch information
Nirmal-Neel committed Jul 14, 2021
2 parents 60b2d07 + e77cd69 commit 6bd017c
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions nyoka/lgbm/lgb_to_pmml.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
import sys, os
BASE_DIR = os.path.dirname(os.path.dirname(__file__))
sys.path.append(BASE_DIR)
import PMML44 as pml
import nyoka.PMML44 as pml
import nyoka.skl.skl_to_pmml as sklToPmml
import nyoka.xgboost.xgboost_to_pmml as xgboostToPmml
import json
from skl import pre_process as pp
from nyoka.skl import pre_process as pp
from datetime import datetime
from base.constants import *
from nyoka.base.constants import *


def lgb_to_pmml(pipeline, col_names, target_name, pmml_f_name='from_lgbm.pmml',model_name=None, description=None):
Expand Down Expand Up @@ -326,16 +326,16 @@ def create_node(obj, main_node,derived_col_names):
def create_left_node(obj,derived_col_names):
nd = pml.Node()
nd.set_SimplePredicate(
pml.SimplePredicate(field=xgboostToPmml.replace_name_with_derivedColumnNames(derived_col_names[int(obj['split_feature'])],\
derived_col_names), operator=SIMPLE_PREDICATE_OPERATOR.LESS_OR_EQUAL, value="{:.16f}".format(obj['threshold'])))
pml.SimplePredicate(field=derived_col_names[int(obj['split_feature'])],
operator=SIMPLE_PREDICATE_OPERATOR.LESS_OR_EQUAL, value="{:.16f}".format(obj['threshold'])))
create_node(obj['left_child'], nd, derived_col_names)
return nd

def create_right_node(obj,derived_col_names):
nd = pml.Node()
nd.set_SimplePredicate(
pml.SimplePredicate(field=xgboostToPmml.replace_name_with_derivedColumnNames(derived_col_names[int(obj['split_feature'])],\
derived_col_names), operator=SIMPLE_PREDICATE_OPERATOR.GREATER_THAN, value="{:.16f}".format(obj['threshold'])))
pml.SimplePredicate(field=derived_col_names[int(obj['split_feature'])]
, operator=SIMPLE_PREDICATE_OPERATOR.GREATER_THAN, value="{:.16f}".format(obj['threshold'])))
create_node(obj['right_child'], nd, derived_col_names)
return nd

Expand Down

0 comments on commit 6bd017c

Please sign in to comment.