Skip to content

Commit c004cea

Browse files
yanboliangCodingCat
authored andcommitted
Expose setCustomObj & setCustomEval for XGBoostClassifier & XGBoostRegressor. (dmlc#3486)
1 parent b6dcbf0 commit c004cea

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostClassifier.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import scala.collection.mutable
2222

2323
import ml.dmlc.xgboost4j.java.Rabit
2424
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
25+
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
2526
import ml.dmlc.xgboost4j.scala.spark.params._
2627
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
2728

@@ -134,6 +135,10 @@ class XGBoostClassifier (
134135

135136
def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)
136137

138+
def setCustomObj(value: ObjectiveTrait): this.type = set(customObj, value)
139+
140+
def setCustomEval(value: EvalTrait): this.type = set(customEval, value)
141+
137142
// called at the start of fit/train when 'eval_metric' is not defined
138143
private def setupDefaultEvalMetric(): String = {
139144
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")

jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostRegressor.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import ml.dmlc.xgboost4j.java.Rabit
2323
import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint}
2424
import ml.dmlc.xgboost4j.scala.spark.params.{DefaultXGBoostParamsReader, _}
2525
import ml.dmlc.xgboost4j.scala.{Booster, DMatrix, XGBoost => SXGBoost}
26+
import ml.dmlc.xgboost4j.scala.{EvalTrait, ObjectiveTrait}
2627

2728
import org.apache.hadoop.fs.Path
2829
import org.apache.spark.TaskContext
@@ -136,6 +137,10 @@ class XGBoostRegressor (
136137

137138
def setNumEarlyStoppingRounds(value: Int): this.type = set(numEarlyStoppingRounds, value)
138139

140+
def setCustomObj(value: ObjectiveTrait): this.type = set(customObj, value)
141+
142+
def setCustomEval(value: EvalTrait): this.type = set(customEval, value)
143+
139144
// called at the start of fit/train when 'eval_metric' is not defined
140145
private def setupDefaultEvalMetric(): String = {
141146
require(isDefined(objective), "Users must set \'objective\' via xgboostParams.")

0 commit comments

Comments
 (0)