File tree Expand file tree Collapse file tree 2 files changed +10
-0
lines changed
jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark Expand file tree Collapse file tree 2 files changed +10
-0
lines changed Original file line number Diff line number Diff line change @@ -22,6 +22,7 @@ import scala.collection.mutable
2222
2323import ml .dmlc .xgboost4j .java .Rabit
2424import ml .dmlc .xgboost4j .scala .{Booster , DMatrix , XGBoost => SXGBoost }
25+ import ml .dmlc .xgboost4j .scala .{EvalTrait , ObjectiveTrait }
2526import ml .dmlc .xgboost4j .scala .spark .params ._
2627import ml .dmlc .xgboost4j .{LabeledPoint => XGBLabeledPoint }
2728
@@ -134,6 +135,10 @@ class XGBoostClassifier (
134135
135136 def setNumEarlyStoppingRounds (value : Int ): this .type = set(numEarlyStoppingRounds, value)
136137
138+ def setCustomObj (value : ObjectiveTrait ): this .type = set(customObj, value)
139+
140+ def setCustomEval (value : EvalTrait ): this .type = set(customEval, value)
141+
137142 // called at the start of fit/train when 'eval_metric' is not defined
138143 private def setupDefaultEvalMetric (): String = {
139144 require(isDefined(objective), " Users must set \' objective\' via xgboostParams." )
Original file line number Diff line number Diff line change @@ -23,6 +23,7 @@ import ml.dmlc.xgboost4j.java.Rabit
2323import ml .dmlc .xgboost4j .{LabeledPoint => XGBLabeledPoint }
2424import ml .dmlc .xgboost4j .scala .spark .params .{DefaultXGBoostParamsReader , _ }
2525import ml .dmlc .xgboost4j .scala .{Booster , DMatrix , XGBoost => SXGBoost }
26+ import ml .dmlc .xgboost4j .scala .{EvalTrait , ObjectiveTrait }
2627
2728import org .apache .hadoop .fs .Path
2829import org .apache .spark .TaskContext
@@ -136,6 +137,10 @@ class XGBoostRegressor (
136137
137138 def setNumEarlyStoppingRounds (value : Int ): this .type = set(numEarlyStoppingRounds, value)
138139
140+ def setCustomObj (value : ObjectiveTrait ): this .type = set(customObj, value)
141+
142+ def setCustomEval (value : EvalTrait ): this .type = set(customEval, value)
143+
139144 // called at the start of fit/train when 'eval_metric' is not defined
140145 private def setupDefaultEvalMetric (): String = {
141146 require(isDefined(objective), " Users must set \' objective\' via xgboostParams." )
You can’t perform that action at this time.
0 commit comments