Skip to content

Commit

Permalink
Fixed the implementation of spanning a data set across multiple
Browse files Browse the repository at this point in the history
models.  The implementation now properly handles the tail of a data
set that needs to be evenly distributed across all instances of the
model.  The code now computes how many complete mini-batches can be
spread across all models and reader and then distributes the remaining
samples equally as partial mini-batches.
  • Loading branch information
bvanessen committed Aug 13, 2016
1 parent 62f4130 commit 6be4b03
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 20 deletions.
87 changes: 80 additions & 7 deletions include/lbann/data_readers/lbann_data_reader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@
#ifndef LBANN_DATA_READER_HPP
#define LBANN_DATA_READER_HPP

#include "lbann_base.hpp"
#include "lbann/lbann_base.hpp"
#include "lbann/utils/lbann_random.hpp"
#include "lbann/utils/lbann_exception.hpp"
#include "lbann/lbann_comm.hpp"
#include <assert.h>
#include <algorithm>
#include <string>
Expand All @@ -48,14 +49,20 @@ namespace lbann
public:
DataReader(int batchSize, bool shuffle) :
BatchSize(batchSize), CurrentPos(0), m_shuffle(shuffle),
m_stride(batchSize), m_base_offset(0), m_model_offset(0) {}
m_stride(batchSize), m_base_offset(0), m_model_offset(0),
m_use_alt_last_mini_batch_size(false),
m_last_mini_batch_threshold(0), m_last_mini_batch_size(batchSize), m_last_mini_batch_stride(0)
{}
DataReader(int batchSize) :
DataReader(batchSize, true) {}

DataReader(const DataReader& source) :
BatchSize(source.BatchSize), CurrentPos(source.CurrentPos), m_shuffle(source.m_shuffle),
m_stride(source.m_stride), m_base_offset(source.m_base_offset), m_model_offset(source.m_model_offset),
ShuffledIndices(source.ShuffledIndices), m_unused_indices(source.m_unused_indices) {}
m_use_alt_last_mini_batch_size(source.m_use_alt_last_mini_batch_size),
m_last_mini_batch_threshold(source.m_last_mini_batch_threshold), m_last_mini_batch_size(source.m_last_mini_batch_size), m_last_mini_batch_stride(source.m_last_mini_batch_stride),
ShuffledIndices(source.ShuffledIndices), m_unused_indices(source.m_unused_indices)
{}

virtual ~DataReader() {}

