Skip to content

Commit

Permalink
revert new train/test splitting routine
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed Aug 4, 2020
1 parent 87ee455 commit aeb4301
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 211 deletions.
26 changes: 17 additions & 9 deletions m2cgen/assemblers/linear.py
@@ -1,7 +1,7 @@
import numpy as np

from m2cgen import ast
from m2cgen.assemblers import fallback_expressions, utils
from m2cgen.assemblers import utils
from m2cgen.assemblers.base import ModelAssembler


Expand Down Expand Up @@ -95,7 +95,14 @@ def _get_supported_inversed_funs(self):
raise NotImplementedError

def _logit_inversed(self, ast_to_transform):
return fallback_expressions.sigmoid(ast_to_transform)
return utils.div(
ast.NumVal(1.0),
utils.add(
ast.NumVal(1.0),
ast.ExpExpr(
utils.sub(
ast.NumVal(0.0),
ast_to_transform))))

def _power_inversed(self, ast_to_transform):
power = self._get_power()
Expand Down Expand Up @@ -139,15 +146,16 @@ def _cloglog_inversed(self, ast_to_transform):

def _negativebinomial_inversed(self, ast_to_transform):
alpha = self._get_alpha()
res = utils.sub(
ast.NumVal(1.0),
ast.ExpExpr(
utils.sub(
ast.NumVal(0.0),
ast_to_transform)))
return utils.div(
ast.NumVal(-1.0),
utils.mul(ast.NumVal(alpha), res) if alpha != 1.0 else res)
utils.mul(
ast.NumVal(alpha),
utils.sub(
ast.NumVal(1.0),
ast.ExpExpr(
utils.sub(
ast.NumVal(0.0),
ast_to_transform)))))

def _get_power(self):
raise NotImplementedError
Expand Down
86 changes: 43 additions & 43 deletions tests/assemblers/test_lightgbm.py
Expand Up @@ -26,15 +26,15 @@ def test_binary_classification():
ast.FeatureRef(23),
ast.NumVal(868.2000000000002),
ast.CompOpType.GT),
ast.NumVal(0.26400127816506497),
ast.NumVal(0.633133056485969)),
ast.NumVal(0.25986931215073095),
ast.NumVal(0.6237178414050242)),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(22),
ast.NumVal(105.95000000000002),
ast.FeatureRef(7),
ast.NumVal(0.05142),
ast.CompOpType.GT),
ast.NumVal(-0.18744882409486507),
ast.NumVal(0.13458899352064668)),
ast.NumVal(-0.1909605544006228),
ast.NumVal(0.1293965108676673)),
ast.BinNumOpType.ADD),
ast.BinNumOpType.SUB)),
ast.BinNumOpType.ADD),
Expand Down Expand Up @@ -85,17 +85,17 @@ def test_regression():
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.837500000000001),
ast.NumVal(6.918),
ast.CompOpType.GT),
ast.NumVal(23.961356387224317),
ast.NumVal(22.32858336612959)),
ast.NumVal(24.011454621684155),
ast.NumVal(22.289277544391084)),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(9.725000000000003),
ast.NumVal(9.63),
ast.CompOpType.GT),
ast.NumVal(-0.5031712645462916),
ast.NumVal(0.6885501354513913)),
ast.NumVal(-0.49461212269771115),
ast.NumVal(0.7174324413014594)),
ast.BinNumOpType.ADD)

assert utils.cmp_exprs(actual, expected)
Expand All @@ -114,18 +114,18 @@ def test_regression_random_forest():
ast.BinNumExpr(
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(5.200000000000001),
ast.FeatureRef(5),
ast.NumVal(6.954000000000001),
ast.CompOpType.GT),
ast.NumVal(20.206688945020474),
ast.NumVal(38.30000037757679)),
ast.NumVal(37.24347877367631),
ast.NumVal(19.936999995530854)),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.837500000000001),
ast.NumVal(6.971500000000001),
ast.CompOpType.GT),
ast.NumVal(36.40634951405711),
ast.NumVal(19.57067132709245)),
ast.NumVal(38.48600037864964),
ast.NumVal(20.183783757300255)),
ast.BinNumOpType.ADD),
ast.NumVal(0.5),
ast.BinNumOpType.MUL)
Expand Down Expand Up @@ -154,15 +154,15 @@ def test_simple_sigmoid_output_transform():
ast.FeatureRef(12),
ast.NumVal(19.23),
ast.CompOpType.GT),
ast.NumVal(4.0050691250),
ast.NumVal(4.0914737728)),
ast.NumVal(4.0026305187),
ast.NumVal(4.0880438137)),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(15.065),
ast.NumVal(14.895),
ast.CompOpType.GT),
ast.NumVal(-0.0420531079),
ast.NumVal(0.0202891577)),
ast.NumVal(-0.0412703078),
ast.NumVal(0.0208393767)),
ast.BinNumOpType.ADD),
ast.BinNumOpType.SUB)),
ast.BinNumOpType.ADD),
Expand All @@ -188,15 +188,15 @@ def test_log1p_exp_output_transform():
ast.FeatureRef(12),
ast.NumVal(19.23),
ast.CompOpType.GT),
ast.NumVal(0.6623996001),
ast.NumVal(0.6684477608)),
ast.NumVal(0.6623502468),
ast.NumVal(0.6683497987)),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(15.065),
ast.NumVal(15.145),
ast.CompOpType.GT),
ast.NumVal(0.1405782705),
ast.NumVal(0.1453764991)),
ast.NumVal(0.1405181490),
ast.NumVal(0.1453602134)),
ast.BinNumOpType.ADD)))

assert utils.cmp_exprs(actual, expected)
Expand All @@ -216,17 +216,17 @@ def test_maybe_sqr_output_transform():
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(11.655),
ast.NumVal(9.905),
ast.CompOpType.GT),
ast.NumVal(4.5671830654),
ast.NumVal(4.6516575813)),
ast.NumVal(4.5658116817),
ast.NumVal(4.6620790482)),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(9.725),
ast.NumVal(9.77),
ast.CompOpType.GT),
ast.NumVal(-0.0348178434),
ast.NumVal(0.0549301624)),
ast.NumVal(-0.0340889740),
ast.NumVal(0.0543687153)),
ast.BinNumOpType.ADD),
to_reuse=True)

Expand All @@ -251,17 +251,17 @@ def test_exp_output_transform():
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(5),
ast.NumVal(6.8375),
ast.NumVal(6.918),
ast.CompOpType.GT),
ast.NumVal(3.1481886430),
ast.NumVal(3.1123367238)),
ast.NumVal(3.1480683932),
ast.NumVal(3.1101554907)),
ast.IfExpr(
ast.CompExpr(
ast.FeatureRef(12),
ast.NumVal(9.725),
ast.NumVal(9.63),
ast.CompOpType.GT),
ast.NumVal(-0.0113689739),
ast.NumVal(0.0153551274)),
ast.NumVal(-0.0111969636),
ast.NumVal(0.0160298303)),
ast.BinNumOpType.ADD))

assert utils.cmp_exprs(actual, expected)
Expand Down Expand Up @@ -289,8 +289,8 @@ def test_bin_class_sigmoid_output_transform():
ast.FeatureRef(23),
ast.NumVal(868.2),
ast.CompOpType.GT),
ast.NumVal(0.5280025563),
ast.NumVal(1.2662661130)),
ast.NumVal(0.5197386243),
ast.NumVal(1.2474356828)),
ast.BinNumOpType.MUL),
ast.BinNumOpType.SUB)),
ast.BinNumOpType.ADD),
Expand Down

0 comments on commit aeb4301

Please sign in to comment.