Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap up SyncedMem resize from @kloudkl; make train/test nets share data blobs #355

Closed
wants to merge 19 commits into from

Conversation

jeffdonahue
Copy link
Contributor

This PR wraps up the SyncedMem changes from @kloudkl in #250 (thanks for doing most of the work on this @kloudkl!). I ended up removing the thrust host/device_vector dependency as I couldn't get them working at all (almost all tests segfaulted), and instead kept the void* pointers that we were already using to store the data (and still syncing using explicit cudaMemcpy calls). I added private fields cpu_capacity_ and gpu_capacity_ which keep track of the current space allocation, and methods cpu_resize() and gpu_resize() allocate additional memory (copying any current memory) if the current allocation is insufficient.

The final commit makes use of this new functionality and my ShareData and ShareDiff methods from #332 to share all the data blobs between the training and test net (#332 only shared weight blobs). Because of this change I've expanded the imagenet_val batchsize from 50 to 250 (as close to the train batch size of 256 as possible while still evenly dividing 50K...) for slightly faster test time with little memory cost.

I ran some tests of the ImageNet model on a Tesla K40 and found little to no difference in training/test time (not counting the test time improvement due to the expanded batch size), while memory use after the first test iteration improves significantly:

test batch size 50:

    @ dev:
    before first test iteration: 3014 MB (GPU memory use based on 'nvidia-smi')
     after first test iteration: 3255 MB
    -> 241 MB additional memory used by test net

    @ capacity-aware-memory:
    before first test iteration: 3014 MB
     after first test iteration: 3053 MB
    -> 39 MB additional memory used by test net, 84% reduction vs. dev


test batch size 250:

    @ dev:
    before first test iteration: 3014 MB
     after first test iteration: 4143 MB
    -> 1129 MB additional memory used by test net

    @ capacity-aware-memory:
    before first test iteration: 3014 MB
     after first test iteration: 3141 MB
    -> 127 MB additional memory used by test net, 89% reduction vs. dev

Note that the memory use of the test net is non-zero, because additional memory is still allocated for the test net in at least these three cases:

