Skip to content

Commit

Permalink
Fix external memory, gpu_hist and subsampling combination bug. (dmlc#…
Browse files Browse the repository at this point in the history
…7476)

- The error happens because when reading from external memory the batch is
  reassembled for every new iteration. The variable `original_page_` is
  initialized from the first batch, when the constructor of `GradiendBasedSample`
  is called. After iterating through the batches the original memory is not
  accessible, so when trying to access the memory pointed by `original_page_`
  causes an error.

- The solution is instead of accessing data from the `original_page_`, to access
  the data from the first page of the available batch.

fix dmlc#7476
  • Loading branch information
GinkoBalboa authored and trivialfis committed Dec 24, 2021
1 parent 7f399ea commit feebd92
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 22 deletions.
32 changes: 16 additions & 16 deletions src/tree/gpu_hist/gradient_based_sampler.cu
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,10 @@ GradientBasedSample UniformSampling::Sample(common::Span<GradientPair> gpair, DM
return {dmat->Info().num_row_, page_, gpair};
}

ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(EllpackPageImpl const* page,
size_t n_rows,
ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(size_t n_rows,
BatchParam batch_param,
float subsample)
: original_page_(page),
batch_param_(std::move(batch_param)),
: batch_param_(std::move(batch_param)),
subsample_(subsample),
sample_row_index_(n_rows) {}

Expand Down Expand Up @@ -218,15 +216,17 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(common::Span<GradientP
sample_row_index_.begin(),
ClearEmptyRows());

auto batch_iterator = dmat->GetBatches<EllpackPage>(batch_param_);
auto first_page = (*batch_iterator.begin()).Impl();
// Create a new ELLPACK page with empty rows.
page_.reset(); // Release the device memory first before reallocating
page_.reset(new EllpackPageImpl(
batch_param_.gpu_id, original_page_->Cuts(), original_page_->is_dense,
original_page_->row_stride, sample_rows));
batch_param_.gpu_id, first_page->Cuts(), first_page->is_dense,
first_page->row_stride, sample_rows));

// Compact the ELLPACK pages into the single sample page.
thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
for (auto& batch : dmat->GetBatches<EllpackPage>(batch_param_)) {
for (auto& batch : batch_iterator) {
page_->Compact(batch_param_.gpu_id, batch.Impl(), dh::ToSpan(sample_row_index_));
}

Expand Down Expand Up @@ -259,12 +259,10 @@ GradientBasedSample GradientBasedSampling::Sample(common::Span<GradientPair> gpa
}

ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling(
EllpackPageImpl const* page,
size_t n_rows,
BatchParam batch_param,
float subsample)
: original_page_(page),
batch_param_(std::move(batch_param)),
: batch_param_(std::move(batch_param)),
subsample_(subsample),
threshold_(n_rows + 1, 0.0f),
grad_sum_(n_rows, 0.0f),
Expand Down Expand Up @@ -300,15 +298,17 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(common::Span<Gra
sample_row_index_.begin(),
ClearEmptyRows());

auto batch_iterator = dmat->GetBatches<EllpackPage>(batch_param_);
auto first_page = (*batch_iterator.begin()).Impl();
// Create a new ELLPACK page with empty rows.
page_.reset(); // Release the device memory first before reallocating
page_.reset(new EllpackPageImpl(batch_param_.gpu_id, original_page_->Cuts(),
original_page_->is_dense,
original_page_->row_stride, sample_rows));
page_.reset(new EllpackPageImpl(batch_param_.gpu_id, first_page->Cuts(),
first_page->is_dense,
first_page->row_stride, sample_rows));

// Compact the ELLPACK pages into the single sample page.
thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
for (auto& batch : dmat->GetBatches<EllpackPage>(batch_param_)) {
for (auto& batch : batch_iterator) {
page_->Compact(batch_param_.gpu_id, batch.Impl(), dh::ToSpan(sample_row_index_));
}

Expand All @@ -329,15 +329,15 @@ GradientBasedSampler::GradientBasedSampler(EllpackPageImpl const* page,
switch (sampling_method) {
case TrainParam::kUniform:
if (is_external_memory) {
strategy_.reset(new ExternalMemoryUniformSampling(page, n_rows, batch_param, subsample));
strategy_.reset(new ExternalMemoryUniformSampling(n_rows, batch_param, subsample));
} else {
strategy_.reset(new UniformSampling(page, subsample));
}
break;
case TrainParam::kGradientBased:
if (is_external_memory) {
strategy_.reset(
new ExternalMemoryGradientBasedSampling(page, n_rows, batch_param, subsample));
new ExternalMemoryGradientBasedSampling(n_rows, batch_param, subsample));
} else {
strategy_.reset(new GradientBasedSampling(page, n_rows, batch_param, subsample));
}
Expand Down
8 changes: 2 additions & 6 deletions src/tree/gpu_hist/gradient_based_sampler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,12 @@ class UniformSampling : public SamplingStrategy {
/*! \brief No sampling in external memory mode. */
class ExternalMemoryUniformSampling : public SamplingStrategy {
public:
ExternalMemoryUniformSampling(EllpackPageImpl const* page,
size_t n_rows,
ExternalMemoryUniformSampling(size_t n_rows,
BatchParam batch_param,
float subsample);
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;

private:
EllpackPageImpl const* original_page_;
BatchParam batch_param_;
float subsample_;
std::unique_ptr<EllpackPageImpl> page_;
Expand All @@ -100,14 +98,12 @@ class GradientBasedSampling : public SamplingStrategy {
/*! \brief Gradient-based sampling in external memory mode.. */
class ExternalMemoryGradientBasedSampling : public SamplingStrategy {
public:
ExternalMemoryGradientBasedSampling(EllpackPageImpl const* page,
size_t n_rows,
ExternalMemoryGradientBasedSampling(size_t n_rows,
BatchParam batch_param,
float subsample);
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;

private:
EllpackPageImpl const* original_page_;
BatchParam batch_param_;
float subsample_;
dh::caching_device_vector<float> threshold_;
Expand Down

0 comments on commit feebd92

Please sign in to comment.