Expand All @@ -65,11 +72,16 @@ namespace lbann
* If the base offset is not specified set it to 0
* If the stride is not specified set it to batch size
*/
void setup(int base_offset, int stride, int model_offset = 0) {
void setup(int base_offset, int stride, int model_offset = 0, lbann_comm *comm = NULL) {
m_model_offset = model_offset;
m_base_offset = base_offset;
m_stride = stride;

if(comm != NULL) {
calculate_multi_model_data_distribution(comm);
m_use_alt_last_mini_batch_size = true;
}

CurrentPos = m_base_offset + m_model_offset;
if (m_shuffle) {
std::shuffle(ShuffledIndices.begin(), ShuffledIndices.end(),
Expand Down Expand Up @@ -101,7 +113,11 @@ namespace lbann
* around, then reshuffle the data indicies.
*/
virtual bool update() {
CurrentPos += m_stride;
if(m_use_alt_last_mini_batch_size && CurrentPos+m_stride > m_last_mini_batch_threshold) {
CurrentPos += m_last_mini_batch_stride;
}else {
CurrentPos += m_stride;
}
if (CurrentPos < (int)ShuffledIndices.size()) {
return true;
} else {
Expand All @@ -119,9 +135,21 @@ namespace lbann
virtual int get_linearized_label_size() { return 0; }

bool position_valid() { return (CurrentPos < (int)ShuffledIndices.size()); }
int getBatchSize() { return BatchSize; }
int getBatchSize() {
if(m_use_alt_last_mini_batch_size && CurrentPos >= m_last_mini_batch_threshold) {
return m_last_mini_batch_size;
}else {
return BatchSize;
}
}
int getPosition() { return CurrentPos; }
int get_next_position() { return CurrentPos + m_stride; }
int get_next_position() {
if(m_use_alt_last_mini_batch_size && CurrentPos+m_stride > m_last_mini_batch_threshold) {
return CurrentPos + m_last_mini_batch_stride;
}else {
return CurrentPos + m_stride;
}
}
int* getIndices() { return &ShuffledIndices[0]; }
int getNumData() { return (int)ShuffledIndices.size(); }
int get_num_unused_data() { return (int)m_unused_indices.size(); }
Expand Down Expand Up @@ -171,6 +199,11 @@ namespace lbann
this->m_stride = source.m_stride;
this->m_base_offset = source.m_base_offset;
this->m_model_offset = source.m_model_offset;
this->m_use_alt_last_mini_batch_size = source.m_use_alt_last_mini_batch_size;
this->m_last_mini_batch_threshold = source.m_last_mini_batch_threshold;
this->m_last_mini_batch_size = source.m_last_mini_batch_size;
this->m_last_mini_batch_stride = source.m_last_mini_batch_stride;

// Vectors implement a deep copy
this->ShuffledIndices = source.ShuffledIndices;
this->m_unused_indices = source.m_unused_indices;
Expand All @@ -188,6 +221,40 @@ namespace lbann
return getNumData();
}

void calculate_multi_model_data_distribution(lbann_comm *comm) {
int max_mini_batch_size = BatchSize;
int num_parallel_readers_per_model = (m_stride / comm->get_num_models()) / max_mini_batch_size;

int num_whole_mini_batches = rint(getNumData() / m_stride);
int partial_mini_batch_size = (getNumData() - (num_whole_mini_batches*m_stride))/(comm->get_num_models() * num_parallel_readers_per_model);
int world_master_remainder_data = 0;

int world_master_remainder_adjustment = getNumData()
- (num_whole_mini_batches * m_stride)
- (partial_mini_batch_size * comm->get_num_models() * num_parallel_readers_per_model);
if(comm->am_world_master()) {
world_master_remainder_data = world_master_remainder_adjustment;
world_master_remainder_adjustment = 0;
}
partial_mini_batch_size += world_master_remainder_data;

m_last_mini_batch_threshold = m_stride * num_whole_mini_batches;
m_last_mini_batch_size = partial_mini_batch_size;

/// Note that comm->get_model_rank() + comm->get_rank_in_model() is not equivalent to comm->get_world_rank() from a parallel I/O perspective
/// Given the data readers rank, how many readers have a higher rank
int num_readers_at_full_stride = (comm->get_num_models() - comm->get_model_rank()) * num_parallel_readers_per_model - comm->get_rank_in_model();
/// Given the data readers rank, how many readers have a lower rank
int num_readers_at_last_stride = comm->get_model_rank() * num_parallel_readers_per_model + comm->get_rank_in_model();
/// Compute how big the stride should be assuming that each higher ranked parallel reader has completed a full mini-batch
/// and each lower ranked parallel reader has completed a partial mini-batch
m_last_mini_batch_stride = max_mini_batch_size * num_readers_at_full_stride
+ (partial_mini_batch_size * (num_readers_at_last_stride)) + world_master_remainder_adjustment;

// cout << "[" << comm->get_rank_in_world() << "] " << comm->get_model_rank() << " model rank, num_whole_mini_batches=" << num_whole_mini_batches << " partial_mini_batch_size=" << partial_mini_batch_size << " world_master_remainder_data=" << world_master_remainder_data << " threshold " << m_last_mini_batch_threshold << " with a last stride of " << m_last_mini_batch_stride << " and stride of " << m_stride << " and there are " << num_parallel_readers_per_model << " parallel readers per model " <<endl;

return;
}

protected:
int BatchSize;
Expand All @@ -198,6 +265,12 @@ namespace lbann
/// then it may not reset to zero
int m_model_offset; /// If there are multiple models with multiple instances of the reader,
/// each model's set of readers may not reset to zero
/// Provide a set of size, strides, and thresholds to handle the last mini batch of a dataset
bool m_use_alt_last_mini_batch_size;
int m_last_mini_batch_threshold;
int m_last_mini_batch_size;
int m_last_mini_batch_stride;

std::vector<int> ShuffledIndices;
std::vector<int> m_unused_indices; /// Record of the indicies that are not being used for training
};
Expand Down
1 change: 0 additions & 1 deletion include/lbann/utils/lbann_dataset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ namespace lbann
DataReader *data_reader;
long num_samples_processed;
long total_samples;

};
}

Expand Down
6 changes: 4 additions & 2 deletions src/data_readers/lbann_data_reader_imagenet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@ int lbann::DataReader_ImageNet::fetch_data(Mat& X)
}

int pixelcount = m_image_width * m_image_height;
int current_batch_size = getBatchSize();

int n = 0;
for (n = CurrentPos; n < CurrentPos + BatchSize; n++) {
for (n = CurrentPos; n < CurrentPos + current_batch_size; n++) {
if (n >= (int)ShuffledIndices.size())
break;

Expand Down Expand Up @@ -100,8 +101,9 @@ int lbann::DataReader_ImageNet::fetch_label(Mat& Y)
return 0;
}

int current_batch_size = getBatchSize();
int n = 0;
for (n = CurrentPos; n < CurrentPos + BatchSize; n++) {
for (n = CurrentPos; n < CurrentPos + current_batch_size; n++) {
if (n >= (int)ShuffledIndices.size())
break;

Expand Down
6 changes: 4 additions & 2 deletions src/data_readers/lbann_data_reader_mnist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,10 @@ int lbann::DataReader_MNIST::fetch_data(Mat& X)
}

int pixelcount = ImageWidth * ImageHeight;
int current_batch_size = getBatchSize();

int n = 0;
for (n = CurrentPos; n < CurrentPos + BatchSize; n++) {
for (n = CurrentPos; n < CurrentPos + current_batch_size; n++) {
if (n >= (int)ShuffledIndices.size())
break;

Expand All @@ -98,8 +99,9 @@ int lbann::DataReader_MNIST::fetch_label(Mat& Y)
return 0;
}

int current_batch_size = getBatchSize();
int n = 0;
for (n = CurrentPos; n < CurrentPos + BatchSize; n++) {
for (n = CurrentPos; n < CurrentPos + current_batch_size; n++) {
if (n >= (int)ShuffledIndices.size())
break;

Expand Down
6 changes: 4 additions & 2 deletions src/data_readers/lbann_data_reader_nci.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,12 +69,13 @@ int lbann::data_reader_nci::fetch_data(Mat& X)
return 0;
}

int current_batch_size = getBatchSize();
ifstream ifs(m_infile.c_str());
if (!ifs) { std::cout << "\n In load: can't open file : " << m_infile; exit(1); }

string line;
int n = 0;
for (n = CurrentPos; n < CurrentPos + BatchSize; ++n) {
for (n = CurrentPos; n < CurrentPos + current_batch_size; ++n) {
if (n >= (int)ShuffledIndices.size())
break;

Expand Down Expand Up @@ -106,8 +107,9 @@ int lbann::data_reader_nci::fetch_label(Mat& Y)
if(!DataReader::position_valid()) {
return 0;
}
int current_batch_size = getBatchSize();
int n = 0;
for (n = CurrentPos; n < CurrentPos + BatchSize; ++n) {
for (n = CurrentPos; n < CurrentPos + current_batch_size; ++n) {
if (n >= (int)ShuffledIndices.size())
break;

Expand Down
2 changes: 1 addition & 1 deletion src/layers/lbann_input_layer_distributed_minibatch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ lbann::input_layer_distributed_minibatch::input_layer_distributed_minibatch(lban

void lbann::input_layer_distributed_minibatch::setup(int num_prev_neurons) {
if(io_layer::m_data_sets_span_models) {
io_layer::setup_data_readers(0, Layer::comm->get_num_models() * Layer::m_mini_batch_size,
io_layer::setup_data_readers(Layer::comm->get_num_models() * Layer::m_mini_batch_size,
Layer::comm->get_model_rank() * Layer::m_mini_batch_size);
}else {
io_layer::setup_data_readers(0, m_mini_batch_size);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void lbann::input_layer_distributed_minibatch_parallel_io::setup(int num_prev_ne
if(io_layer::m_data_sets_span_models) {
int stride = Layer::comm->get_num_models() * m_num_parallel_readers_training * Layer::m_mini_batch_size;
int model_offset = Layer::comm->get_model_rank() * m_num_parallel_readers_training * Layer::m_mini_batch_size;
cout << "["<< Layer::comm->get_rank_in_world() << "] Setting up input layer, with " << Layer::comm->get_num_models() << " models and " << m_num_parallel_readers_training << " parallel readers and " << Layer::m_mini_batch_size << " mb size, which gives a stride of " << stride << " and my model offset is " << model_offset << " and my base offset is " << (Layer::comm->get_rank_in_model() * Layer::m_mini_batch_size) << endl;
// cout << "["<< Layer::comm->get_rank_in_world() << "] Setting up input layer, with " << Layer::comm->get_num_models() << " models and " << m_num_parallel_readers_training << " parallel readers and " << Layer::m_mini_batch_size << " mb size, which gives a stride of " << stride << " and my model offset is " << model_offset << " and my base offset is " << (Layer::comm->get_rank_in_model() * Layer::m_mini_batch_size) << endl;
io_layer::setup_data_readers(Layer::comm->get_rank_in_model() * Layer::m_mini_batch_size,
stride,
Layer::comm->get_model_rank() * m_num_parallel_readers_training * Layer::m_mini_batch_size);
Expand Down
6 changes: 3 additions & 3 deletions src/layers/lbann_io_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,15 +192,15 @@ long lbann::io_layer::get_linearized_label_size() {

void lbann::io_layer::setup_data_readers(int base_offset, int stride, int model_offset) {
if(m_training_dataset.data_reader != NULL) {
m_training_dataset.data_reader->setup(base_offset, stride, model_offset);
m_training_dataset.data_reader->setup(base_offset, stride, model_offset, comm);
}

if(m_validation_dataset.data_reader != NULL) {
m_validation_dataset.data_reader->setup(base_offset, stride, model_offset);
m_validation_dataset.data_reader->setup(base_offset, stride, model_offset, comm);
}

if(m_testing_dataset.data_reader != NULL) {
m_testing_dataset.data_reader->setup(base_offset, stride, model_offset);
m_testing_dataset.data_reader->setup(base_offset, stride, model_offset, comm);
}
return;
}
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@ bool lbann::target_layer_distributed_minibatch_parallel_io::update_data_reader()
DataReader *data_reader = target_layer::select_data_reader();
if(m_shared_data_reader) {
/// If the data reader is shared with an input layer, don't update the reader just check to see if the epoch is done
/// or will be done on the next update of the input layer (which includes adding the stride)
/// or will be done on the next update of the input layer (which includes adding the stride).
/// Note that target layers are always update before input layers, which is why the position
/// is not up to date yet.
return (data_reader->get_next_position() < data_reader->getNumData());
}else {
return data_reader->update();
Expand Down

0 comments on commit 6be4b03

Please sign in to comment.