(1) blobs used inside of layers to store intermediate computations (the solver has no way of knowing about these; would have to change each layer's code to share these),
(2) blobs in the test net which do not share a name with any blob in the train net (e.g. the accuracy top blob in the ImageNet test net -- which isn't a great example because it's only 8 bytes, but you get the idea)
(3) blobs in the test net which are larger (in terms of count) than correspondingly named blobs in the train net (there aren't any of these cases in any of Caffe's sample models, but it would obviously happen if you, e.g., used a larger batch size in the test net, or did something weird like swapping the names of conv1 and conv5 between the train and test net, etc.)

@jeffdonahue
Copy link
Contributor Author

I just edited history to not rename syncedmem.cpp to syncedmem.cu (was done in @kloudkl's first commit), as github's diff is pretty useless for renamed files. I can redo that change if this gets merged though.

@@ -305,6 +305,10 @@ $(TEST_BIN_DIR)/%.testbin: $(TEST_BUILD_DIR)/%.o $(GTEST_OBJ) $(STATIC_NAME) \
-o $@ $(CXXFLAGS) $(LDFLAGS) $(WARNINGS)
@ echo

$(GTEST_OBJ): $(GTEST_SRC) | $(GTEST_BUILD_DIR)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's up with the Makefile changes introduced in 1d4ea4be7e77203306a157580a44f70fb475093e?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kloudkl had added the +$(OBJ_BUILD_DIR)/%.cuo: rule (presumably due to the renaming of syncedmem.cpp->syncedmem.cu -- we had no rule for building *.cu files in the top-level src/caffe directory), but I moved it down in the list because older versions of make (including the one I use, the default Ubuntu installation) will match the first rule matching instead of the most specific one, so this rule matched *.cu files in subdirs also, which I fixed by moving it after all other *.cu rules.

The rest of the changes were basically style changes and adding a dependency on the header files $(HXX_SRCS) where I happened to notice there wasn't one before (in one case changing from $(PROTO_GEN_HEADER), which is a subset of $(HXX_SRCS)) Sorry for mixing these changes into an unrelated PR...I can remove them from history if desired.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fine to keep it in this PR for convenience, but could you split the Makefile changes into their own commit (or at least mention them in the message for 1d4ea4b)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done - moved into different commit right before the "use handwritten resize methods" one (github displays it as being the last commit though)

@shelhamer
Copy link
Member

Nice memory savings, and reducing the number of allocations is good too!

This all looks good to me, and the tests cover the key uses, but I'd rather have another reviewer check this too since memory is slightly important.

Thanks @kloudkl for raising the original PR!

Thanks Jeff for this and all your recent sharing commits!

@shelhamer
Copy link
Member

@longjon @sergeyk @sguada one of you please review.

@jeffdonahue
Copy link
Contributor Author

Thanks for taking the time to look over the changes Evan! Agreed that this isn't a PR we want to merge hastily.

void Reshape(const int num, const int channels, const int height,
const int width);
explicit Blob(const int num = 0, const int channels = 0,
const int height = 0, const int width = 0);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe remove the default arguments? It would be odd to set some being 0 and some being not 0, but if all are 0, it is effectively just the default constructor Blob().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kloudkl got rid of the no-arg blob constructor -- the one with the default args sets member variables and does Reshape, so I don't think it's quite the same?

@Yangqing
Copy link
Member

There are a few things that I am a little worried about on this, but maybe I am overly cautious, so please kindly check with me:

(1) Assuming that we have two blobs referring to the same syncedmem, and one does a reshape or w/e that changes the syncedmem storage, either by expanding or shrinking and zero-fitting. Will the other blob play nicely with it?

(2) It seems that we never shrink the allocated size, which means that "once expanded, always expanded". I am wondering if this would cause certain problems in the future (for now things seem to be fine in our limited usage of this feature), so it is probably better to document it in the comments so one may be aware of this effect?

const int width);
explicit Blob(const int num = 0, const int channels = 0,
const int height = 0, const int width = 0);
explicit Blob(Blob* memory_share_blob);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO we could do const reference

explicit Blob(const Blob& memory_share_blob);

? :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, done

@jeffdonahue
Copy link
Contributor Author

Hey Yangqing -- no need to explain your hesitation, greatly appreciate feedback from you!

(1) The way I implemented this was indeed (as you said in (2)) to only ever expand the SyncedMemory allocations (cpu_data_ and gpu_data_), and do so lazily as with the rest of the SyncedMemory implementation. SyncedMemory maintains the variables size_, cpu_capacity_, and gpu_capacity_, with the size_ variable being the only publicly mutable one via the set_size method which simply updates the size_ variable. Blob calls to the underlying memory accessors (cpu_data(), gpu_data(), mutable_cpu_data(), etc.) now always call set_size(count_) and then will simply return the data pointer as before (copying to/from device/host as needed, just like before) if it already has adequate storage; i.e., if cpu_capacity_ >= size_ (or gpu_capacity_ >= size_ for a gpu_data call). If the SyncedMemory doesn't have sufficient storage, it then allocates the necessary amount of additional memory (keeping any existing data as is via realloc for the host or explicit cudaMemcpy for GPU).

For the use in the current codebase, we would only ever expand the original memory allocation once -- specifically during the first training iteration (since as of the dev branch today we first perform a test pass before training, and training blobs are typically larger than test blobs -- if they were <= the size of the test blobs we'd never expand the original allocation).

I can't think of any scenarios in which this would cause problems with Blobs interfering with one another, as they always explicitly set the SyncedMemory size before any data accessor/mutator calls. Obviously this can't be used haphazardly to share SyncedMemory's among any arbitrary set of Blobs, but I think it seems to work for sharing only among pairs of blobs between the train and test net. Let me know if you think there's something I may be missing though.

(2) Agreed on adding a comment to note that SyncedMemory is only ever expanded -- will do.

@shelhamer
Copy link
Member

Needs rebase, then let's merge? I think the reasoning in #355 (comment) is sound, but if there are any reservations let me know!

Re: (1), the "other" blob is actually a shared pointer to a blob, so the reshape/expansion/whatever is done to "both" blobs. Or have I missed something? Blobs are better thought of as virtual views of a SyncedMemory with concrete size. See follow-up comment for details.

@jeffdonahue
Copy link
Contributor Author

I think that's not quite right. With this PR, you can think of Blobs as having a "virtual" size, while the SyncedMemory (still) has a "physical" size which is the maximum of all the virtual sizes of the Blobs sharing it, assuming they've all called cpu_data or another data accessor/mutator.

So both Blobs have their own "size", but the size is just an int member variable they each hold which does not actually affect the state of the SyncedMemory on its own. When a Blob's {mutable_,}cpu_data (or other data accessor/mutator is called), the Blob does a set_size on its SyncedMemory which it may share with other Blobs. This set_size always takes essentially 0 time, just updating the SyncedMemory's size variable. Then when a SyncedMemory's cpu/gpu_data is acccessed the SyncedMemory checks whether its current allocation is sufficient and allocates additional space if necessary.

Please do let me know if that doesn't make sense or you have any other questions.

I forgot to add the comment suggested by Yangqing that set_size will never actually deallocate any memory -- I'll do that, rebase, and let you know that this is ready to merge. I'd like to believe my implementation works and won't cause any problems anywhere, but I'm not sure about immediately merging this into master -- I think it might not be a bad idea to let this simmer in dev a while to make sure nothing funny happens that I/we haven't thought of.

@shelhamer
Copy link
Member

Alright, no need for a trial by fire by sending this to master soon. Thanks for the explanation–I'd somehow convinced myself we were sharing pointers, although that's clearly impossible with shareData() being a member function of Blob. Oh well.

@kloudkl
Copy link
Contributor

kloudkl commented May 3, 2014

One of the most highlighted features of the CUDA 6 is the unified memory which automatically manages the data transfer between the CPU and the GPU. It provides a much more fundamental solution to memory synchronization. Shall we shift to this new GPU memory management paradigm?

@kloudkl
Copy link
Contributor

kloudkl commented Jun 10, 2014

@jeffdonahue, do you have the time to rebase again? I want to fork this branch to continue the development of #195. If you don't, I will do it in my fork. But I'm afraid to introduce extra conflicts if we resolve the merge conflicts in different ways.

@kloudkl
Copy link
Contributor

kloudkl commented Jun 10, 2014

The quality of memory cards and management systems are usually assured by repeatedly writing, reading and checking random data with the memtest or stress test utilities. The stability of this PR can be tested similarly.

@jeffdonahue
Copy link
Contributor Author

The quality of memory cards and management systems are usually assured by repeatedly writing, reading and checking random data with the memtest or stress test utilities. The stability of this PR can be tested similarly.

Agreed that that would be good, but I'm not familiar with those tools and don't really have the spare time to learn them, unless there's something very simple and straightforward I can do? Otherwise if you or anyone else would like to stress test this yourselves, by all means feel free (not that I could stop you...).

I'm actually not really sure my implementation is ready for "primetime". I have occasionally seen cases when a cudaMalloc fails on a gpu_resize call when there is a lot of extra memory being allocated (e.g., when the first imagenet training iteration begins, since this now happens after the first test iteration when the batch size is smaller -- when I originally wrote this PR that problem wasn't exposed because the training net was allocated first since we didn't do a "0th iteration" test then).

I think I might know how to fix it, but probably won't have time to work on it for a few weeks, unfortunately. If you'd like, feel free to take over development yourself again (either with my implementation or yours).

@shelhamer
Copy link
Member

This is worth revising in light of #594 especially with regards to

  1. reallocation for reshapes to larger sizes
  2. sharing all blobs between train and test nets (and not only the trained parameters)

in case anyone in the community with a focus on memory usage is interested in taking up this PR.

@jeffdonahue
Copy link
Contributor Author

Closing this as I've abandoned it; @tnarihi feel free to cherry-pick or reference for #1985 if useful.

@jeffdonahue jeffdonahue closed this Mar 9, 2015
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants