Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OpenMP support for CorrMM (simple paths) #3689

Closed
wants to merge 6 commits into from
Closed

Add OpenMP support for CorrMM (simple paths) #3689

wants to merge 6 commits into from

Conversation

jonhoo
Copy link

@jonhoo jonhoo commented Nov 25, 2015

Adds OpenMP support for two of the three paths in CorrMM. This patch results in a ~5x performance gain for my MNIST-like CNN when run on an 80-core machine. The last path, backpropagation with weights, is slightly trickier because of the write-sharing of weights, and will be handled separately in #3653.

This cover forward and backprop wrt. inputs

@jonhoo
Copy link
Author

jonhoo commented Nov 26, 2015

Hmm, the failures here are interesting. @JesseLivezey, can you take a look and see that I'm doing the indexing into the new col correctly? I thought GETPTR3 would give me a pointer to the sub-matrix for each batch, but maybe not?

col_dim[1] = (npy_intp)(topHeight * topWidth);
PyArrayObject* col = (PyArrayObject*)PyArray_EMPTY(2,
npy_intp col_dim[3];
col_dim[0] = (npy_intp)(batchSize);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is ok for testing, but use too much memory. You only need the size min(batch_size, number_of_thread).
I think the max number of thread can be know by a call to openmp. This will need adjustmane in the usage of col.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this will need to be aligned? I'm not sure of that, I think it is not mandatory to align, but maybe this can give a speed up.

@nouiz
Copy link
Member

nouiz commented Nov 27, 2015

I'm not able to see the problem by code review. Do this work at your place? Do you use a parallel blas?

@JesseLivezey
Copy link
Contributor

@jonhoo, @nouiz, also looks reasonable to me. I can try using it and see what happens.

@nouiz
Copy link
Member

nouiz commented Dec 3, 2015

I'm still puzzled by travis error. If a few people test it outside and it work well, maybe we could just disable openmp in travis?

// First, im2col
im2col((%(float_type)s*)PyArray_DATA(bottom) + n * bottom_stride, nChannels, bottomHeight,
bottomWidth, kH, kW, padH, padW, dH, dW, (%(float_type)s*)PyArray_DATA(col));
bottomWidth, kH, kW, padH, padW, dH, dW, (%(float_type)s*)PyArray_GETPTR3(col, colidx, 0, 0));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could it be the PyArray_DATA or PyArray_GETPTR3 that cause this problem? Can this be moved outside the loop for _DATA and for getptr3, could this be computated manually? This is the only python code in the parallel loop, so I don't see what else could cause the double free related to python in the error message.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically PyArray_GETPTR3 should just be doing pointer arithmetic to move past the first colidx matrices, so we could, although it'll make the code much less readable. In fact, looking at the source, that's pretty much exactly what it does. PyArray_BYTES and PyArray_STRIDES both just do a pointer dereference, so I don't see how the problem can originate from there :/

@JesseLivezey
Copy link
Contributor

@jonhoo @nouiz , I get

error: omp_get_max_threads was not declared in this scope.      max = omp_get_max_threads();

error: omp_get_thread_num was not declared in this scope.              colidx = omp_get_thread_num();

.theanorc looks like

[global]
root = /usr/local/cuda
floatX = float32
device = cpu
ldflags = -lmkl_rt
openmp = True
allow_gc = False

[lib]
cnmem = .45

[nvcc]
fastmath = True

I'm just calling CorrMM directly. The legacy CPU ConvOp works with OpenMP for me, so I think something is still missing for BaseCorrMM.

@JesseLivezey
Copy link
Contributor

Adding #include <omp.h> to corr_gemm.c gets it through compilation for me but it still errors when it runs

*** Error in `python': free(): invalid pointer: 0x00000000043edd50 ***

# Add the -fopenmp flags
ret += super(BaseCorrMM, self).c_compile_args()

return ret
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The method c_headers must also be updated to get the missing include added by the parent.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes the first problem!

error: omp_get_max_threads was not declared in this scope.

@nouiz
Copy link
Member

nouiz commented Dec 7, 2015

By code review I'm not able to find the problem. Can you comment the pydecref and try if this fix the problem? It will help to know which object is causing problem (but this would probably cause a memory leak, so this isn't a long term solution)

max = omp_get_max_threads();
#endif
npy_intp col_dim[3];
col_dim[0] = (npy_intp)(max);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jonhoo, @nouiz, K_ and N_ below still reference col_dim[0] and col_dim[1] below. jonhoo, I made a PR to your branch with fixes.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch! Following up in https://github.com/jonhoo/Theano/pull/1.

@jonhoo
Copy link
Author

jonhoo commented Dec 9, 2015

Sorry I've been so slow to follow up on this. Have been quite busy the past week. Will hopefully have time to look more closely on this on Thursday. @JesseLivezey, thanks, I'll have a look.

@JesseLivezey
Copy link
Contributor

@nouiz, I think this is ready for review. Is there a standard way to test openmp ops in theano? There isn't a test for the openmp CorrMM currently.

@nouiz, @jonhoo, should the grads inherit the openmp option from the original op?

@nouiz
Copy link
Member

nouiz commented Dec 21, 2015

I think this PR is good to merge. But as we won't be available in case of problem and it wasn't tested on mac/windows, I will postpone the merge when we are back from holidays. But those problems should be already addressed by the super clas OpenMPOp, but just in case.

To test this, I run the test_corr.py file with OMP_NUM_THREADS=1 and 2.

Do someone tested it with bigger matrix to see if there is a speed up? Just to be sure.

thanks

@JesseLivezey
Copy link
Contributor

I ran some tests on convolutions that were similar to the first layer of imagenet model convolutions and also on some convolutions that would be more similar to 1D convolutions with a spectrogram. This was all done on a 4-core xeon processor.

For both sets of shapes, running with MKL_NUM_THREADS=4 (previous version) was faster than no parallelization, and running with OMP_NUM_THREADS=4 and MKL_NUM_THREADS=1 (new openmp version) was faster than just having MKL use 4 threads!

I'll also try timing on 8 and 12 core processors today. The previous version (no openmp, MKL using all cores) didn't scale well past ~6 cores, so hopefully this will help.

@nouiz
Copy link
Member

nouiz commented Dec 21, 2015

It would be great to update this file with this timing information to this
file (and tell that cpu convolution is now parallelized for the forward and
grad vs inputs):

doc/tutorial/multi_cores.txt

thanks

On Mon, Dec 21, 2015 at 3:38 PM, Jesse Livezey notifications@github.com
wrote:

I ran some tests on convolutions that were similar to the first layer of
imagenet model convolutions and also on some convolutions that would be
more similar to 1D convolutions with a spectrogram. This was all done on a
4-core xeon processor.

For both sets of shapes, running with MKL_NUM_THREADS=4 (previous version)
was faster than no parallelization, and running with OMP_NUM_THREADS=4 and
MKL_NUM_THREADS=1 (new openmp version) was faster than just having MKL use
4 threads!

I'll also try timing on 8 and 12 core processors today. The previous
version (no openmp, MKL using all cores) didn't scale well past ~6 cores,
so hopefully this will help.


Reply to this email directly or view it on GitHub
#3689 (comment).

@JesseLivezey
Copy link
Contributor

I can make another PR to this PR with the timing + docs.

Timing done on 12-core Intel Xeon X5650 with 24GB RAM.
BLAS was MKL from Anaconda Python.

Image shape, filter shape
(128, 3, 128, 128), (96, 3, 5, 5)
Imagenet-like
MKL     OpenMP  Speedup  Notes
1       12      1        Legacy, OpenMP, baseline
1       1       .097     Legacy, single thread
1       1       .75      CorrMM, MKL only
2       1       1.3      “
4       1       1.9      “
8       1       2.0      “
12      1       1.7      “
1       2       1.5      CorrMM, MKL single thread + OpenMP
1       4       2.7      “
1       8       4.5      “
1       12      5.3      “

Image shape, filter shape
(128, 85, 1, 258), (64, 85, 1, 20)
Spectrogram-like
MKL     OpenMP  Speedup  Notes
1       12      1        Legacy, OpenMP, baseline
1       1       .095     Legacy, single thread
1       1       .50      CorrMM, MKL only
2       1       .78      “
4       1       1.2      “
8       1       1.5      “
12      1       1.3      “
1       2       .94      CorrMM, MKL single thread + OpenMP
1       4       2.0      “
1       8       4.0      “
1       12      5.2      “

@JesseLivezey
Copy link
Contributor

@jonhoo, @nouiz the test_corr.py tests DO NOT pass on OSX El Capitan 10.11.2 with Homebrew gcc 5.3.0 --without-multilib --enable-cxx. This is the current default installed with gcc through Homebrew. The default OSX clang compiler doesn't support OpenMP, but tests pass; OpenMP just gets disabled.

The test DO pass on OSX (same machine as above) with Homebrew gcc48 4.8.4 --without-multilib --enable-cxx.

The tests DO pass on Ubuntu with g++ version Ubuntu 4.8.4-2ubuntu1~14.04.

Any ideas? I can try and run some of the other ops with OpenMP (elementwise, I guess) and see if they pass/fail the same way.

@jonhoo
Copy link
Author

jonhoo commented Jan 6, 2016

That's very interesting.. I don't know what might be causing that. Can you try with gcc 5.X on Ubuntu?

@JesseLivezey
Copy link
Contributor

Tests also fail on Ubuntu with g++ 5.3 and also 4.9. Maybe something changed in OpenMP from 3.1 to 4.0? I think g++ 4.8 only supports 3.1 where 4.9+ supports 4.0.

@nouiz, does Theano keep track of the g++ version?

Elementwise tests seem okay for all versions.

@nouiz
Copy link
Member

nouiz commented Feb 12, 2016

As the ICML deadline is passed, we can try to merge stuff that isn't 100% sure.

Theano just record the version, but except at a few places, we don't use it. When we use the version, it is just to work around bugged version.

@JesseLivezey which tests failed? Can you show the error?

This PR would need a rebase.

We try to clean up stuff to make a release. It would be great to have this in. thanks

@JesseLivezey
Copy link
Contributor

@nouiz With g++5 on OSX I get the following errors (I think the errors are the same on ubuntu with g++4.9+.) I looked into it a bit more a while ago and it looked like when OMP_NUM_THREADS >1 the computations were wrong. I think it might have been doing the computation for only the first element of the batch for every loop, but I'd have to go back and look more carefully.

Tests that basic correlations work for odd and even ... FAIL
Tests basic correlation in full mode and case where filter ... FAIL
test_img_kernel_same_shape (theano.tensor.nnet.tests.test_corr.TestCorr2D) ... FAIL
test_infer_shape_forward (theano.tensor.nnet.tests.test_corr.TestCorr2D) ... ok
test_infer_shape_gradI (theano.tensor.nnet.tests.test_corr.TestCorr2D) ... ok
test_infer_shape_gradW (theano.tensor.nnet.tests.test_corr.TestCorr2D) ... ok
Tests scenario where filter_shape[1] != input_shape[1] ... ok
test_non_contiguous (theano.tensor.nnet.tests.test_corr.TestCorr2D) ... ERROR
Tests correlation where the {image,filter}_shape is a Constant tensor. ... FAIL
Tests correlation where subsampling != (1,1) ... FAIL
Make sure errors are raised when image and kernel are not 4D tensors ... ok

======================================================================
ERROR: test_non_contiguous (theano.tensor.nnet.tests.test_corr.TestCorr2D)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/grndfthrprdx/Development/Theano/theano/tensor/nnet/tests/test_corr.py", line 331, in test_non_contiguous
    self.validate((2, 2, 3, 3), (2, 2, 2, 2), 'valid', non_contiguous=True)
  File "/Users/grndfthrprdx/Development/Theano/theano/tensor/nnet/tests/test_corr.py", line 128, in validate
    utt.verify_grad(sym_CorrMM, [orig_image_data, filter_data])
  File "/Users/grndfthrprdx/Development/Theano/theano/tests/unittest_tools.py", line 83, in verify_grad
    T.verify_grad(op, pt, n_tests, rng, *args, **kwargs)
  File "/Users/grndfthrprdx/Development/Theano/theano/gradient.py", line 1709, in verify_grad
    abs_tol, rel_tol)
GradientError: GradientError: numeric gradient and analytic gradient exceed tolerance:
        At position 32 of argument 0,
            abs. error = 5.979729,  abs. tolerance = 0.010000
            rel. error = 0.600072,  rel. tolerance = 0.010000
Exception args: 
The error happened with the following inputs:, [array([[[[ 0.70043713,  0.84418666,  0.67651433],
         [ 0.72785807,  0.95145798,  0.0127032 ],
         [ 0.41358769,  0.0488128 ,  0.09992856]],

        [[ 0.5080663 ,  0.20024754,  0.74415416],
         [ 0.192892  ,  0.70084476,  0.29322812],
         [ 0.77447945,  0.00510884,  0.11285765]]],


       [[[ 0.11095367,  0.24766822,  0.0232363 ],
         [ 0.72732115,  0.34003493,  0.19750315],
         [ 0.90917957,  0.978347  ,  0.53280252]],

        [[ 0.25913185,  0.58381259,  0.32569066],
         [ 0.88889933,  0.62640452,  0.8188737 ],
         [ 0.5473454 ,  0.41671202,  0.74304718]]]], dtype=float32), array([[[[ 0.36959639,  0.07516655],
         [ 0.77519298,  0.21940924]],

        [[ 0.07934213,  0.48678052],
         [ 0.1536739 ,  0.8284651 ]]],


       [[[ 0.19136856,  0.27040896],
         [ 0.56103444,  0.90238041]],

        [[ 0.85178834,  0.41808197],
         [ 0.39347628,  0.01622051]]]], dtype=float32)], 
The value of eps is:, None, 
The out_type is:, None

======================================================================
FAIL: Tests that basic correlations work for odd and even
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/grndfthrprdx/Development/Theano/theano/tensor/nnet/tests/test_corr.py", line 146, in test_basic
    self.validate(img, fil, border_mode, verify_grad=False)
  File "/Users/grndfthrprdx/Development/Theano/theano/tensor/nnet/tests/test_corr.py", line 124, in validate
    self.assertTrue(_allclose(theano_output, ref_output))
AssertionError: False is not true

======================================================================
FAIL: Tests basic correlation in full mode and case where filter
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/grndfthrprdx/Development/Theano/theano/tensor/nnet/tests/test_corr.py", line 212, in test_full_mode
    self.validate((3, 2, 5, 5), (4, 2, 8, 8), 'full')
  File "/Users/grndfthrprdx/Development/Theano/theano/tensor/nnet/tests/test_corr.py", line 124, in validate
    self.assertTrue(_allclose(theano_output, ref_output))
AssertionError: False is not true

======================================================================
FAIL: test_img_kernel_same_shape (theano.tensor.nnet.tests.test_corr.TestCorr2D)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/grndfthrprdx/Development/Theano/theano/tensor/nnet/tests/test_corr.py", line 152, in test_img_kernel_same_shape
    self.validate((3, 2, 3, 3), (4, 2, 3, 3), 'full')
  File "/Users/grndfthrprdx/Development/Theano/theano/tensor/nnet/tests/test_corr.py", line 124, in validate
    self.assertTrue(_allclose(theano_output, ref_output))
AssertionError: False is not true

======================================================================
FAIL: Tests correlation where the {image,filter}_shape is a Constant tensor.
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/grndfthrprdx/Development/Theano/theano/tensor/nnet/tests/test_corr.py", line 191, in test_shape_Constant_tensor
    (5, 2, 2, 3), border_mode)
  File "/Users/grndfthrprdx/Development/Theano/theano/tensor/nnet/tests/test_corr.py", line 124, in validate
    self.assertTrue(_allclose(theano_output, ref_output))
AssertionError: False is not true

======================================================================
FAIL: Tests correlation where subsampling != (1,1)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/Users/grndfthrprdx/Development/Theano/theano/tensor/nnet/tests/test_corr.py", line 163, in test_subsample
    self.validate((3, 2, 7, 5), (5, 2, 2, 3), 'valid', subsample=(2, 2))
  File "/Users/grndfthrprdx/Development/Theano/theano/tensor/nnet/tests/test_corr.py", line 124, in validate
    self.assertTrue(_allclose(theano_output, ref_output))
AssertionError: False is not true

----------------------------------------------------------------------
Ran 11 tests in 27.950s

FAILED (errors=1, failures=5)

@JesseLivezey
Copy link
Contributor

Bumping this again in case either of you have any ideas why this fails with different gcc versions @jonhoo @nouiz

I can investigate more tomorrow.

@nouiz
Copy link
Member

nouiz commented Mar 9, 2016

Sorry, no time. Thanks to investigate.

On Tue, Mar 8, 2016 at 9:24 PM, Jesse Livezey notifications@github.com
wrote:

Bumping this again in case either of you have any ideas why this fails
with different gcc versions @jonhoo https://github.com/jonhoo @nouiz
https://github.com/nouiz

I can investigate more tomorrow.


Reply to this email directly or view it on GitHub
#3689 (comment).

@nouiz
Copy link
Member

nouiz commented Sep 19, 2016

this was finished in #4591

@nouiz nouiz closed this Sep 19, 2016
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants