Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge NVIDIA's NCCL multi-GPU, switch it to python #4563

Merged
merged 5 commits into from
Jan 17, 2017
Merged

Conversation

cypof
Copy link
Member

@cypof cypof commented Aug 9, 2016

This is an attempt to merge NVIDIA's multi-GPU work on NCCL, while fixing the open issues with python support and the parallel data pipeline. Some of them are non trivial, particularly around the GIL, so the idea is to switch to a python implementation using processes instead. NCCL supports it, and transfers are direct GPU-GPU anyway, so perf. should be the same. The code in train.py should have the same functionality as the command line and C++ version, but is much simpler and gives more flexibility to users. Custom setups from users will also be easier to share with Caffe2.

To test it, you need to clone NCCL, make install it, and set the new flag in Makefile.config.

The GIL and the fact you need to fork sub-processes before initializing CUDA made it tricky to get good performance, but it looks OK now. It should be as good as the NV branch with the new thread-pool, maybe a bit faster as the current pipeline does an in-device copy as the last step, whereas this one is fully zero-copy using CUDA IPC. I have tried on a couple of 4 and 8-GPU boxes so far and everything seems to stay busy.

If NV is interested to help testing and benchmarking, it seems layer-wise reduction and overlapping NCCL communications with compute is often slower than a single big allreduce at the end. Not sure if I'm doing something wrong, and if it can be fixed/improved. Also I don't think layer-wise works with shared weights. We need to compute actual dependencies between layers and that's better left for Caffe2. For now I simply disable layer-wise if the network has shared weights.

If we are happy with this, I suggest we deprecate multi-GPU from the command line version, and remove all associated code before 1.0. There is a lot of code we could remove around the round-robin IO, shared layers lock, root solvers etc. The complexity has grown a bit out of control, and NV had to add another layer of thread pool in their branch to make it go fast, with more Transformer functions etc. I hope we can avoid it by switching to this.

TODO: breakdown the PR into simpler ones, and modularize the python code. For now it exposes a single function train() that emulates the command line. The different parts could be made easier to use and customize individually.

@cypof
Copy link
Member Author

cypof commented Aug 27, 2016

New commit that allows training with any data layer. Lots of cleaning, and removed most of the old parallel code. I simplified train.py a lot, it's really easy to customize now, and moved the advanced bits like the multi-threaded pipeline to a separate example.

@@ -409,7 +409,7 @@ CXXFLAGS += -MMD -MP
# Complete build flags.
COMMON_FLAGS += $(foreach includedir,$(INCLUDE_DIRS),-I$(includedir))
CXXFLAGS += -pthread -fPIC $(COMMON_FLAGS) $(WARNINGS)
NVCCFLAGS += -ccbin=$(CXX) -Xcompiler -fPIC $(COMMON_FLAGS)
NVCCFLAGS += -D_FORCE_INLINES -ccbin=$(CXX) -Xcompiler -fPIC $(COMMON_FLAGS)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only a workaround for an Ubuntu 16.04 issue that should be fixed upstream by Ubuntu, or at least that's what I remember. At any rate it should not be committed as part of this patch since it is a build detail and not about parallelism and NCCL.

Copy link
Member

@shelhamer shelhamer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did a short pass but I need to review this again with more coffee. In the meantime I made a few points that you could address. I can confirm that this builds and passes tests on multi-GPU machines.

static string new_uid();

