Permalink
Browse files

support OVA multi-classification.

  • Loading branch information...
guolinke committed Mar 29, 2017
1 parent 1419587 commit 841a8987d68811b89b9139b42bac7b5713dbea33
View
@@ -138,7 +138,8 @@ void OverallConfig::GetTaskType(const std::unordered_map<std::string, std::strin
void OverallConfig::CheckParamConflict() {
// check if objective_type, metric_type, and num_class match
- bool objective_type_multiclass = (objective_type == std::string("multiclass"));
+ bool objective_type_multiclass = (objective_type == std::string("multiclass")
+ || objective_type == std::string("multiclassova"));
int num_class_check = boosting_config.num_class;
if (objective_type_multiclass) {
if (num_class_check <= 1) {
@@ -151,11 +152,19 @@ void OverallConfig::CheckParamConflict() {
}
if (boosting_config.is_provide_training_metric || !io_config.valid_data_filenames.empty()) {
for (std::string metric_type : metric_types) {
- bool metric_type_multiclass = (metric_type == std::string("multi_logloss") || metric_type == std::string("multi_error"));
+ bool metric_type_multiclass = (metric_type == std::string("multi_logloss")
+ || metric_type == std::string("multi_error")
+ || metric_type == std::string("multi_loglossova"));
if ((objective_type_multiclass && !metric_type_multiclass)
|| (!objective_type_multiclass && metric_type_multiclass)) {
Log::Fatal("Objective and metrics don't match");
}
+ if (objective_type == std::string("multiclassova") && metric_type == std::string("multi_logloss")) {
+ Log::Fatal("Wrong metric. For Multi-class with OVA, you should use multi_loglossova metric.");
+ }
+ if (objective_type == std::string("multiclass") && metric_type == std::string("multi_loglossova")) {
+ Log::Fatal("Wrong metric. For Multi-class with softmax, you should use multi_logloss metric.");
+ }
}
}
View
@@ -29,7 +29,9 @@ Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config
} else if (type == std::string("map")) {
return new MapMetric(config);
} else if (type == std::string("multi_logloss")) {
- return new MultiLoglossMetric(config);
+ return new MultiSoftmaxLoglossMetric(config);
+ } else if (type == std::string("multi_loglossova")) {
+ return new MultiOVALoglossMetric(config);
} else if (type == std::string("multi_error")) {
return new MultiErrorMetric(config);
}
@@ -79,8 +79,6 @@ class MulticlassMetric: public Metric {
}
private:
- /*! \brief Output frequency */
- int output_freq_;
/*! \brief Number of data */
data_size_t num_data_;
/*! \brief Number of classes */
@@ -116,9 +114,9 @@ class MultiErrorMetric: public MulticlassMetric<MultiErrorMetric> {
};
/*! \brief Logloss for multiclass task */
-class MultiLoglossMetric: public MulticlassMetric<MultiLoglossMetric> {
+class MultiSoftmaxLoglossMetric: public MulticlassMetric<MultiSoftmaxLoglossMetric> {
public:
- explicit MultiLoglossMetric(const MetricConfig& config) :MulticlassMetric<MultiLoglossMetric>(config) {}
+ explicit MultiSoftmaxLoglossMetric(const MetricConfig& config) :MulticlassMetric<MultiSoftmaxLoglossMetric>(config) {}
inline static double LossOnPoint(float label, std::vector<double>& score) {
size_t k = static_cast<size_t>(label);
@@ -135,5 +133,84 @@ class MultiLoglossMetric: public MulticlassMetric<MultiLoglossMetric> {
}
};
+class MultiOVALoglossMetric: public Metric {
+public:
+ explicit MultiOVALoglossMetric(const MetricConfig& config) {
+ num_class_ = config.num_class;
+ sigmoid_ = config.sigmoid;
+ }
+
+ virtual ~MultiOVALoglossMetric() {
+
+ }
+
+ void Init(const Metadata& metadata, data_size_t num_data) override {
+
+ name_.emplace_back("multi_loglossova");
+ num_data_ = num_data;
+ // get label
+ label_ = metadata.label();
+ // get weights
+ weights_ = metadata.weights();
+ if (weights_ == nullptr) {
+ sum_weights_ = static_cast<double>(num_data_);
+ } else {
+ sum_weights_ = 0.0f;
+ for (data_size_t i = 0; i < num_data_; ++i) {
+ sum_weights_ += weights_[i];
+ }
+ }
+ }
+
+ const std::vector<std::string>& GetName() const override {
+ return name_;
+ }
+
+ double factor_to_bigger_better() const override {
+ return -1.0f;
+ }
+
+ std::vector<double> Eval(const double* score) const override {
+ double sum_loss = 0.0;
+ if (weights_ == nullptr) {
+ #pragma omp parallel for schedule(static) reduction(+:sum_loss)
+ for (data_size_t i = 0; i < num_data_; ++i) {
+ std::vector<double> rec(num_class_);
+ size_t idx = static_cast<size_t>(num_data_) * static_cast<int>(label_[i]) + i;
+ double prob = 1.0f / (1.0f + std::exp(-sigmoid_ * score[idx]));
+ if (prob < kEpsilon) { prob = kEpsilon; }
+ // add loss
+ sum_loss += -std::log(prob);
+ }
+ } else {
+ #pragma omp parallel for schedule(static) reduction(+:sum_loss)
+ for (data_size_t i = 0; i < num_data_; ++i) {
+ size_t idx = static_cast<size_t>(num_data_) * static_cast<int>(label_[i]) + i;
+ double prob = 1.0f / (1.0f + std::exp(-sigmoid_ * score[idx]));
+ if (prob < kEpsilon) { prob = kEpsilon; }
+ // add loss
+ sum_loss += -std::log(prob) * weights_[i];
+ }
+ }
+ double loss = sum_loss / sum_weights_;
+ return std::vector<double>(1, loss);
+ }
+
+private:
+ /*! \brief Number of data */
+ data_size_t num_data_;
+ /*! \brief Number of classes */
+ int num_class_;
+ /*! \brief Pointer of label */
+ const float* label_;
+ /*! \brief Pointer of weighs */
+ const float* weights_;
+ /*! \brief Sum weights */
+ double sum_weights_;
+ /*! \brief Name of this test set */
+ std::vector<std::string> name_;
+ double sigmoid_;
+};
+
} // namespace LightGBM
#endif // LightGBM_METRIC_MULTICLASS_METRIC_HPP_
@@ -12,15 +12,21 @@ namespace LightGBM {
*/
class BinaryLogloss: public ObjectiveFunction {
public:
- explicit BinaryLogloss(const ObjectiveConfig& config) {
+ explicit BinaryLogloss(const ObjectiveConfig& config, std::function<bool(float)> is_pos = nullptr) {
is_unbalance_ = config.is_unbalance;
sigmoid_ = static_cast<double>(config.sigmoid);
if (sigmoid_ <= 0.0) {
Log::Fatal("Sigmoid parameter %f should be greater than zero", sigmoid_);
}
scale_pos_weight_ = static_cast<double>(config.scale_pos_weight);
+ is_pos_ = is_pos;
+ if (is_pos_ == nullptr) {
+ is_pos_ = [](float label) {return label > 0; };
+ }
}
+
~BinaryLogloss() {}
+
void Init(const Metadata& metadata, data_size_t num_data) override {
num_data_ = num_data;
label_ = metadata.label();
@@ -30,7 +36,7 @@ class BinaryLogloss: public ObjectiveFunction {
// count for positive and negative samples
#pragma omp parallel for schedule(static) reduction(+:cnt_positive, cnt_negative)
for (data_size_t i = 0; i < num_data_; ++i) {
- if (label_[i] > 0) {
+ if (is_pos_(label_[i])) {
++cnt_positive;
} else {
++cnt_negative;
@@ -61,7 +67,7 @@ class BinaryLogloss: public ObjectiveFunction {
#pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data_; ++i) {
// get label and label weights
- const int is_pos = label_[i] > 0;
+ const int is_pos = is_pos_(label_[i]);
const int label = label_val_[is_pos];
const double label_weight = label_weights_[is_pos];
// calculate gradients and hessians
@@ -74,7 +80,7 @@ class BinaryLogloss: public ObjectiveFunction {
#pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data_; ++i) {
// get label and label weights
- const int is_pos = label_[i] > 0;
+ const int is_pos = is_pos_(label_[i]);
const int label = label_val_[is_pos];
const double label_weight = label_weights_[is_pos];
// calculate gradients and hessians
@@ -106,6 +112,7 @@ class BinaryLogloss: public ObjectiveFunction {
/*! \brief Weights for data */
const float* weights_;
double scale_pos_weight_;
+ std::function<bool(float)> is_pos_;
};
} // namespace LightGBM
@@ -5,19 +5,22 @@
#include <cstring>
#include <cmath>
+#include <vector>
+
+#include "binary_objective.hpp"
namespace LightGBM {
/*!
-* \brief Objective function for multiclass classification
+* \brief Objective function for multiclass classification, use softmax as objective functions
*/
-class MulticlassLogloss: public ObjectiveFunction {
+class MulticlassSoftmax: public ObjectiveFunction {
public:
- explicit MulticlassLogloss(const ObjectiveConfig& config) {
+ explicit MulticlassSoftmax(const ObjectiveConfig& config) {
num_class_ = config.num_class;
- is_unbalance_ = config.is_unbalance;
}
- ~MulticlassLogloss() {
+ ~MulticlassSoftmax() {
+
}
void Init(const Metadata& metadata, data_size_t num_data) override {
@@ -32,18 +35,6 @@ class MulticlassLogloss: public ObjectiveFunction {
Log::Fatal("Label must be in [0, %d), but found %d in label", num_class_, label_int_[i]);
}
}
- label_pos_weights_ = std::vector<float>(num_class_, 1);
- if (is_unbalance_) {
- std::vector<int> cnts(num_class_, 0);
- for (int i = 0; i < num_data_; ++i) {
- ++cnts[label_int_[i]];
- }
- for (int i = 0; i < num_class_; ++i) {
- int cnt_cur = cnts[i];
- int cnt_other = (num_data_ - cnts[i]);
- label_pos_weights_[i] = static_cast<float>(cnt_other) / cnt_cur;
- }
- }
}
void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
@@ -52,7 +43,7 @@ class MulticlassLogloss: public ObjectiveFunction {
#pragma omp parallel for schedule(static) private(rec)
for (data_size_t i = 0; i < num_data_; ++i) {
rec.resize(num_class_);
- for (int k = 0; k < num_class_; ++k){
+ for (int k = 0; k < num_class_; ++k) {
size_t idx = static_cast<size_t>(num_data_) * k + i;
rec[k] = static_cast<double>(score[idx]);
}
@@ -61,20 +52,19 @@ class MulticlassLogloss: public ObjectiveFunction {
auto p = rec[k];
size_t idx = static_cast<size_t>(num_data_) * k + i;
if (label_int_[i] == k) {
- gradients[idx] = static_cast<score_t>(p - 1.0f) * label_pos_weights_[k];
- hessians[idx] = static_cast<score_t>(p * (1.0f - p))* label_pos_weights_[k];
+ gradients[idx] = static_cast<score_t>(p - 1.0f);
} else {
gradients[idx] = static_cast<score_t>(p);
- hessians[idx] = static_cast<score_t>(p * (1.0f - p));
}
+ hessians[idx] = static_cast<score_t>(p * (1.0f - p));
}
}
} else {
std::vector<double> rec;
#pragma omp parallel for schedule(static) private(rec)
for (data_size_t i = 0; i < num_data_; ++i) {
rec.resize(num_class_);
- for (int k = 0; k < num_class_; ++k){
+ for (int k = 0; k < num_class_; ++k) {
size_t idx = static_cast<size_t>(num_data_) * k + i;
rec[k] = static_cast<double>(score[idx]);
}
@@ -83,13 +73,11 @@ class MulticlassLogloss: public ObjectiveFunction {
auto p = rec[k];
size_t idx = static_cast<size_t>(num_data_) * k + i;
if (label_int_[i] == k) {
- gradients[idx] = static_cast<score_t>((p - 1.0f) * weights_[i]) * label_pos_weights_[k];
- hessians[idx] = static_cast<score_t>(p * (1.0f - p) * weights_[i]) * label_pos_weights_[k];
+ gradients[idx] = static_cast<score_t>((p - 1.0f) * weights_[i]);
} else {
gradients[idx] = static_cast<score_t>(p * weights_[i]);
- hessians[idx] = static_cast<score_t>(p * (1.0f - p) * weights_[i]);
}
-
+ hessians[idx] = static_cast<score_t>(p * (1.0f - p) * weights_[i]);
}
}
}
@@ -110,9 +98,49 @@ class MulticlassLogloss: public ObjectiveFunction {
std::vector<int> label_int_;
/*! \brief Weights for data */
const float* weights_;
- /*! \brief Weights for label */
- std::vector<float> label_pos_weights_;
- bool is_unbalance_;
+};
+
+/*!
+* \brief Objective function for multiclass classification, use one-vs-all binary objective function
+*/
+class MulticlassOVA: public ObjectiveFunction {
+public:
+ explicit MulticlassOVA(const ObjectiveConfig& config) {
+ num_class_ = config.num_class;
+ for (int i = 0; i < num_class_; ++i) {
+ binary_loss_.emplace_back(
+ new BinaryLogloss(config, [i](float label) { return static_cast<int>(label) == i; }));
+ }
+ }
+
+ ~MulticlassOVA() {
+
+ }
+
+ void Init(const Metadata& metadata, data_size_t num_data) override {
+ num_data_ = num_data;
+ for (int i = 0; i < num_class_; ++i) {
+ binary_loss_[i]->Init(metadata, num_data);
+ }
+ }
+
+ void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
+ for (int i = 0; i < num_class_; ++i) {
+ int64_t bias = static_cast<int64_t>(num_data_) * i;
+ binary_loss_[i]->GetGradients(score + bias, gradients + bias, hessians + bias);
+ }
+ }
+
+ const char* GetName() const override {
+ return "multiclassova";
+ }
+
+private:
+ /*! \brief Number of data */
+ data_size_t num_data_;
+ /*! \brief Number of classes */
+ int num_class_;
+ std::vector<std::unique_ptr<BinaryLogloss>> binary_loss_;
};
} // namespace LightGBM
@@ -23,7 +23,9 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string&
} else if (type == std::string("lambdarank")) {
return new LambdarankNDCG(config);
} else if (type == std::string("multiclass")) {
- return new MulticlassLogloss(config);
+ return new MulticlassSoftmax(config);
+ } else if (type == std::string("multiclassova")) {
+ return new MulticlassOVA(config);
}
return nullptr;
}

0 comments on commit 841a898

Please sign in to comment.