Skip to content

Commit

Permalink
Support not only task-default objective functions in LightGBM (#221)
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed Jul 6, 2020
1 parent fe9fb04 commit 65f8dbf
Show file tree
Hide file tree
Showing 9 changed files with 405 additions and 86 deletions.
98 changes: 90 additions & 8 deletions m2cgen/assemblers/boosting.py
Expand Up @@ -13,7 +13,7 @@ class BaseBoostingAssembler(ModelAssembler):
classifier_names = {}
multiclass_params_seq_len = 1

def __init__(self, model, estimator_params, base_score=0):
def __init__(self, model, estimator_params, base_score=0.0):
super().__init__(model)
self._all_estimator_params = estimator_params
self._base_score = base_score
Expand All @@ -36,11 +36,12 @@ def assemble(self):
return self._assemble_multi_class_output(
self._all_estimator_params)
else:
return self._assemble_single_output(
result_ast = self._assemble_single_output(
self._all_estimator_params, base_score=self._base_score)
return self._single_convert_output(result_ast)

def _assemble_single_output(self, estimator_params,
base_score=0, split_idx=0):
base_score=0.0, split_idx=0):
estimators_ast = self._assemble_estimators(estimator_params, split_idx)

tmp_ast = utils.apply_op_to_expressions(
Expand Down Expand Up @@ -71,21 +72,21 @@ def _assemble_multi_class_output(self, estimator_params):
for i, e in enumerate(splits)
]

proba_exprs = fallback_expressions.softmax(exprs)
proba_exprs = self._multi_class_convert_output(exprs)
return ast.VectorVal(proba_exprs)

def _assemble_bin_class_output(self, estimator_params):
# Base score is calculated based on
# https://github.com/dmlc/xgboost/blob/8de7f1928e4815843fbf8773a5ac7ecbc37b2e15/src/objective/regression_loss.h#L91
# return -logf(1.0f / base_score - 1.0f);
base_score = 0
if self._base_score != 0:
base_score = 0.0
if self._base_score != 0.0:
base_score = -math.log(1.0 / self._base_score - 1.0)

expr = self._assemble_single_output(
estimator_params, base_score=base_score)

proba_expr = fallback_expressions.sigmoid(expr, to_reuse=True)
proba_expr = self._bin_class_convert_output(expr)

return ast.VectorVal([
ast.BinNumExpr(ast.NumVal(1), proba_expr, ast.BinNumOpType.SUB),
Expand All @@ -95,13 +96,22 @@ def _assemble_bin_class_output(self, estimator_params):
def _final_transform(self, ast_to_transform):
return ast_to_transform

def _multi_class_convert_output(self, exprs):
return fallback_expressions.softmax(exprs)

def _bin_class_convert_output(self, expr, to_reuse=True):
return fallback_expressions.sigmoid(expr, to_reuse=to_reuse)

def _single_convert_output(self, expr):
return expr

def _assemble_estimators(self, estimator_params, split_idx):
raise NotImplementedError


class BaseTreeBoostingAssembler(BaseBoostingAssembler):

def __init__(self, model, trees, base_score=0, tree_limit=None):
def __init__(self, model, trees, base_score=0.0, tree_limit=None):
super().__init__(model, trees, base_score=base_score)
assert tree_limit is None or tree_limit > 0, "Unexpected tree limit"
self._tree_limit = tree_limit
Expand Down Expand Up @@ -212,6 +222,9 @@ def __init__(self, model):

self.n_iter = len(trees) // model_dump["num_tree_per_iteration"]
self.average_output = model_dump.get("average_output", False)
self.objective_config_parts = model_dump.get(
"objective", "custom").split(" ")
self.objective_name = self.objective_config_parts[0]

super().__init__(model, trees)

Expand All @@ -225,6 +238,75 @@ def _final_transform(self, ast_to_transform):
else:
return super()._final_transform(ast_to_transform)

def _multi_class_convert_output(self, exprs):
supported_objectives = {
"multiclass": super()._multi_class_convert_output,
"multiclassova": self._multi_class_sigmoid_transform,
"custom": super()._single_convert_output,
}
if self.objective_name not in supported_objectives:
raise ValueError(
f"Unsupported objective function '{self.objective_name}'")
return supported_objectives[self.objective_name](exprs)

def _multi_class_sigmoid_transform(self, exprs):
return [self._bin_class_sigmoid_transform(expr, to_reuse=False)
for expr in exprs]

def _bin_class_convert_output(self, expr, to_reuse=True):
supported_objectives = {
"binary": self._bin_class_sigmoid_transform,
"custom": super()._single_convert_output,
}
if self.objective_name not in supported_objectives:
raise ValueError(
f"Unsupported objective function '{self.objective_name}'")
return supported_objectives[self.objective_name](expr)

def _bin_class_sigmoid_transform(self, expr, to_reuse=True):
coef = 1.0
for config_part in self.objective_config_parts:
config_entry = config_part.split(":")
if config_entry[0] == "sigmoid":
coef = np.float64(config_entry[1])
break
return super()._bin_class_convert_output(
utils.mul(ast.NumVal(coef), expr) if coef != 1.0 else expr,
to_reuse=to_reuse)

def _single_convert_output(self, expr):
supported_objectives = {
"cross_entropy": fallback_expressions.sigmoid,
"cross_entropy_lambda": self._log1p_exp_transform,
"regression": self._maybe_sqr_transform,
"regression_l1": self._maybe_sqr_transform,
"huber": super()._single_convert_output,
"fair": self._maybe_sqr_transform,
"poisson": self._exp_transform,
"quantile": self._maybe_sqr_transform,
"mape": self._maybe_sqr_transform,
"gamma": self._exp_transform,
"tweedie": self._exp_transform,
"custom": super()._single_convert_output,
}
if self.objective_name not in supported_objectives:
raise ValueError(
f"Unsupported objective function '{self.objective_name}'")
return supported_objectives[self.objective_name](expr)

def _log1p_exp_transform(self, expr):
return ast.Log1pExpr(ast.ExpExpr(expr))

def _maybe_sqr_transform(self, expr):
if "sqrt" in self.objective_config_parts:
expr = ast.IdExpr(expr, to_reuse=True)
return utils.mul(ast.AbsExpr(expr), expr)
else:
return expr

def _exp_transform(self, expr):
return ast.ExpExpr(expr)

def _assemble_tree(self, tree):
if "leaf_value" in tree:
return ast.NumVal(tree["leaf_value"])
Expand Down
45 changes: 23 additions & 22 deletions m2cgen/interpreters/haskell/log1p.hs
Expand Up @@ -12,28 +12,29 @@ log1p x
m_epsilon = encodeFloat (signif + 1) expo - 1.0
where (signif, expo) = decodeFloat (1.0::Double)
x' = abs x
coeffs = [0.10378693562743769800686267719098e+1,
-0.13364301504908918098766041553133e+0,
0.19408249135520563357926199374750e-1,
-0.30107551127535777690376537776592e-2,
0.48694614797154850090456366509137e-3,
-0.81054881893175356066809943008622e-4,
0.13778847799559524782938251496059e-4,
-0.23802210894358970251369992914935e-5,
0.41640416213865183476391859901989e-6,
-0.73595828378075994984266837031998e-7,
0.13117611876241674949152294345011e-7,
-0.23546709317742425136696092330175e-8,
0.42522773276034997775638052962567e-9,
-0.77190894134840796826108107493300e-10,
0.14075746481359069909215356472191e-10,
-0.25769072058024680627537078627584e-11,
0.47342406666294421849154395005938e-12,
-0.87249012674742641745301263292675e-13,
0.16124614902740551465739833119115e-13,
-0.29875652015665773006710792416815e-14,
0.55480701209082887983041321697279e-15,
-0.10324619158271569595141333961932e-15]
coeffs = [
0.10378693562743769800686267719098e+1,
-0.13364301504908918098766041553133e+0,
0.19408249135520563357926199374750e-1,
-0.30107551127535777690376537776592e-2,
0.48694614797154850090456366509137e-3,
-0.81054881893175356066809943008622e-4,
0.13778847799559524782938251496059e-4,
-0.23802210894358970251369992914935e-5,
0.41640416213865183476391859901989e-6,
-0.73595828378075994984266837031998e-7,
0.13117611876241674949152294345011e-7,
-0.23546709317742425136696092330175e-8,
0.42522773276034997775638052962567e-9,
-0.77190894134840796826108107493300e-10,
0.14075746481359069909215356472191e-10,
-0.25769072058024680627537078627584e-11,
0.47342406666294421849154395005938e-12,
-0.87249012674742641745301263292675e-13,
0.16124614902740551465739833119115e-13,
-0.29875652015665773006710792416815e-14,
0.55480701209082887983041321697279e-15,
-0.10324619158271569595141333961932e-15]
chebyshevBroucke i = fini . foldr step [0, 0, 0]
where
step k [b0, b1, _] = [(k + i * 2 * b0 - b1), b0, b1]
Expand Down
2 changes: 1 addition & 1 deletion m2cgen/interpreters/powershell/log1p.ps1
Expand Up @@ -4,7 +4,7 @@ function Log1p([double] $x) {
if ($x -lt -1.0) { return [double]::NaN }
[double] $xAbs = [math]::Abs($x)
if ($xAbs -lt 0.5 * [double]::Epsilon) { return $x }
if ((($x -gt 0.0) -and ($x -lt 1e-8))
if ((($x -gt 0.0) -and ($x -lt 1e-8)) `
-or (($x -gt -1e-9) -and ($x -lt 0.0))) {
return $x * (1.0 - $x * 0.5)
}
Expand Down
32 changes: 16 additions & 16 deletions m2cgen/interpreters/visual_basic/log1p.bas
Expand Up @@ -43,27 +43,27 @@ Function Log1p(ByVal x As Double) As Double
End If
If xAbs < 0.375 Then
Dim coeffs(22) As Double
coeffs(0) = 0.10378693562743769800686267719098e+1
coeffs(1) = -0.13364301504908918098766041553133e+0
coeffs(2) = 0.19408249135520563357926199374750e-1
coeffs(3) = -0.30107551127535777690376537776592e-2
coeffs(4) = 0.48694614797154850090456366509137e-3
coeffs(5) = -0.81054881893175356066809943008622e-4
coeffs(6) = 0.13778847799559524782938251496059e-4
coeffs(7) = -0.23802210894358970251369992914935e-5
coeffs(8) = 0.41640416213865183476391859901989e-6
coeffs(9) = -0.73595828378075994984266837031998e-7
coeffs(10) = 0.13117611876241674949152294345011e-7
coeffs(0) = 0.10378693562743769800686267719098e+1
coeffs(1) = -0.13364301504908918098766041553133e+0
coeffs(2) = 0.19408249135520563357926199374750e-1
coeffs(3) = -0.30107551127535777690376537776592e-2
coeffs(4) = 0.48694614797154850090456366509137e-3
coeffs(5) = -0.81054881893175356066809943008622e-4
coeffs(6) = 0.13778847799559524782938251496059e-4
coeffs(7) = -0.23802210894358970251369992914935e-5
coeffs(8) = 0.41640416213865183476391859901989e-6
coeffs(9) = -0.73595828378075994984266837031998e-7
coeffs(10) = 0.13117611876241674949152294345011e-7
coeffs(11) = -0.23546709317742425136696092330175e-8
coeffs(12) = 0.42522773276034997775638052962567e-9
coeffs(12) = 0.42522773276034997775638052962567e-9
coeffs(13) = -0.77190894134840796826108107493300e-10
coeffs(14) = 0.14075746481359069909215356472191e-10
coeffs(14) = 0.14075746481359069909215356472191e-10
coeffs(15) = -0.25769072058024680627537078627584e-11
coeffs(16) = 0.47342406666294421849154395005938e-12
coeffs(16) = 0.47342406666294421849154395005938e-12
coeffs(17) = -0.87249012674742641745301263292675e-13
coeffs(18) = 0.16124614902740551465739833119115e-13
coeffs(18) = 0.16124614902740551465739833119115e-13
coeffs(19) = -0.29875652015665773006710792416815e-14
coeffs(20) = 0.55480701209082887983041321697279e-15
coeffs(20) = 0.55480701209082887983041321697279e-15
coeffs(21) = -0.10324619158271569595141333961932e-15
Log1p = x * (1.0 - x * ChebyshevBroucke(x / 0.375, coeffs))
Exit Function
Expand Down

0 comments on commit 65f8dbf

Please sign in to comment.