/**
* Broadcast weigths from rank 0 other solvers.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spellcheck: weights

@@ -125,34 +125,53 @@ void HDF5DataLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
}

template <typename Dtype>
bool HDF5DataLayer<Dtype>::Skip() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 on pulling this into Skip() and Next().

DISABLE_COPY_AND_ASSIGN(Solver);
};
// Timing information, handy to tune e.g. nbr of GPUs
Timer iteration_timer_;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we take all of these timers back out? I don't think we need to expose the Caffe Timer to Python and the like, since existing profiling tools could be used instead.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Timing is pretty handy to adjust the number of GPUs etc. Not sure how to do that conveniently without it? The only way to measure accurately is to insert events in the GPU stream, so peopl would have to use something like pycuda maybe, is that OK?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made my peace with this timer. It can stay this time.

@@ -68,15 +68,16 @@ class BasePrefetchingDataLayer :
const vector<Blob<Dtype>*>& top);

// Prefetches batches (asynchronously if to GPU memory)
static const int PREFETCH_COUNT = 3;
static const int PREFETCH_COUNT = 4; // same as proto
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is the same as the proto def, then drop this. No need to have fragile duplication.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't found a way to read the default value from the proto. One way could be to instantiate a data_param, read the value and destroy it. Seems a kill overkill so just copied the constant here.

@@ -448,7 +454,18 @@ endif

all: lib tools examples

lib: $(STATIC_NAME) $(DYNAMIC_NAME)
ifeq ($(CPU_ONLY), 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These build checks should be spun out into another PR.

@@ -70,7 +70,7 @@ class GradientBasedSolverTest : public MultiDeviceTest<TypeParam> {
string RunLeastSquaresSolver(const Dtype learning_rate,
const Dtype weight_decay, const Dtype momentum, const int num_iters,
const int iter_size = 1, const int devices = 1,
const bool snapshot = false, const char* from_snapshot = NULL) {
const bool snapshot = false, const string from_snapshot = "") {
Copy link
Member

@shelhamer shelhamer Oct 2, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why char* -> string and the need to now call c_str()?

@@ -565,7 +581,9 @@ class SGDSolverTest : public GradientBasedSolverTest<TypeParam> {

protected:
virtual void InitSolver(const SolverParameter& param) {
this->solver_.reset(new SGDSolver<Dtype>(param));
SolverParameter new_param = param;
new_param.set_type("SGD");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting the type has no effect when instantiating the solver directly from its type as done here with SGDSolver and the others. It's only needed when making use of the solver registry.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It used to work as the other workers did not apply gradients. Now each worker does the full iteration, so they need to be of the right type, including the ones created by parallel.cpp, based on param type.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should no longer be needed as of #5009. Please drop these set_type0 to check.

@@ -51,7 +51,18 @@ const int NPY_DTYPE = NPY_FLOAT32;
void set_mode_cpu() { Caffe::set_mode(Caffe::CPU); }
void set_mode_gpu() { Caffe::set_mode(Caffe::GPU); }

void set_random_seed(unsigned int seed) { Caffe::set_random_seed(seed); }
void InitLog(int level) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why this is here, and in particular in this PR.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The python code might want to initialize each process with different log levels. E.g. full log for the master, and warn only for each worker, so that there is less duplicated logs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logging could be its own commit.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be nice, but not strictly necessary, to figure out how to do this with boost python call policies instead of hardcoding it. (Like the Net constructor.)

@@ -359,10 +411,18 @@ BOOST_PYTHON_MODULE(_caffe) {
bp::return_internal_reference<>()))
.def("setup", &Layer<Dtype>::LayerSetUp)
.def("reshape", &Layer<Dtype>::Reshape)
.add_property("type", bp::make_function(&Layer<Dtype>::type));
.add_property("type", bp::make_function(&Layer<Dtype>::type))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many of the members exposed in these changes don't seem specific to this PR. Could they be pulled out on their own?

@@ -89,6 +89,12 @@ const Dtype* Blob<Dtype>::cpu_data() const {
template <typename Dtype>
void Blob<Dtype>::set_cpu_data(Dtype* data) {
CHECK(data);
// Make sure CPU and GPU sizes remain equal
Copy link
Member

@shelhamer shelhamer Oct 2, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can the sizes differ? This seems like it should never happen.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a weird case. As an optimization, Blob doesn't release its buffers if the size is reduced. It only updates its size, and uses part of the allocated buffers. Unfortunately the underlying syncedmem doesn't have a notion of size vs capacity, it always copies its full length between CPU and GPU. If a pointer is directly set from set_cpu_data with the current blob size, syncedmem might try to copy an old larger capacity into it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right.

@@ -74,6 +74,7 @@ void DataLayer<Dtype>::Next() {
<< "Restarting data prefetching from start.";
cursor_->SeekToFirst();
}
offset_++;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this reset when seeking back to the beginning or otherwise it will add up endlessly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's on purpose, the round robin can continue the same way, even if the database has to reset. It's easier to think of it on the unrolled dataset.

std::random_shuffle(data_permutation_.begin(), data_permutation_.end());
for (int i = 0; i < batch_size; ++i) {
while (Skip()) {
Next();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At a glance it looks like Next() is still doing the work of loading data. Can Skip() instead advance over the data w/o doing all the work of Next()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked the database code, it's only moving the cursor which should be cheap.

@cypof
Copy link
Member Author

cypof commented Oct 3, 2016

Thanks for reviewing, I will update this tomorrow.

@junshi15
Copy link

@cypof since you are getting rid of P2Psync, how would this affect RDMASync used in CaffeOnSpark?

@cypof
Copy link
Member Author

cypof commented Oct 27, 2016

@junshi15 GPUParams has not changed, it still gathers all the weights in a single buffer, so the distributed version should not require any changes. You run the distributed allreduce on this buffer at the end of the local one like before. The code will need to extends NCCL instead of P2PSync, but they are similar.

@junshi15
Copy link

@cypof Thanks for the info. How much performance gain do you see with NCCL against original P2PSync?

@cypof
Copy link
Member Author

cypof commented Oct 27, 2016

I don't have numbers yet, but it's better. If you plan to switch I would be happy to know what you get.

@@ -51,7 +51,18 @@ const int NPY_DTYPE = NPY_FLOAT32;
void set_mode_cpu() { Caffe::set_mode(Caffe::CPU); }
void set_mode_gpu() { Caffe::set_mode(Caffe::GPU); }

void set_random_seed(unsigned int seed) { Caffe::set_random_seed(seed); }
Copy link
Member

@shelhamer shelhamer Nov 19, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shouldn't go missing! Can't drop this from the API.

@@ -72,7 +72,8 @@ class Solver {
inline const vector<shared_ptr<Net<Dtype> > >& test_nets() {
return test_nets_;
}
int iter() { return iter_; }
int iter() const { return iter_; }
void set_iter(int value) { iter_ = value; }
Copy link
Member

@shelhamer shelhamer Nov 19, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should not be necessary: please double-check and drop. Restore() should handle this.

@@ -105,6 +105,32 @@ class DataLayerTest : public MultiDeviceTest<TypeParam> {
}
}

void TestSkip() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ImageDataLayer needs a TestSkip() too.

inline static void set_root_solver(bool val) { Get().root_solver_ = val; }
inline static int solver_rank() { return Get().solver_rank_; }
inline static void set_solver_rank(int val) { Get().solver_rank_ = val; }
inline static bool multi_process() { return Get().multi_process_; }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

multi_process -> multiprocess?

@shelhamer
Copy link
Member

@cypof Thanks for the updates. Still need to drop set_iter() and set_type() but this is looking ready otherwise!

@oyxhust
Copy link

oyxhust commented Feb 20, 2017

I have the same question as @iacopomasi . When I used my custom python layer but got this issue, So I have to use trian.py? @shelhamer

@shelhamer
Copy link
Member

shelhamer commented Feb 20, 2017

@iacopomasi @oyxhust Right, train.py is the demonstration of the new pycaffe multi-GPU interface which includes training nets with Python layers (as each net is now in its own process, sidestepping earlier parallelization issues). For further usage questions please ask on the mailing list.

@cypof it might be helpful to include further documentation in the form of an ipython notebook?

@weiliu89
Copy link

weiliu89 commented Mar 9, 2017

@cypof Thanks for including NCCL in Caffe, which helps speed up the Multi-GPU training quite a bit. However, I am facing two issues:

  1. As you mentioned that share_weight is not supported for AllReduce, do you have any plan to include support for this? Or can you give any hints on how to add support for it. When a net has share_weight, it can be solved by setting layer_wise_reduce to false. However the training speed becomes much slower (2x). I think supporting share_weight will be very useful (for example recurrent network needs to share_weight).

  2. When I cancel a job trained using Multi-GPU, the job seems to hang at ncclCommDestroy. Have you encountered such issue?

Thanks!

@cypof
Copy link
Member Author

cypof commented Mar 10, 2017

@shelhamer @weiliu89 Sorry I missed the previous message. Yes a notebook would be great, I will try to do that. About share_weight, it's a matter of ordering the graph, to sync weights only when all their layers are done. It might not be too difficult but we need to look at the code for weight sharing, how to list layers for a given blob etc. I haven't seen the hang on shutdown, there is not much we can do if it hangs in it. It might be useful to create an issue on the NCCL repo.

CUDA_CHECK(cudaGetDevice(&device));
CHECK_EQ(device, device_);
#endif
param.set_type(rank0_->type());
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? The type info has been copied from rank0_.

caffe.set_device(gpus[rank])
caffe.set_solver_count(len(gpus))
caffe.set_solver_rank(rank)
caffe.set_multiprocess(True)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? This has been done in

Caffe::set_multiprocess(true);
.

wqvbjhc added a commit to wqvbjhc/caffe-ssd that referenced this pull request Jul 16, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet