Skip to content

Commit

Permalink
update early-stop
Browse files Browse the repository at this point in the history
  • Loading branch information
aksnzhy committed Nov 10, 2018
1 parent 068b8dc commit bef24ca
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 6 deletions.
4 changes: 2 additions & 2 deletions demo/regression/house_price/run_house_no_cv.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

# Training task:
# -s : 4 (use fm model for regression)
# -x : mae (use MAE metric)
# -x : rmsd (use RMSD metric)
# The model will be stored in house_price_train.txt.model
../../xlearn_train ./house_price_train.txt -s 4 -v ./house_price_test.txt -x mae
../../xlearn_train ./house_price_train.txt -s 4 -v ./house_price_test.txt -x rmsd
# Prediction task:
# The output result will be stored in house_price_test.txt.out
../../xlearn_predict ./house_price_test.txt ./house_price_train.txt.model
43 changes: 43 additions & 0 deletions src/loss/metric.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ class Metric {
// Return the metric type.
virtual std::string metric_type() = 0;

// Compare two metric value, which is used in early-stop.
virtual bool cmp(const real_t a, const real_t b) = 0;

protected:
/* Pointer of thread pool */
ThreadPool* pool_;
Expand Down Expand Up @@ -163,6 +166,11 @@ class AccMetric : public Metric {
return "Accuarcy";
}

// Compare two metric value
bool cmp(const real_t a, const real_t b) {
return a >= b ? true : false;
}

protected:
index_t total_example_;
index_t true_pred_;
Expand Down Expand Up @@ -251,6 +259,11 @@ class PrecMetric : public Metric {
return "Precision";
}

// Compare two metric value
bool cmp(const real_t a, const real_t b) {
return a >= b ? true : false;
}

protected:
index_t true_positive_;
index_t false_positive_;
Expand Down Expand Up @@ -339,6 +352,11 @@ class RecallMetric : public Metric {
return "Recall";
}

// Compare two metric value
bool cmp(const real_t a, const real_t b) {
return a >= b ? true : false;
}

protected:
index_t true_positive_;
index_t false_negative_;
Expand Down Expand Up @@ -431,6 +449,11 @@ class F1Metric : public Metric {
return "F1";
}

// Compare two metric value
bool cmp(const real_t a, const real_t b) {
return a >= b ? true : false;
}

protected:
index_t total_example_;
index_t true_positive_;
Expand Down Expand Up @@ -534,6 +557,11 @@ class AUCMetric : public Metric {
return "AUC";
}

// Compare two metric value
bool cmp(const real_t a, const real_t b) {
return a >= b ? true : false;
}

protected:
std::vector<index_t> all_positive_number_;
std::vector<index_t> all_negative_number_;
Expand Down Expand Up @@ -635,6 +663,11 @@ class MAEMetric : public Metric {
return "MAE";
}

// Compare two metric value
bool cmp(const real_t a, const real_t b) {
return a <= b ? true : false;
}

protected:
real_t error_;
index_t total_example_;
Expand Down Expand Up @@ -709,6 +742,11 @@ class MAPEMetric : public Metric {
return "MAPE";
}

// Compare two metric value
bool cmp(const real_t a, const real_t b) {
return a <= b ? true : false;
}

protected:
real_t error_;
index_t total_example_;
Expand Down Expand Up @@ -785,6 +823,11 @@ class RMSDMetric : public Metric {
return "RMSD";
}

// Compare two metric value
bool cmp(const real_t a, const real_t b) {
return a <= b ? true : false;
}

protected:
real_t error_;
index_t total_example_;
Expand Down
30 changes: 26 additions & 4 deletions src/solver/trainer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,28 @@ void Trainer::train(std::vector<Reader*>& train_reader,
std::vector<Reader*>& test_reader) {
int best_epoch = 0;
int stop_window = 0;
real_t best_result = metric_ == nullptr ? kFloatMax : kFloatMin;
real_t prev_result = metric_ == nullptr ? kFloatMin : kFloatMax;
real_t best_result = 0;
real_t prev_result = 0;
if (metric_ == nullptr) {
best_result = kFloatMax;
prev_result = kFloatMin;
} else {
std::string metric_type = metric_->metric_type();
// Classification
if (metric_type.compare("Accuarcy") == 0 ||
metric_type.compare("Precision") == 0 ||
metric_type.compare("Recall") == 0 ||
metric_type.compare("F1") == 0 ||
metric_type.compare("AUC") == 0) {
best_result = kFloatMin;
prev_result = kFloatMax;
} else if (metric_type.compare("MAE") == 0 ||
metric_type.compare("MAPE") == 0 ||
metric_type.compare("RMSD") == 0) { // regression
best_result = kFloatMax;
prev_result = kFloatMin;
}
}
MetricInfo te_info;
// Show header info
if (!quiet_) {
Expand All @@ -190,14 +210,16 @@ void Trainer::train(std::vector<Reader*>& train_reader,
// Early-stopping
if (early_stop_) {
if ((metric_ == nullptr && te_info.loss_val <= best_result) ||
(metric_ != nullptr && te_info.metric_val >= best_result)) {
(metric_ != nullptr && metric_->cmp(te_info.metric_val,
best_result))) {
best_result = metric_ == nullptr ?
te_info.loss_val : te_info.metric_val;
best_epoch = n;
model_->SetBestModel();
}
if ((metric_ == nullptr && te_info.loss_val > prev_result) ||
(metric_ != nullptr && te_info.metric_val < prev_result)) {
(metric_ != nullptr && !metric_->cmp(te_info.metric_val,
prev_result))) {
// If the validation loss goes up conntinuously
// in stop_window epoch, we stop training
if (stop_window == stop_window_) { break; }
Expand Down

0 comments on commit bef24ca

Please sign in to comment.