Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature weights (#5962) #3

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions amalgamation/xgboost-all0.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
#include "../src/learner.cc"
#include "../src/logging.cc"
#include "../src/common/common.cc"
#include "../src/common/random.cc"
#include "../src/common/charconv.cc"
#include "../src/common/timer.cc"
#include "../src/common/host_device_vector.cc"
Expand Down
49 changes: 49 additions & 0 deletions demo/guide-python/feature_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
'''Using feature weight to change column sampling.

.. versionadded:: 1.3.0
'''

import numpy as np
import xgboost
from matplotlib import pyplot as plt
import argparse


def main(args):
rng = np.random.RandomState(1994)

kRows = 1000
kCols = 10

X = rng.randn(kRows, kCols)
y = rng.randn(kRows)
fw = np.ones(shape=(kCols,))
for i in range(kCols):
fw[i] *= float(i)

dtrain = xgboost.DMatrix(X, y)
dtrain.set_info(feature_weights=fw)

bst = xgboost.train({'tree_method': 'hist',
'colsample_bynode': 0.5},
dtrain, num_boost_round=10,
evals=[(dtrain, 'd')])
featue_map = bst.get_fscore()
# feature zero has 0 weight
assert featue_map.get('f0', None) is None
assert max(featue_map.values()) == featue_map.get('f9')

if args.plot:
xgboost.plot_importance(bst)
plt.show()


if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--plot',
type=int,
default=1,
help='Set to 0 to disable plotting the evaluation history.')
args = parser.parse_args()
main(args)
2 changes: 1 addition & 1 deletion demo/json-model/json_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __str__(self):

class Model:
'''Gradient boosted tree model.'''
def __init__(self, m: dict):
def __init__(self, model: dict):
'''Construct the Model from JSON object.

parameters
Expand Down
9 changes: 6 additions & 3 deletions doc/parameter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ Parameters for Tree Booster
'colsample_bynode':0.5}`` with 64 features will leave 8 features to choose from at
each split.

On Python interface, one can set the ``feature_weights`` for DMatrix to define the
probability of each feature being selected when using column sampling. There's a
similar parameter for ``fit`` method in sklearn interface.

* ``lambda`` [default=1, alias: ``reg_lambda``]

- L2 regularization term on weights. Increasing this value will make model more conservative.
Expand Down Expand Up @@ -225,9 +229,8 @@ Parameters for Tree Booster
list is a group of indices of features that are allowed to interact with each other.
See tutorial for more information

Additional parameters for `hist` and 'gpu_hist' tree method
================================================

Additional parameters for ``hist`` and ``gpu_hist`` tree method
================================================================
* ``single_precision_histogram``, [default=``false``]

- Use single precision to build histograms instead of double precision.
Expand Down
28 changes: 28 additions & 0 deletions include/xgboost/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,34 @@ XGB_DLL int XGDMatrixGetStrFeatureInfo(DMatrixHandle handle, const char *field,
bst_ulong *size,
const char ***out_features);

/*!
* \brief Set meta info from dense matrix. Valid field names are:
*
* - label
* - weight
* - base_margin
* - group
* - label_lower_bound
* - label_upper_bound
* - feature_weights
*
* \param handle An instance of data matrix
* \param field Feild name
* \param data Pointer to consecutive memory storing data.
* \param size Size of the data, this is relative to size of type. (Meaning NOT number
* of bytes.)
* \param type Indicator of data type. This is defined in xgboost::DataType enum class.
*
* float = 1
* double = 2
* uint32_t = 3
* uint64_t = 4
*
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGDMatrixSetDenseInfo(DMatrixHandle handle, const char *field,
void *data, bst_ulong size, int type);

/*!
* \brief (deprecated) Use XGDMatrixSetUIntInfo instead. Set group of the training matrix
* \param handle a instance of data matrix
Expand Down
29 changes: 6 additions & 23 deletions include/xgboost/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,34 +89,17 @@ class MetaInfo {
* \brief Type of each feature. Automatically set when feature_type_names is specifed.
*/
HostDeviceVector<FeatureType> feature_types;
/*
* \brief Weight of each feature, used to define the probability of each feature being
* selected when using column sampling.
*/
HostDeviceVector<float> feature_weigths;

/*! \brief default constructor */
MetaInfo() = default;
MetaInfo(MetaInfo&& that) = default;
MetaInfo& operator=(MetaInfo&& that) = default;
MetaInfo& operator=(MetaInfo const& that) {
this->num_row_ = that.num_row_;
this->num_col_ = that.num_col_;
this->num_nonzero_ = that.num_nonzero_;

this->labels_.Resize(that.labels_.Size());
this->labels_.Copy(that.labels_);

this->group_ptr_ = that.group_ptr_;

this->weights_.Resize(that.weights_.Size());
this->weights_.Copy(that.weights_);

this->base_margin_.Resize(that.base_margin_.Size());
this->base_margin_.Copy(that.base_margin_);

this->labels_lower_bound_.Resize(that.labels_lower_bound_.Size());
this->labels_lower_bound_.Copy(that.labels_lower_bound_);

this->labels_upper_bound_.Resize(that.labels_upper_bound_.Size());
this->labels_upper_bound_.Copy(that.labels_upper_bound_);
return *this;
}
MetaInfo& operator=(MetaInfo const& that) = delete;

/*!
* \brief Validate all metainfo.
Expand Down
2 changes: 1 addition & 1 deletion jvm-packages/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<groupId>ml.dmlc</groupId>
<artifactId>xgboost-jvm_2.12</artifactId>
<version>1.2.4-al</version>
<version>1.2.6-al</version>
<packaging>pom</packaging>
<name>XGBoost JVM Package</name>
<description>JVM Package for XGBoost</description>
Expand Down
8 changes: 4 additions & 4 deletions jvm-packages/xgboost4j-example/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
<parent>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost-jvm_2.12</artifactId>
<version>1.2.4-al</version>
<version>1.2.6-al</version>
</parent>
<artifactId>xgboost4j-example_2.12</artifactId>
<version>1.2.4-al</version>
<version>1.2.6-al</version>
<packaging>jar</packaging>
<build>
<plugins>
Expand All @@ -26,7 +26,7 @@
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j-spark_${scala.binary.version}</artifactId>
<version>1.2.4-al</version>
<version>1.2.6-al</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
Expand All @@ -37,7 +37,7 @@
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j-flink_${scala.binary.version}</artifactId>
<version>1.2.4-al</version>
<version>1.2.6-al</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
Expand Down
6 changes: 3 additions & 3 deletions jvm-packages/xgboost4j-flink/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
<parent>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost-jvm_2.12</artifactId>
<version>1.2.4-al</version>
<version>1.2.6-al</version>
</parent>
<artifactId>xgboost4j-flink_2.12</artifactId>
<version>1.2.4-al</version>
<version>1.2.6-al</version>
<build>
<plugins>
<plugin>
Expand All @@ -26,7 +26,7 @@
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j_${scala.binary.version}</artifactId>
<version>1.2.4-al</version>
<version>1.2.6-al</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
Expand Down
4 changes: 2 additions & 2 deletions jvm-packages/xgboost4j-spark/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<parent>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost-jvm_2.12</artifactId>
<version>1.2.4-al</version>
<version>1.2.6-al</version>
</parent>
<artifactId>xgboost4j-spark_2.12</artifactId>
<build>
Expand All @@ -24,7 +24,7 @@
<dependency>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost4j_${scala.binary.version}</artifactId>
<version>1.2.4-al</version>
<version>1.2.6-al</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s

private val isLocal = sc.isLocal

private val overridedParams = overrideParams(rawParams, sc)
private val overridedParams: Map[String, Any] = overrideParams(rawParams, sc)

/**
* Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true).
Expand Down Expand Up @@ -213,7 +213,8 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
.asInstanceOf[Double]
val seed = overridedParams.getOrElse("seed", System.nanoTime()).asInstanceOf[Long]
val featureWeights = overridedParams.getOrElse(
"feature_weights", new Array[Float](0)).asInstanceOf[Array[Float]]
"feature_weights", new Array[Double](0)).asInstanceOf[Array[Double]]
.map(_.toFloat)
val inputParams = XGBoostExecutionInputParams(trainTestRatio, seed, featureWeights)

val earlyStoppingRounds = overridedParams.getOrElse(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
package ml.dmlc.xgboost4j.scala.spark.params

import scala.collection.immutable.HashSet

import org.apache.spark.ml.param.{DoubleParam, IntParam, BooleanParam, Param, Params}
import org.apache.spark.ml.param.{BooleanParam, DoubleArrayParam, DoubleParam, IntParam, Param, Params}

private[spark] trait BoosterParams extends Params {

Expand Down Expand Up @@ -110,6 +109,15 @@ private[spark] trait BoosterParams extends Params {

final def getSubsample: Double = $(subsample)

/**
* Probability distribution for column sampling. Doesn't have to be normalized
*/
final val featureWeights = new DoubleArrayParam(this, "featureWeights",
"probability distribution " +
"for feature sampling.", (value: Array[Double]) => true)

final def getFeatureWeights: Array[Double] = $(featureWeights)

/**
* subsample ratio of columns when constructing each tree. [default=1] range: (0,1]
*/
Expand Down Expand Up @@ -286,7 +294,8 @@ private[spark] trait BoosterParams extends Params {
subsample -> 1, colsampleBytree -> 1, colsampleBylevel -> 1,
lambda -> 1, alpha -> 0, treeMethod -> "auto", sketchEps -> 0.03,
scalePosWeight -> 1.0, sampleType -> "uniform", normalizeType -> "tree",
rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0, treeLimit -> 0)
rateDrop -> 0.0, skipDrop -> 0.0, lambdaBias -> 0, treeLimit -> 0,
featureWeights -> new Array[Double](0))
}

private[spark] object BoosterParams {
Expand Down
4 changes: 2 additions & 2 deletions jvm-packages/xgboost4j/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
<parent>
<groupId>ml.dmlc</groupId>
<artifactId>xgboost-jvm_2.12</artifactId>
<version>1.2.4-al</version>
<version>1.2.6-al</version>
</parent>
<artifactId>xgboost4j_2.12</artifactId>
<version>1.2.4-al</version>
<version>1.2.6-al</version>
<packaging>jar</packaging>

<dependencies>
Expand Down
7 changes: 6 additions & 1 deletion python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,8 @@ def set_info(self,
label_lower_bound=None,
label_upper_bound=None,
feature_names=None,
feature_types=None):
feature_types=None,
feature_weights=None):
'''Set meta info for DMatrix.'''
if label is not None:
self.set_label(label)
Expand All @@ -473,6 +474,10 @@ def set_info(self,
self.feature_names = feature_names
if feature_types is not None:
self.feature_types = feature_types
if feature_weights is not None:
from .data import dispatch_meta_backend
dispatch_meta_backend(matrix=self, data=feature_weights,
name='feature_weights')

def get_float_info(self, field):
"""Get float property from the DMatrix.
Expand Down
45 changes: 31 additions & 14 deletions python-package/xgboost/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,22 +530,38 @@ def dispatch_data_backend(data, missing, threads,
raise TypeError('Not supported type for data.' + str(type(data)))


def _to_data_type(dtype: str, name: str):
dtype_map = {'float32': 1, 'float64': 2, 'uint32': 3, 'uint64': 4}
if dtype not in dtype_map.keys():
raise TypeError(
f'Expecting float32, float64, uint32, uint64, got {dtype} ' +
f'for {name}.')
return dtype_map[dtype]


def _validate_meta_shape(data):
if hasattr(data, 'shape'):
assert len(data.shape) == 1 or (
len(data.shape) == 2 and
(data.shape[1] == 0 or data.shape[1] == 1))


def _meta_from_numpy(data, field, dtype, handle):
data = _maybe_np_slice(data, dtype)
if dtype == 'uint32':
c_data = c_array(ctypes.c_uint32, data)
_check_call(_LIB.XGDMatrixSetUIntInfo(handle,
c_str(field),
c_array(ctypes.c_uint, data),
c_bst_ulong(len(data))))
elif dtype == 'float':
c_data = c_array(ctypes.c_float, data)
_check_call(_LIB.XGDMatrixSetFloatInfo(handle,
c_str(field),
c_data,
c_bst_ulong(len(data))))
else:
raise TypeError('Unsupported type ' + str(dtype) + ' for:' + field)
interface = data.__array_interface__
assert interface.get('mask', None) is None, 'Masked array is not supported'
size = data.shape[0]

c_type = _to_data_type(str(data.dtype), field)
ptr = interface['data'][0]
ptr = ctypes.c_void_p(ptr)
_check_call(_LIB.XGDMatrixSetDenseInfo(
handle,
c_str(field),
ptr,
c_bst_ulong(size),
c_type
))


def _meta_from_list(data, field, dtype, handle):
Expand Down Expand Up @@ -595,6 +611,7 @@ def _meta_from_dt(data, field, dtype, handle):
def dispatch_meta_backend(matrix: DMatrix, data, name: str, dtype: str = None):
'''Dispatch for meta info.'''
handle = matrix.handle
_validate_meta_shape(data)
if data is None:
return
if _is_list(data):
Expand Down