Faster algorithms and gradients for GpuCorrMM #2033

Merged
merged 22 commits into from Sep 1, 2014

Projects

None yet

7 participants

@f0k
Contributor
f0k commented Aug 12, 2014

As a follow-up to #2023, this PR adds caffe's backward pass wrt. inputs to GpuCorrMM to implement border_mode="full". It passes all the tests and it's about 2-4x faster than simulating a full convolution with a padded valid convolution. On the way, it cleans up the code and adds some more elaborate documentation.

There are some caveats, though, all stemming from the fact that the implementation in caffe is meant as a backward pass for a valid correlation:

  • With border_mode="valid", GpuCorrMM doesn't flip the kernels, but with border_mode="full", it does.
  • With border_mode="valid", subsampling is for subsampling the output, but with border_mode="full", it is for upsampling the input.
  • With border_mode="valid", pad is for padding the input, but with border_mode="full", it is for cropping the output.
  • With border_mode="full", it needs a different memory layout for the kernels.

Currently, GpuCorrMM directly wraps the underlying algorithm, and local_conv_gemm() copes with the different peculiarities, because I wasn't sure whether the GpuCorrMM Op should be inserting dimshuffles, kernel flips and gpu_contiguouses on its own.

Looking at the end result, although it's quite fast, local_conv_gemm() now basically undoes everything that ConvOp.grad() does. It might be possible to write an optimizer that replaces the gradient of a valid convolution with a properly parameterized GpuCorrMM Op, instead of just replacing the full convolution Op that is part of the gradient (introducing redundant dimshuffles etc. on the way). This way we would leverage the caffe implementation better.
The alternative would be to modify the CUDA code to perform a subsampled, padded, full correlation, as would be expected from a Gpu Correlation Op. This would be cleaner, but it would also mean a lot more work and we wouldn't profit as much from caffe's implementation for the case of subsampling != (1,1) or padding != (0,0) (granted, this may be an uninteresting case in practice anyway).
/edit: Another alternative would be splitting this into two Ops: GpuCorrMM for valid correlation with padding and subsampling (the caffe forward pass), and GpuConvMM for full convolution with upsampling and cropping (the caffe backward pass). This way it would be obvious how the operations differ, but the memory layout required for the second Op would still be unintuitive.
Yet another alternative would be splitting it into GpuCorrMM and GpuCorrMM_gradInput. This way it would be obvious how to use GpuCorrMM_gradInput for computing the gradient of GpuCorrMM. We could even add GpuCorrMM_gradKernel for the gradient wrt. weights; it seems caffe does things slightly different there as well.
In both cases, GpuCorrMM should get a grad() method to define its gradient (so it can be used directly in a CNN to avoid kernel flipping and whatnot) and local_conv_gemm() should still be able to replace any GpuConv instances with a gemm-based Op (so it can be used with existing CNN implementations).

@f0k
Contributor
f0k commented Aug 12, 2014

PS: This is on top of #2023, which wasn't merged yet. If you merge #2023, I can rebase this PR on top of master and solve the merge conflicts.

@f0k f0k referenced this pull request Aug 12, 2014
Closed

Continue gemm convolution #2015

5 of 8 tasks complete
@abergeron
Member

It already says that it has conflicts. But I think you should wait until #2023 is merged before doing a rebase.

Apart from that, I think you have the right approach in the optimization. The redundant dimshuffles should get merged or eliminated by the optimizer.

@stencilman
Contributor

Thanks a lot @f0k !! These are the timing spead-up that I get:
after your changes:

Ops
---
<% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Op name>
  89.9%    89.9%      80.250s       2.79e-01s     C      288       18   GpuCorrMM{valid, (1, 1), pad=0}
   5.6%    95.6%       5.022s       2.24e-02s     C      224       14   GpuCorrMM{full, (1, 1), pad=0}

before your changes:

Ops
---
<% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Op name>
  57.1%    57.1%      83.547s       2.90e-01s     C      288       18   GpuCorrMM{valid, (1, 1), pad=(0, 0)}
  39.0%    96.1%      57.025s       1.02e-01s     C      560       28   GpuCorrMM{full, (1, 1), pad=(0, 0)}

Maybe there are redundant ops in my graph or I am doing something very wrong.. Is there a way to fine tune my graph? For me, the fprop is still very good, but the bprop is now like 3x slower than torch7. @f0k: did you try the convnet-benchmark(https://github.com/soumith/convnet-benchmarks)?

@f0k
Contributor
f0k commented Aug 12, 2014

@stencilman: Hmm, so after my change it is about 5 times faster and only called half as often (upper table)? What's the problem then?
I already extended the convnet benchmark to include corrmm in a branch on my fork: https://github.com/f0k/convnet-benchmarks/tree/theano-benchmark-better (I'm postponing the pull request until GpuCorrMM is somewhat more mature). The bprop wrt. inputs was 2-4x faster with the new version. What do you get?
@abergeron: Yes, I'll wait for the merge. And I'll think about splitting the Op as detailed above.

@benanne
Contributor
benanne commented Aug 12, 2014

I have nothing of value to contribute here, but I'd just like to say that this is awesome :)

@madisonmay
Contributor

Seconded -- awesome work so far @f0k, @stencilman, @abergeron.

@abergeron
Member

I barely did anything.

@stencilman
Contributor

So @f0k, in your fork(https://github.com/f0k/convnet-benchmarks/tree/theano-benchmark-better) how does this compare to torch mm? My bprop is 3x slower than torch for some reason...

@f0k f0k referenced this pull request Aug 13, 2014
Merged

Conv gemm non-square kernel support #2023

2 of 2 tasks complete
@f0k
Contributor
f0k commented Aug 15, 2014

So I've rebased the PR onto master (the latest and greatest 9b3ea9e). @stencilman: I did a force push to my corrmm-faster-fullconv branch, if you still have it checked out, you should probably checkout master, delete the branch and check it out anew.

As discussed in #2023, I will refactor this implementation to have a GpuCorrMM op with two gradient ops, similar to the cuda-convnet wrapper in pylearn2. This will allow us to have a fast bprop wrt. weights in addition to the fast bprop wrt. inputs enabled by 121c1d4. So I'd suggest to not merge it yet.

@stencilman
Contributor

@f0k: you are my hero!! Thank you so much, looking forward to hearing from you on this. Really grateful about this :-)

@stencilman
Contributor

@f0k: like i said earlier, do let me know if I can help you in any way to get this in quicker.. actually my project depends on this :-).. Thanks!

@f0k
Contributor
f0k commented Aug 17, 2014

@stencilman: Thanks for your offer, I planned to do it tomorrow (Monday). You can help by benchmarking against Torch again, and possibly with writing the tests... I'll let you know!

@stencilman
Contributor

@f0k: I would love to run it against Torch and write any needed tests. Thanks! 👍

@abergeron
Member

So is this considered finished?

Even if it's a bit slower than Torch, if it's faster than what we have and works correctly, it's good for merging. Unless you want to work more on it.

@f0k
Contributor
f0k commented Aug 18, 2014

@abergeron: Doesn't matter to me -- either you merge it and I'll file a new PR that refactors everything, or you don't merge it yet and I'll add a commit to this PR that refactors everything.

(I have the refactored version almost finished, there's just something flipped again...)

@abergeron
Member

I'll wait for the refactor then.

@f0k
Contributor
f0k commented Aug 18, 2014

OK, the refactored version is up. It passes all the previous tests (test_gemm_valid, test_gemm_full, test_gemm_subsample, test_gemm_directly). The optimizer doesn't use the new backprop wrt. weights yet, but the GpuCorrMM op now has gradients defined so it can be used instead of conv2d (and then the faster backprop wrt. weights will be used).

TODO:

  • Write tests for the gradients of GpuCorrMM. I didn't have time to test them at all yet.
  • Modify the conv_gemm optimizer to choose between GpuCorrMM and GpuCorrMM_gradWeights depending on the input shapes of any valid convolutions to be replaced. (Any suggestions?)

I'll be gone for today, but I'll happily accept pull requests to my pull request for any of the TODO items.

@stencilman
Contributor

@f0k: Wowowow!! I will take a stab at testing the gradients of GpuCorrMM. I will leave TODO 2 for someome with better knowledge of Theano (perhaps @nouiz). I will also report the speadup.

Thanks again, I am really grateful to you!

@f0k f0k and 2 others commented on an outdated diff Aug 18, 2014
theano/sandbox/cuda/opt.py
- kern = gpu_contiguous(kern)
- return [GpuCorrMM(node.op.border_mode, node.op.subsample)(img, kern)]
+ border_mode = node.op.border_mode
+ subsample = node.op.subsample
+ pad = (0,0)
+ if (border_mode == 'full') and (subsample != (1,1)):
+ # need to simulate this via a padded valid convolution
+ pad = 'auto'
+ border_mode = 'valid'
+ if (border_mode == 'valid'):
+ # need to flip the kernel for valid convolution
+ kern = kern[:, :, ::-1, ::-1]
+ # call GpuCorrMM
+ # TODO: call GpuCorrMM_gradWeights instead if appropriate
+ return [GpuCorrMM('valid', subsample, pad)(
+ gpu_contiguous(img), gpu_contiguous(kern))]
@f0k
f0k Aug 18, 2014 Contributor

Any suggestions on when to use GpuCorrMM_gradWeights instead?

@nouiz
nouiz Aug 19, 2014 Member

I'm not convinced we needed to split that into 2 op. Here is some info to help find which one to use:
fprop img shape 32x32, kernel 5x5, output, 28x28
in that case, if we reuse the fprop code, the input will be img is 32x32, kernel 28x28 and output 5x5.

What about, if the kernel is bigger then the output, we use the gradWeights case? The only good way to know this is at run time. That is why I think we should keep both valid case in the same op. But with your code, it should be easy. But before doing this change, we should do some benchmark about my selection of the threshold. @stencilman can you do that? Time GpuCorrMM vs GpuCorr_gradWeights for different input size, including my examples?

@stencilman
stencilman Aug 19, 2014 Contributor

Yes, I can do this now. @nouiz: what exactly do you mean by GpuCorr_gradWeights?

@stencilman
stencilman Aug 19, 2014 Contributor

@nouiz:
torch:

CONFIG: input = 64x32x32 * ker = 64x128x5x5 (bs = 32, stride = 1)
SpatialConvolutionMM:updateOutput(): (tm = 0.0090797543525696)
SpatialConvoltionMM:accGradParameters(): (tm = 0.011068224906921)
SpatialConvolutionMM:updateGradInput(): (tm = 0.0082367658615112)

CONFIG: input = 64x32x32 * ker = 64x128x28x28 (bs = 32, stride = 1)
SpatialConvolutionMM:updateOutput(): (tm = 0.032178223133087)
SpatialConvoltionMM:accGradParameters(): (tm = 0.032377779483795)
SpatialConvolutionMM:updateGradInput(): (tm = 0.02130252122879)

and theano:

CONFIG: input = 64 x 32 x 32 * ker = 64 x 128 x 5 x 5 ( bs = 32 , stride = 1 )
(experimental) theano.sandbox.cuda.blas.CorrMM fprop: 1323.00504803 GFLOP/s ( tm = 0.00776719999313 )
(experimental) theano.sandbox.cuda.blas.CorrMM bprop weights: 0.0 GFLOP/s ( tm = 0.0100939035416 )
(experimental) theano.sandbox.cuda.blas.CorrMM bprop inputs: 0.0 GFLOP/s ( tm = 0.00898323154449 )

CONFIG: input = 64 x 32 x 32 * ker = 64 x 128 x 28 x 28 ( bs = 32 , stride = 1 )
(experimental) theano.sandbox.cuda.blas.CorrMM fprop: 308.976205726 GFLOP/s ( tm = 0.0332583694458 )
(experimental) theano.sandbox.cuda.blas.CorrMM bprop weights: 0.0 GFLOP/s ( tm = 0.0303143196106 )
(experimental) theano.sandbox.cuda.blas.CorrMM bprop inputs: 0.0 GFLOP/s ( tm = 0.0210659198761 )

is this what you were looking for?

@nouiz
nouiz Aug 19, 2014 Member

No, a Theano GpuCorrMM vs vs GpuCorrMM_gradWeight. Both do the same computation, just with different algo.

@f0k
f0k Aug 19, 2014 Contributor

I think the batch size will also play a role, because in one case (fprop), the gemm is used to accumulate over the channels and the for loop to iterate over the minibatch, and in the other case (bprop weights), the for loop is used to accumulate over the channels (= minibatch of the forward pass) and gemm to iterate over the minibatch (= input channels of the forward pass).

I had a benchmark of the two implementations running over night, I will try to figure out a formula to choose between them. A nice aspect is that this may allow us to be faster than caffe/Torch for certain configurations, because in caffe/Torch it's hard-coded which implementation to use in the forward and backward pass.

The problem I see with joining both implementations into one Op is that they require very different memory layouts for the input and output. To compare, these three perform the same computation (copied from test_gemm_directly()):

cpuval = py_conv(npy_img, npy_kern, 'valid', subsample)

op = theano.sandbox.cuda.blas.GpuCorrMM(border_mode='valid',
        subsample=subsample)(i, k)
f = theano.function([i, k], op, mode=theano_mode)
gpuval1 = f(npy_img, npy_kern[:,:,::-1,::-1])

op = theano.sandbox.cuda.blas.GpuCorrMM_gradWeights(border_mode='valid',
        subsample=subsample)(i, k)
f = theano.function([i, k], op, mode=theano_mode)
gpuval2 = numpy.array(f(npy_img.transpose(1, 0, 2, 3),
        npy_kern.transpose(1, 0, 2, 3)[:,:,::-1,::-1])).transpose(1, 0, 2, 3)

# cpuval, gpuval1 and gpuval2 are now approximately equal

If we do the necessary dimshuffles, kernel flippings and gpu_contiguous calls inside the C code of a merged Op, we will probably do unneeded copies. If we do them in Python, we might still do unneeded copies unless we create an optimizer that can merge chains of dimshuffles with gpu_contiguous calls in between. Besides, keeping the code similar to caffe makes it easier to port upstream changes to Theano in case there are any.

So I'd prefer to keep the Ops separate for now, and leave merging them for another PR if that's okay with you.

@f0k
f0k Aug 19, 2014 Contributor

Okay, from my benchmark the magic formula seems to be:

if batchsize * kernelHeight * kernelWidth < inputChannels * outputHeight * outputWidth:
    use GpuCorrMM
else:
    use GpuCorrMM_gradWeights

Unfortunately, even if ConvOp knows about the shapes, GpuConv only knows about image and kernel width and height and the number of input channels, but not about the batchsize. So for now I'll have the optimizer decide based on kernel and output size, following your suggestion -- this gets most cases correct already.

@nouiz
nouiz Aug 19, 2014 Member

We use c_code_helper in a few other op. So I would go with that name.

I'm good with the idea of merging ops later. If this PR is stable and give
more performance then the current master, it can be merged. But I think we
need to figure out how to fix @stencilman bug.

@stencilman, I understand your script is complicated. Can you dump your
Theano function and make a small python scrip that load it and init good
input value to the function that generate the error?

http://www.deeplearning.net/software/theano/tutorial/debug_faq.html?highlight=dump%20function#dumping-a-function-to-help-debug

I suppose the problem happen in the first call to the theano function. Is
that the case?

@f0k, it is easy to add the batch size to the GpuConv op. You just need to
add a parameter to the GpuConv.init, store it as a member, modify
eq and hash to check it and modify the op in cuda/opt.py to pass it
to it. My sentence is long, but it is about 5-10 lines of codes. If you
want me to do it, just ask and I'll do it. That way, you can take into
account the fact when it is set. Just make sure the opt don't crash if only
some of the shapes are know. It is fine to do it only if we know all shapes.

thanks

On Tue, Aug 19, 2014 at 10:15 AM, Jan Schlüter notifications@github.com
wrote:

In theano/sandbox/cuda/opt.py:

  •    kern = gpu_contiguous(kern)
    
  •    return [GpuCorrMM(node.op.border_mode, node.op.subsample)(img, kern)]
    
  •    border_mode = node.op.border_mode
    
  •    subsample = node.op.subsample
    
  •    pad = (0,0)
    
  •    if (border_mode == 'full') and (subsample != (1,1)):
    
  •        # need to simulate this via a padded valid convolution
    
  •        pad = 'auto'
    
  •        border_mode = 'valid'
    
  •    if (border_mode == 'valid'):
    
  •        # need to flip the kernel for valid convolution
    
  •        kern = kern[:, :, ::-1, ::-1]
    
  •        # call GpuCorrMM
    
  •        # TODO: call GpuCorrMM_gradWeights instead if appropriate
    
  •        return [GpuCorrMM('valid', subsample, pad)(
    
  •                gpu_contiguous(img), gpu_contiguous(kern))]
    

Okay, from my benchmark the magic formula seems to be:

if batchsize * kernelHeight * kernelWidth < inputChannels * outputHeight * outputWidth:
use GpuCorrMMelse:
use GpuCorrMM_gradWeights

Unfortunately, even if ConvOp knows about the shapes, GpuConv only knows
about image and kernel width and height and the number of input channels,
but not about the batchsize. So for now I'll have the optimizer decide
based on kernel and output size, following your suggestion -- this gets
most cases correct already.


Reply to this email directly or view it on GitHub
https://github.com/Theano/Theano/pull/2033/files#r16416429.

@f0k f0k commented on the diff Aug 18, 2014
theano/sandbox/cuda/tests/test_conv_cuda_ndarray.py
@@ -186,7 +186,7 @@ def _params_allgood(ishape, kshape, mode, subsample=(1, 1), img_stride=(1, 1),
f = theano.function([i, k], op, mode=theano_mode)
if cls is not None:
assert any([isinstance(node.op, cls)
- for node in f.maker.fgraph.toposort()]), f.maker.fgraph.toposort()
+ for node in f.maker.fgraph.toposort()]), "Cannot find class %r in %r" % (cls, f.maker.fgraph.toposort())
@f0k
f0k Aug 18, 2014 Contributor

More enlightening printout on why the test failed.

@f0k f0k commented on the diff Aug 18, 2014
theano/sandbox/cuda/tests/test_conv_cuda_ndarray.py
@@ -284,7 +284,7 @@ def exec_conv(version, shapes, verbose, random, mode,
cls=cls)
except Exception, e:
print ver, id, (ishape, kshape, subshape, istride, kstride)
- print e
+ print "Exception", type(e), e
@f0k
f0k Aug 18, 2014 Contributor

Again, a bit more verbose information on why the test failed.

@f0k
Contributor
f0k commented Aug 18, 2014

Added tests for GpuCorrMM_gradWeights and GpuCorrMM_gradInputs to test_gemm_directly, and fixed a mistake in GpuCorrMM_gradWeights. Still missing: A test_gemm_grads test which tests if the gradients of GpuCorrMM are actually the correct gradients. I'll see if I can do it before I really leave, otherwise I'll leave it open for tomorrow or for somebody else!

@f0k
Contributor
f0k commented Aug 18, 2014

I added a test for the gradients (test_gemm_grads). It passes for mode="valid", subsample=(1,1) and pad=(0,0), so if you don't use subsampling or padding, the GpuCorrMM()(img, gpu_contiguous(kern[:,:,::-1,::-1])) from this PR can already serve as a replacement for conv2d(img, kern). Note that this is currently required to get a speedup over the automatic optimization.

Additional TODOs:

  • GpuCorrMM needs to pass the input shape to GpuCorrMM_gradInputs and the weight shape to GpuCorrMM_gradWeights to make it work for all subsampling values. Currently, the shape is inferred incorrectly if the output image size is not divisible by the subsampling size.
  • At least one of the gradient ops fails for pad="auto".
  • GpuCorrMM_gradInputs and GpuCorrMM_gradWeights should probably get a grad() function as well.
@stencilman
Contributor
X CONFIG: input = 64x64x64 * ker = 64x128x9x9 (bs = 128, stride = 1) CONFIG: input = 128x32x32 * ker = 128x128x9x9 (bs = 128, stride = 1) CONFIG: input = 3x128x128 * ker = 3x96x11x11 (bs = 128, stride = 1) CONFIG: input = 384x13x13 * ker = 384x384x3x3 (bs = 128, stride = 1) CONFIG: input = 128x16x16 * ker = 128x128x7x7 (bs = 128, stride = 1)
torch fprop 0.24305301904678 0.16899198293686 0.10639077425003 0.055703997612 0.041813731193
GpuCorrMM fprop 0.225996734619 0.159330703735 0.0936857147217 0.0532916145325 0.04246251297
torch grad wrt input 0.30804550647736 0.13152647018433 0.091511249542 0.02750152349472 0.02301424741745
GpuCorrMM bprop wrt inputs 0.288951873779 0.121228462219 0.086506401062 0.0258434085846 0.0209931201935
torch grad wrt weights 0.37771952152252 0.1442049741745 0.18967247009277 0.03774124383 0.023773729801
GpuCorrMM bprop wrt weights 0.352622833252 0.132992980957 0.169927322388 0.0339831695557 0.0214208164215

[]

@benanne
Contributor
benanne commented Aug 18, 2014

very nice :) @f0k out of curiosity, have you tried using this for 1D convolutions? If so, how's the performance? Would it make sense to make a specialized 1D version of this, or is the 1D-using-2D approach good enough?

@f0k
Contributor
f0k commented Aug 18, 2014

@stencilman: Cool, thanks for testing! Seems we're on a good way here!
@benanne: I haven't tried, but the theano benchmark in convnet-benchmarks accepts a custom config on the command line, so you can time the 1D-using-2D approach (but the bprop wrt. weights will not be fast yet). If by 1D convolution you refer to your approach of shifting a 2D patch over a spectrogram in the time dimension only, then I don't see a way to make it faster. If you refer to convolving a vector with another vector, then... hmm... do you still have batches? Multiple input and output channels?

@benanne
Contributor
benanne commented Aug 18, 2014

Yeah, I meant the former. I have a suspicion that for the CorrMM-approach to computing convolutions, the 1D-using-2D approach is automatically the best thing you can do, which is why I'm interested to find out how it performs :) If I find some time I guess I'll try it for myself using the benchmark.

@stencilman
Contributor

During backprop, I get the error bellow when I use GpuCorrMM directly. however when I use optimizer=fast_compile option, I do not get this error.

Error when tring to find the memory information on the GPU: an illegal memory access was encountered
Error freeing device pointer 0x1116d80000 (an illegal memory access was encountered). Driver report 0 bytes free and 0 bytes total 
CudaNdarray_uninit: error freeing self->devdata. (self=0xd7614f0, self->devata=0x1116d80000)
Traceback (most recent call last):
  File "deep_nets.py", line 71, in <module>
    model.trainEpoch(data.tr, conf, epoch)
  File "/home/ajain/Projects/deep_nets/python/lib/machine.py", line 344, in trainEpoch
    err = self.train_model()
  File "/home/ajain/Theano/theano/compile/function_module.py", line 589, in __call__
    self.fn.thunks[self.fn.position_of_error])
  File "/home/ajain/Theano/theano/compile/function_module.py", line 579, in __call__
    outputs = self.fn()
RuntimeError: Cuda error: GpuElemwise node_600ed41afec07a555b0c58a139e2f767_0 Mul: an illegal memory access was encountered.
    n_blocks=30 threads_per_block=256
   Call: kernel_Mul_node_600ed41afec07a555b0c58a139e2f767_0_Ccontiguous<<<n_blocks, threads_per_block>>>(numEls, i0_data, i1_data, o0_data)

Apply node that caused the error: GpuElemwise{Mul}[(0, 1)](GpuElemwise{Composite{[Cast{float32}(EQ(i0, i1))]},no_inplace}.0, GpuCorrMM_gradInputs{valid, (1, 1), pad=(0, 0)}.0)
Inputs types: [CudaNdarrayType(float32, 4D), CudaNdarrayType(float32, 4D)]
Inputs shapes: [(16, 256, 90, 90), (16, 256, 90, 90)]
Inputs strides: [(2073600, 8100, 90, 1), (2073600, 8100, 90, 1)]
Inputs scalar values: ['not scalar', 'not scalar']

HINT: Re-running with most Theano optimization disabled could give you a back-traces when this node was created. This can be done with by setting the Theano flags optimizer=fast_compile
HINT: Use the Theano flag 'exception_verbosity=high' for a debugprint of this apply node.

when I run with exception_verbosity=high I get the error bellow.

Error when tring to find the memory information on the GPU: an illegal memory access was encountered
Error freeing device pointer 0x1116d80000 (an illegal memory access was encountered). Driver report 0 bytes free and 0 bytes total 
CudaNdarray_uninit: error freeing self->devdata. (self=0xcfd32f0, self->devata=0x1116d80000)
Traceback (most recent call last):
  File "deep_nets.py", line 71, in <module>
    model.trainEpoch(data.tr, conf, epoch)
  File "/home/ajain/Projects/deep_nets/python/lib/machine.py", line 344, in trainEpoch
    err = self.train_model()
  File "/home/ajain/Theano/theano/compile/function_module.py", line 589, in __call__
    self.fn.thunks[self.fn.position_of_error])
  File "/home/ajain/Theano/theano/gof/link.py", line 164, in raise_with_op
    print_type=True)
  File "/home/ajain/Theano/theano/printing.py", line 104, in debugprint
    stop_on_name=stop_on_name)
  File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
    prefix_child=new_prefix_child)
  File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
    prefix_child=new_prefix_child)
  File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
    prefix_child=new_prefix_child)
  File "/home/ajain/Theano/theano/compile/debugmode.py", line 613, in debugprint
    print >> file, '%s%s %s%s' % (prefix, r, id_str, type_str)
  File "/home/ajain/Theano/theano/sandbox/cuda/var.py", line 52, in __str__
    return "CudaNdarrayConstant{"+str(numpy.asarray(self.data))+"}"
  File "/usr/local/lib/python2.7/dist-packages/numpy/core/numeric.py", line 460, in asarray
    return array(a, dtype, copy=False, order=order)
RuntimeError: error copying data to host

Any idea what might be wrong?

@stencilman
Contributor

@nouiz @f0k : is it because i am missing any gpu_contiguous? I added it everywhere but I still seem to have this problem.. Any help will be great, thanks!

Update: It works when I use a smaller batch size.

@nouiz
Member
nouiz commented Aug 19, 2014

Can you provide code that reproduce the problem? I have a presentation this
Friday and just started to do it. So I won't have much time.

I would be surprised if the gpu_contiguous cause the problems. But I didn't
check the code of this PR yet. I'll try to check it tonight.

Can you report the error when you use the environment variable
CUDA_LAUNCH_BLOCKING=1. By default gpu kernel run async. So it is highly
possible that the error reported isn't the right place where it happen.

On Mon, Aug 18, 2014 at 8:09 PM, Arjun Jain notifications@github.com
wrote:

@nouiz https://github.com/nouiz @f0k https://github.com/f0k : is it
because i am missing any gpu_contiguous? I added it everywhere but I still
seem to have this problem.. Any help will be great, thanks!


Reply to this email directly or view it on GitHub
#2033 (comment).

@nouiz nouiz and 1 other commented on an outdated diff Aug 19, 2014
theano/sandbox/cuda/blas.py
@@ -591,55 +569,98 @@ def c_support_code_apply(self, node, nodename):
for f in files]
return reduce(str.__add__, codes)
- def c_code(self, node, nodename, inp, out_, sub):
- img, kern = inp
- out, = out_
- dx = self.subsample[0]
- dy = self.subsample[1]
- sub = sub.copy()
- pad = self.pad
- if self.border_mode == "valid":
- bmode = 1
+ def c_code(self, bottom, weights, top, direction, sub):
@nouiz
nouiz Aug 19, 2014 Member

This fct don't use the same signature as normal c_code. Can you rename it to make this clear?

@nouiz
nouiz Aug 19, 2014 Member

I have difficulties to remember what bottom and top mean. Can you document that? bottom and top can have different meaning depending of how you see the graph.

@f0k
f0k Aug 19, 2014 Contributor

Renaming to shared_c_code or c_code_helper would be fine?

About the names, I took them from caffe as they are less ambiguous than inputs and outputs. I'll document them.

@stencilman
Contributor

If I use allow_gc=False, it works, but then it uses a lot of memory and I cant use a big model.

Please find bellow the error with CUDA_LAUNCH_BLOCKING=1.

Error when tring to find the memory information on the GPU: an illegal memory access was encountered
Error freeing device pointer 0x1162160000 (an illegal memory access was encountered). Driver report 0 bytes free and 0 bytes total
device_free: cudaFree() returned an error, but there is already an Python error set. This happen during the clean up when there is a first error and the CUDA dr
iver is in a so bad state that it don't work anymore. We keep the previous error set to help debugging it.CudaNdarray_uninit: error freeing self->devdata. (self
=0xda12730, self->devata=0x1162160000)
Traceback (most recent call last):
  File "deep_nets.py", line 71, in <module>
    model.trainEpoch(data.tr, conf, epoch)
  File "/home/ajain/Projects/deep_nets/python/lib/machine.py", line 344, in trainEpoch
    err = self.train_model()
  File "/home/ajain/Theano/theano/compile/function_module.py", line 589, in __call__
    self.fn.thunks[self.fn.position_of_error])
  File "/home/ajain/Theano/theano/gof/link.py", line 164, in raise_with_op
    print_type=True)
  File "/home/ajain/Theano/theano/printing.py", line 104, in debugprint
    stop_on_name=stop_on_name)
  File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
    prefix_child=new_prefix_child)
  File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
    prefix_child=new_prefix_child)
  File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
    prefix_child=new_prefix_child)
  File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
    prefix_child=new_prefix_child)
  File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
    prefix_child=new_prefix_child)
  File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
    prefix_child=new_prefix_child)
  File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
    prefix_child=new_prefix_child)
  File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
    prefix_child=new_prefix_child)
  File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
    prefix_child=new_prefix_child)
  File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
    prefix_child=new_prefix_child)
  File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
    prefix_child=new_prefix_child)
  File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
    prefix_child=new_prefix_child)
  File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
    prefix_child=new_prefix_child)
  File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
    prefix_child=new_prefix_child)
  File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
    prefix_child=new_prefix_child)
  File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
    prefix_child=new_prefix_child)
  File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
    prefix_child=new_prefix_child)
  File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
    prefix_child=new_prefix_child)
  File "/home/ajain/Theano/theano/compile/debugmode.py", line 613, in debugprint
    print >> file, '%s%s %s%s' % (prefix, r, id_str, type_str)
  File "/home/ajain/Theano/theano/sandbox/cuda/var.py", line 52, in __str__
    return "CudaNdarrayConstant{"+str(numpy.asarray(self.data))+"}"
  File "/usr/local/lib/python2.7/dist-packages/numpy/core/numeric.py", line 460, in asarray
    return array(a, dtype, copy=False, order=order)
RuntimeError: error copying data to host
@nouiz
Member
nouiz commented Aug 19, 2014

Remove exception_verbose something. It cause a second error that hide the
real error.

On Mon, Aug 18, 2014 at 9:32 PM, Arjun Jain notifications@github.com
wrote:

If I use allow_gc=False, it works, but then it uses a lot of memory and I
cant use a big model.

Please find bellow the error with CUDA_LAUNCH_BLOCKING=1.

Error when tring to find the memory information on the GPU: an illegal memory access was encountered
Error freeing device pointer 0x1162160000 (an illegal memory access was encountered). Driver report 0 bytes free and 0 bytes total
device_free: cudaFree() returned an error, but there is already an Python error set. This happen during the clean up when there is a first error and the CUDA dr
iver is in a so bad state that it don't work anymore. We keep the previous error set to help debugging it.CudaNdarray_uninit: error freeing self->devdata. (self
=0xda12730, self->devata=0x1162160000)
Traceback (most recent call last):
File "deep_nets.py", line 71, in
model.trainEpoch(data.tr, conf, epoch)
File "/home/ajain/Projects/deep_nets/python/lib/machine.py", line 344, in trainEpoch
err = self.train_model()
File "/home/ajain/Theano/theano/compile/function_module.py", line 589, in call
self.fn.thunks[self.fn.position_of_error])
File "/home/ajain/Theano/theano/gof/link.py", line 164, in raise_with_op
print_type=True)
File "/home/ajain/Theano/theano/printing.py", line 104, in debugprint
stop_on_name=stop_on_name)
File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
prefix_child=new_prefix_child)
File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
prefix_child=new_prefix_child)
File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
prefix_child=new_prefix_child)
File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
prefix_child=new_prefix_child)
File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
prefix_child=new_prefix_child)
File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
prefix_child=new_prefix_child)
File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
prefix_child=new_prefix_child)
File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
prefix_child=new_prefix_child)
File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
prefix_child=new_prefix_child)
File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
prefix_child=new_prefix_child)
File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
prefix_child=new_prefix_child)
File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
prefix_child=new_prefix_child)
File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
prefix_child=new_prefix_child)
File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
prefix_child=new_prefix_child)
File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
prefix_child=new_prefix_child)
File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
prefix_child=new_prefix_child)
File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
prefix_child=new_prefix_child)
File "/home/ajain/Theano/theano/compile/debugmode.py", line 609, in debugprint
prefix_child=new_prefix_child)
File "/home/ajain/Theano/theano/compile/debugmode.py", line 613, in debugprint
print >> file, '%s%s %s%s' % (prefix, r, id_str, type_str)
File "/home/ajain/Theano/theano/sandbox/cuda/var.py", line 52, in str
return "CudaNdarrayConstant{"+str(numpy.asarray(self.data))+"}"
File "/usr/local/lib/python2.7/dist-packages/numpy/core/numeric.py", line 460, in asarray
return array(a, dtype, copy=False, order=order)
RuntimeError: error copying data to host


Reply to this email directly or view it on GitHub
#2033 (comment).

@nouiz
Member
nouiz commented Aug 19, 2014

Can you give me the script that generate the error and how to use it?

@stencilman
Contributor

Ah, much better error! It is a big project, you will have to have my data etc to run my script. I will try to generate a minimal example if needed. Do you know what the error might mean?

Error when tring to find the memory information on the GPU: an illegal memory access was encountered
Error freeing device pointer 0x1162160000 (an illegal memory access was encountered). Driver report 0 bytes free and 0 bytes total 
device_free: cudaFree() returned an error, but there is already an Python error set. This happen during the clean up when there is a first error and the CUDA driver is in a so bad state that it don't work anymore. We keep the previous error set to help debugging it.CudaNdarray_uninit: error freeing self->devdata. (self=0xd10f1f0, self->devata=0x1162160000)
Traceback (most recent call last):
  File "deep_nets.py", line 71, in <module>
    model.trainEpoch(data.tr, conf, epoch)
  File "/home/ajain/Projects/deep_nets/python/lib/machine.py", line 344, in trainEpoch
    err = self.train_model()
  File "/home/ajain/Theano/theano/compile/function_module.py", line 589, in __call__
    self.fn.thunks[self.fn.position_of_error])
  File "/home/ajain/Theano/theano/compile/function_module.py", line 579, in __call__
    outputs = self.fn()
RuntimeError: GpuCorrMM encountered a CUBLAS error: the function failed to launch on the GPU

Apply node that caused the error: GpuCorrMM_gradInputs{valid, (1, 1), pad=(0, 0)}(W_18, GpuContiguous.0)
Inputs types: [CudaNdarrayType(float32, 4D), CudaNdarrayType(float32, 4D)]
Inputs shapes: [(4, 256, 1, 1), (16, 4, 60, 90)]
Inputs strides: [(256, 1, 0, 0), (21600, 5400, 90, 1)]
Inputs scalar values: ['not scalar', 'not scalar']

Backtrace when the node is created:
  File "/home/ajain/Theano/theano/gradient.py", line 895, in access_term_cache
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/home/ajain/Theano/theano/gradient.py", line 1173, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/home/ajain/Theano/theano/gradient.py", line 1034, in access_term_cache
    input_grads = node.op.grad(inputs, new_output_grads)
  File "/home/ajain/Theano/theano/sandbox/cuda/blas.py", line 746, in grad
    weights, top)

HINT: Use the Theano flag 'exception_verbosity=high' for a debugprint of this apply node.
@nouiz
Member
nouiz commented Aug 19, 2014

@stencilman the new error message seem to indicate the graph() method that introduce the GpuCorrMM_gradInput op don't do it correctly, or that this case isn't well handled in the c code. @f0k can you check the code that call cublas in that case? I don't know well enough that code to do this relatively rapidly.

@stencilman
Contributor

@f0k @nouiz: Do let me know if you know why I have the error. Thank you.

@f0k
Contributor
f0k commented Aug 19, 2014

@stencilman: From your error message, it seems you have a kernel shape of 4x256x1x1 and an output shape of 16x4x60x90. This would be an input shape of 16x256x60x90. In other words, a batchsize of 16, 256 input channels, image size 60x90, 4 output channels (filters), kernel size 1x1. Is this correct?
When I run pylearn2_benchmark.py with this, it works fine, and much faster than both the standard and fft convolution: python pylearn2_benchmark.py i256x60x90,k4x1x1,b16
Maybe it really was a memory problem not tracked correctly because of missing error checks. Please try again with the current version to see if the error message changes.

@nouiz: Thanks for reviewing, I extended the documentation and added error checks as suggested. Also the optimizer can insert all three ops now depending on border mode and shapes (if shapes were provided to conv2d by the user). The gradients of GpuCorrMM are still wrong if subsampling or padding is used; I'll fix that later.

@stencilman
Contributor

@f0k, @nouiz: Unfortunately I still get the same error. I will repeat, I dont get the error when use allow_gc=False. It also works when I turn off optimization. @f0k: yes, the sizes where it seems to fail is correct. I have a feeling that it screws up in a more complicated network due to the theano optimization some how. I have put the dump of the function which fails as @nouiz suggested here (http://vajra.cs.nyu.edu/train.fun). I can load the function as

import cPickle as pickle
train = pickle.load(open('train.fun', 'rb'))

and train looks like as bellow. It doesnt take any arguments. @nouiz: How can I call this function?

In [6]: train
Out[6]:
{'accept_inplace': False,
 'allow_input_downcast': None,
 'givens': {y: <CudaNdarrayType(float32, 4D)>,
  x_1: <CudaNdarrayType(float32, 4D)>,
  x_0: <CudaNdarrayType(float32, 4D)>,
  x_2: <CudaNdarrayType(float32, 4D)>},
 'inputs': [],
 'mode': None,
 'name': None,
 'no_default_updates': False,
'on_unused_input': None,
 'outputs': [Elemwise{true_div,no_inplace}.0],
 'profile': None,
 'rebuild_strict': True,
 'updates': [(<CudaNdarrayType(float32, 4D)>, Elemwise{add,no_inplace}.0),
  (W_111, Elemwise{sub,no_inplace}.0),
  (<CudaNdarrayType(float32, vector)>, Elemwise{add,no_inplace}.0),
  (b_111, Elemwise{sub,no_inplace}.0),
  (<CudaNdarrayType(float32, 4D)>, Elemwise{add,no_inplace}.0),
  (W_110, Elemwise{sub,no_inplace}.0),
  (<CudaNdarrayType(float32, vector)>, Elemwise{add,no_inplace}.0),
  (b_110, Elemwise{sub,no_inplace}.0),
  (<CudaNdarrayType(float32, 4D)>, Elemwise{add,no_inplace}.0),
  (W_113, Elemwise{sub,no_inplace}.0),
  (<CudaNdarrayType(float32, vector)>, Elemwise{add,no_inplace}.0),
  (b_113, Elemwise{sub,no_inplace}.0),
  (<CudaNdarrayType(float32, 4D)>, Elemwise{add,no_inplace}.0),
  (W_112, Elemwise{sub,no_inplace}.0),
  (<CudaNdarrayType(float32, vector)>, Elemwise{add,no_inplace}.0),
  (b_112, Elemwise{sub,no_inplace}.0),
  (<CudaNdarrayType(float32, 4D)>, Elemwise{add,no_inplace}.0),
  (W_3, Elemwise{sub,no_inplace}.0),
  (<CudaNdarrayType(float32, vector)>, Elemwise{add,no_inplace}.0),
  (b_3, Elemwise{sub,no_inplace}.0),
  (<CudaNdarrayType(float32, 4D)>, Elemwise{add,no_inplace}.0),
  (W_5, Elemwise{sub,no_inplace}.0),
  (<CudaNdarrayType(float32, vector)>, Elemwise{add,no_inplace}.0),
  (b_5, Elemwise{sub,no_inplace}.0),
  (<CudaNdarrayType(float32, 4D)>, Elemwise{add,no_inplace}.0),
  (W_4, Elemwise{sub,no_inplace}.0),
  (<CudaNdarrayType(float32, vector)>, Elemwise{add,no_inplace}.0),
  (b_4, Elemwise{sub,no_inplace}.0),
  (<CudaNdarrayType(float32, 4D)>, Elemwise{add,no_inplace}.0),
  (W_6, Elemwise{sub,no_inplace}.0),
  (<CudaNdarrayType(float32, vector)>, Elemwise{add,no_inplace}.0),
  (b_6, Elemwise{sub,no_inplace}.0),
  (<CudaNdarrayType(float32, 4D)>, Elemwise{add,no_inplace}.0),
  (W_11, Elemwise{sub,no_inplace}.0),
  (<CudaNdarrayType(float32, vector)>, Elemwise{add,no_inplace}.0),
  (b_11, Elemwise{sub,no_inplace}.0),
  (<CudaNdarrayType(float32, 4D)>, Elemwise{add,no_inplace}.0),
  (W_10, Elemwise{sub,no_inplace}.0),
  (<CudaNdarrayType(float32, vector)>, Elemwise{add,no_inplace}.0),
  (b_10, Elemwise{sub,no_inplace}.0),
  (<CudaNdarrayType(float32, 4D)>, Elemwise{add,no_inplace}.0),
  (W_13, Elemwise{sub,no_inplace}.0),
  (<CudaNdarrayType(float32, vector)>, Elemwise{add,no_inplace}.0),
  (b_13, Elemwise{sub,no_inplace}.0),
  (<CudaNdarrayType(float32, 4D)>, Elemwise{add,no_inplace}.0),
  (W_12, Elemwise{sub,no_inplace}.0),
  (<CudaNdarrayType(float32, vector)>, Elemwise{add,no_inplace}.0),
  (b_12, Elemwise{sub,no_inplace}.0),
  (<CudaNdarrayType(float32, 4D)>, Elemwise{add,no_inplace}.0),
  (W_16, Elemwise{sub,no_inplace}.0),
  (<CudaNdarrayType(float32, vector)>, Elemwise{add,no_inplace}.0),
  (b_16, Elemwise{sub,no_inplace}.0),
  (<CudaNdarrayType(float32, 4D)>, Elemwise{add,no_inplace}.0),
  (W_18, Elemwise{sub,no_inplace}.0),
  (<CudaNdarrayType(float32, vector)>, Elemwise{add,no_inplace}.0),
  (b_18, Elemwise{sub,no_inplace}.0)]}
@stencilman
Contributor

when i save this function and load it, it seems to work.. i have shared a repository with you if it helps..

@f0k
Contributor
f0k commented Aug 20, 2014

Update: shapes are now propagated from GpuCorrMM to its gradients if needed (and only if needed). pad="auto" has been renamed to pad="full", also I've added pad="half" to simulate "same" convolution (due to the symmetric padding, this only does a "same" convolution for uneven kernel heights and widths). The previous bug with pad="auto" in the weight gradient has been fixed (the kernel shape was not inferred correctly with auto-padding). It passes test_gemm_{valid,full,directly,grads}, but the latter takes very long as it's comparing results to conv2d (assuming that its gradients are well-tested and correct).

I think this is ready to be merged now, except that we could add grad() methods to the two gradients. The pylearn2 cuda-convnet wrapper has this for the gradient wrt. inputs, but it's untested, and it may never be used. We could just leave it for a separate PR, and also leave a possible merged valid convolution op for a separate PR. What do you think?

@stencilman: Please check if your problem magically disappears with the new version.

@stencilman
Contributor

Thanks a lot @f0k. I am checking now and will report in 15 min.

@stencilman
Contributor

Hey hey hey!!! Good news, it does magically work for me now. Why??!

Bellow you will see some results of running my network with profile=True. For some reason, Theano is 2x slower than the torch verison, I will investigate more why. Thanks for adding pad='half' :-).

Ops
---
<% time> <sum %> <apply time> <time per call> <type> <#call> <#apply> <Op name>
  20.3%    20.3%       2.297s       3.19e-02s     C       72        9   GpuCorrMM{valid, (1, 1), pad=(4, 4)}
  16.9%    37.2%       1.917s       7.99e-02s     C       24        3   GpuCorrMM_gradWeights{valid, (1, 1), pad=(4, 4)}
  16.0%    53.2%       1.813s       7.55e-02s     C       24        3   GpuCorrMM_gradInputs{valid, (1, 1), pad=(4, 4)}
  15.5%    68.7%       1.760s       2.44e-02s     C       72        9   GpuCorrMM_gradWeights{valid, (1, 1), pad=(2, 2)}
  10.8%    79.5%       1.221s       2.54e-02s     C       48        6   GpuCorrMM_gradInputs{valid, (1, 1), pad=(2, 2)}
   9.5%    89.0%       1.078s       1.50e-02s     C       72        9   GpuCorrMM{valid, (1, 1), pad=(2, 2)}
   1.1%    90.1%       0.125s       7.78e-03s     C       16        2   GpuCorrMM_gradInputs{valid, (1, 1), pad=(0, 0)}
   1.1%    91.1%       0.121s       7.55e-03s     C       16        2   GpuCorrMM_gradWeights{valid, (1, 1), pad=(0, 0)}
   1.0%    92.1%       0.108s       6.73e-03s     C       16        2   GpuCorrMM{valid, (1, 1), pad=(0, 0)}
   0.8%    92.9%       0.088s       1.10e-03s     C       80       10   GpuFromHost
@f0k
Contributor
f0k commented Aug 20, 2014

Good news, it does magically work for me now. Why??!

Don't ask, it's magic.

For some reason, Theano is 2x slower than the torch verison, I will investigate more why.

Please check if this is also true without padding. Do you know if Torch uses the padding parameter for "same" convolution or some other trick?

@stencilman
Contributor

In torch, we use the same code, and we just pass in the parameter pad exposed by the cuda funciton. So, it should be exactly the same.

@f0k
Contributor
f0k commented Aug 20, 2014

But in the convnet-benchmarks, Theano performs the same as Torch now, so it can't be the convolution. Did you try with allow_gc=0?

@stencilman
Contributor

Yes, you are absolutely right there. It can not be the convolution. I did try with allow_gc=False, it hardly changes the timings.

@nouiz: After your presentation is done, I will be grateful if you could give me any tips how can I check this.

I think this awesome fast convolution is ready to be merged, thanks to dear @f0k!! 👍

UPDATE: I have fixed my problem regarding the slowness as compared to torch. joblib was slowing it down. I cant thank you @f0k enough for this, it really helps me a lot!

@lamblin lamblin and 3 others commented on an outdated diff Aug 20, 2014
theano/sandbox/cuda/conv_gemm.cu
+ weight + weight_offset * g, K_,
+ 0.,
+ col_diff + col_offset * g, N_);
+ }
+ // col2im back to the data
+ col2im_gpu(col_diff, channels_, height_, width_,
+ kernel_h_, kernel_w_, pad_h_, pad_w_, stride_h_, stride_w_,
+ bottom_diff + (*bottom)[i]->offset(n));
+ }
+ }
+ */
+ }
+ // Free temporary columns
+ Py_DECREF(col);
+
+ return output;
@lamblin
lamblin Aug 20, 2014 Member

I'm pretty sure output needs to be Py_INCREFed at some point, since we are creating a new reference to an existing variable.
That may explain https://groups.google.com/forum/#!topic/theano-users/HuWy3DNh8lI

@stencilman
stencilman Aug 20, 2014 Contributor

Great catch @lamblin!! Thanks.

@f0k
f0k Aug 20, 2014 Contributor

Refcounting happens in the C code in blas.py. If the output has the correct shape, its refcount is not modified. If it has a wrong shape, its refcount is decreased and we obtain a new array with CudaNdarray_NewDims. We don't touch its refcount.

The net effect is equivalent to how GpuConv does it: If the output has the correct shape, its refcount is increased, otherwise its refcount is left as it is and a new array is created with CudaNdarray_NewDims. In the end, the refcount of the old output array is decreased (no matter whether it was reused or not; so if it was reused, we have refcount+1-1, otherwise we have refcount-1).
There is a TODO marker to decrease the refcount of the old array before allocating the new one in case the shape is not suitable, i.e., to do it like in GpuCorrMM.

So I'm moderately sure we would create a memory leak if we increased the refcount. I'm not exactly sure what could have fixed the error @stencilman was seeing, though. I fixed a wrong shape inference for the backprop wrt. weights in case of pad="auto", but I think he always used explicit padding by value anyway.

@nouiz
nouiz Aug 20, 2014 Member

If we incref here, we should decref it after the call to corrMM. Otherwise, the output of the node will have a refcount of 2 after the node run.

@nouiz
nouiz Aug 20, 2014 Member

Doing the incref/decref is more clean. This is what the numpy interface does.

@lamblin
lamblin Aug 20, 2014 Member

This looks correct indeed, this should not be a problem. My bad.

@f0k
f0k Aug 20, 2014 Contributor

@lamblin: No problem. Please feel free to review and question everything! I'll push another commit adding the gradients of gradients tomorrow (I wanted to make sure the tests run through, but the reference implementation is awfully slow, so this will take a while), and then I'll do a rebase onto master again, but the bulk of the code is finished!
@nouiz: To me an additional incref/decref would be rather confusing. In my mind, output is not a new reference, it's just set to one of the existing bottom, weights or top so we don't have to distinguish between the cases again. But if numpy does it, we can too... so you would suggest an incref of output, and then a decref of out2 here?

@f0k
Contributor
f0k commented Aug 20, 2014

Okay, the new test_gemm_grads tests for gradients of gradients passed on two machines, but they take 2-3 hours of compiling the reference conv2d implementations I'm comparing results with. The test should probably be restricted to a small collection of shapes for the future; any suggestions?

Anyway, everything is pushed and rebased onto master. Thanks in advance for reviewing! :)

@nouiz
Member
nouiz commented Aug 21, 2014

The reference implementation is what? The Conv2d op? scipy.convolve? python
loop? If it is one of the last 2, using the Conv2d op would speed it up.

For the incref/decref, I don't think you need to change it, as it isn't
used by other code. But if that happen, it could be worthwhile. Maybe add a
comment about this in the corrMM C function?

On Wed, Aug 20, 2014 at 7:10 PM, Jan Schlüter notifications@github.com
wrote:

Okay, the new test_gemm_grads tests for gradients of gradients passed on
two machines, but they take 2-3 hours of compiling the reference conv2d
implementations I'm comparing results with. The test should probably be
restricted to a small collection of shapes for the future; any suggestions?

Anyway, everything is pushed and rebased onto master. Thanks in advance
for reviewing! :)


Reply to this email directly or view it on GitHub
#2033 (comment).

@f0k
Contributor
f0k commented Aug 21, 2014

The reference implementation is what? The Conv2d op? scipy.convolve? python
loop? If it is one of the last 2, using the Conv2d op would speed it up.

The reference implementation is the Conv2d op, and this is actually what slows it down, because it has to compile it (along with two gradients and two gradient-gradients) for every call (plus it switches to Conv3d for strided gradients). The python convolution would be a lot faster, but then I would have to manually define how to compute the gradients and gradient-gradients, and I wanted to rely on a well-tested reference implementation for that. So I think test_gemm_grads should use a short well-chosen list of shapes rather than the nested for loop I copied from test_subsample. I just find the whole get_basic_shapes, get_shapes, get_shapes2 and get_valid_shapes thing in the convolution test very confusing, maybe someone has a recommendation on what to use.

@nouiz
Member
nouiz commented Aug 21, 2014

Do the shapes are passed to the conv2d implementation? If so, this could be
removed and would cause less compilation.

On Thu, Aug 21, 2014 at 5:54 AM, Jan Schlüter notifications@github.com
wrote:

The reference implementation is what? The Conv2d op? scipy.convolve? python
loop? If it is one of the last 2, using the Conv2d op would speed it up.

The reference implementation is the Conv2d op, and this is actually what
slows it down, because it has to compile it (along with two gradients and
two gradient-gradients) for every call (plus it switches to Conv3d for
strided gradients). The python convolution would be a lot faster, but then
I would have to manually define how to compute the gradients and
gradient-gradients, and I wanted to rely on a well-tested reference
implementation for that. So I think test_gemm_grads should use a short
well-chosen list of shapes rather than the nested for loop I copied from
test_subsample. I just find the whole get_basic_shapes, get_shapes,
get_shapes2 and get_valid_shapes thing in the convolution test very
confusing, maybe someone has a recommendation on what to use.


Reply to this email directly or view it on GitHub
#2033 (comment).

@f0k
Contributor
f0k commented Aug 21, 2014

Do the shapes are passed to the conv2d implementation? If so, this could be
removed and would cause less compilation.

It's a little bit faster. I think it might have finished the test in 1.5 hours, but now it throws an exception in theano/tensor/nnet/conv.py:

830             if not all_shape and (self.dx != 1 or self.dy != 1):
831  ->             raise Exception("ConvOp.grad when dx!=1 or dy!=1 we must have all "
832                                 "the optional shape information")

So this is not the solution, it was still too slow anyway. I'd go for providing all shape information to conv2d, but only testing a smaller range of carefully-chosen shapes so the test can finish within, say, 5 minutes. I'm just lost which of the functions (if any) I should use to generate a suitable set of shapes that includes some with subsampling, some without -- just a good mix that should catch most possible problems.

@stencilman
Contributor

I am getting the dreaded error again and I have no clue why. If any of you guys want, I am happy to give you access to my server and tell you how to reproduce the error.... :-(

Error when tring to find the memory information on the GPU: an illegal memory access was encountered
Error freeing device pointer 0x111bba0000 (an illegal memory access was encountered). Driver report 0 bytes free and 0 bytes total 
device_free: cudaFree() returned an error, but there is already an Python error set. This happen during the clean up when there is a first error and the CUDA driver is in a so bad state that it don't work anymore. We keep the previous error set to help debugging it.CudaNdarray_uninit: error freeing self->devdata. (self=0xe8121f0, self->devata=0x111bba0000)
Traceback (most recent call last):
  File "deep_nets.py", line 71, in <module>
    model.trainEpoch(data.tr, conf, epoch)
  File "/home/ajain/Projects/deep_nets/python/lib/machine.py", line 367, in trainEpoch
    err = self.train_model(bs)
  File "/home/ajain/Theano/theano/compile/function_module.py", line 589, in __call__
    self.fn.thunks[self.fn.position_of_error])
  File "/home/ajain/Theano/theano/compile/function_module.py", line 579, in __call__
    outputs = self.fn()
RuntimeError: GpuCorrMM encountered a CUBLAS error: the function failed to launch on the GPU

Apply node that caused the error: GpuCorrMM_gradInputs{valid, (1, 1), pad='half'}(W_18, GpuContiguous.0)
Inputs types: [CudaNdarrayType(float32, 4D), CudaNdarrayType(float32, 4D)]
Inputs shapes: [(4, 256, 1, 1), (4, 4, 60, 90)]
Inputs strides: [(256, 1, 0, 0), (21600, 5400, 90, 1)]
Inputs scalar values: ['not scalar', 'not scalar']

Backtrace when the node is created:
  File "/home/ajain/Theano/theano/gradient.py", line 895, in access_term_cache
    output_grads = [access_grad_cache(var) for var in node.outputs]
  File "/home/ajain/Theano/theano/gradient.py", line 1173, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/home/ajain/Theano/theano/gradient.py", line 1034, in access_term_cache
    input_grads = node.op.grad(inputs, new_output_grads)
  File "/home/ajain/Theano/theano/sandbox/cuda/blas.py", line 838, in grad
    weights, top, bottom.shape[-2:])
@f0k f0k changed the title from Faster full convolution for GpuCorrMM to Faster algorithms and gradients for GpuCorrMM Aug 22, 2014
@f0k
Contributor
f0k commented Aug 22, 2014

I've found that for some configurations I use, the choice between GpuCorrMM and GpuCorrMM_gradWeights needs to take into account the batchsize and input channels. I've extended GpuConv to memorize this information when it is replacing a ConvOp, so any other optimizer can choose an optimal replacement for GpuConv also based on batchsize and input channels (only if the user provided this information, of course). Now, if appropriate, the conv_gemm optimizer swaps the implementations for fprop and bprop wrt. weights compared to caffe and Torch and is faster :)

@stencilman: I'll send you an email. We need to find out if this is related to GpuCorrMM or a very different problem.

@stencilman
Contributor

@f0k: yes, we need to figure this out. Will wait for your email. I have a repro case ready.

@nouiz nouiz commented on an outdated diff Aug 22, 2014
theano/sandbox/cuda/tests/test_conv_cuda_ndarray.py
+ npy_kern.transpose(1, 0, 2, 3)[:,:,::-1,::-1])).transpose(1, 0, 2, 3)
+
+ if not numpy.allclose(cpuval, gpuval, rtol=1e-4):
+ print "Test failed for"
+ print "direction: ", direction
+ print "ishape: ", ishape
+ print "kshape: ", kshape
+ print "subsample: ", subsample
+ assert False
+
+
+def test_gemm_grads():
+ for mode in 'valid', 'full':
+ for bs in [1, 4, 5]:
+ for ch in range(1,4):
+ for nf in range(1,4):
@nouiz
nouiz Aug 22, 2014 Member

I think you could just test with fixed ch to 4 and nf to 3. (different then ch).

@nouiz nouiz commented on an outdated diff Aug 22, 2014
theano/sandbox/cuda/tests/test_conv_cuda_ndarray.py
+ f = theano.function([i, k], op, mode=theano_mode)
+ gpuval = numpy.array(f(npy_img.transpose(1, 0, 2, 3),
+ npy_kern.transpose(1, 0, 2, 3)[:,:,::-1,::-1])).transpose(1, 0, 2, 3)
+
+ if not numpy.allclose(cpuval, gpuval, rtol=1e-4):
+ print "Test failed for"
+ print "direction: ", direction
+ print "ishape: ", ishape
+ print "kshape: ", kshape
+ print "subsample: ", subsample
+ assert False
+
+
+def test_gemm_grads():
+ for mode in 'valid', 'full':
+ for bs in [1, 4, 5]:
@nouiz
nouiz Aug 22, 2014 Member

Just back size 1 and 5 should be enought.

@nouiz
nouiz Aug 22, 2014 Member

I think you can just test 1 and 5.

@nouiz nouiz commented on an outdated diff Aug 22, 2014
theano/sandbox/cuda/tests/test_conv_cuda_ndarray.py
+ if not numpy.allclose(cpuval, gpuval, rtol=1e-4):
+ print "Test failed for"
+ print "direction: ", direction
+ print "ishape: ", ishape
+ print "kshape: ", kshape
+ print "subsample: ", subsample
+ assert False
+
+
+def test_gemm_grads():
+ for mode in 'valid', 'full':
+ for bs in [1, 4, 5]:
+ for ch in range(1,4):
+ for nf in range(1,4):
+ for rImg1 in [2, 5, 8]:
+ for rImg2 in [2, 5, 8]:
@nouiz
nouiz Aug 22, 2014 Member

for rimg1 I would just test value 2, 5 and for rimg2, what about 2 and 8?

@nouiz
Member
nouiz commented Aug 22, 2014

I made comment in the PR on how to speed it up. You can tests much less
case then the old theano convolution, as you use the same code path for
most cases. The old convolution had many different code path for different
shapes, so it needed to tests all of them.

On Thu, Aug 21, 2014 at 5:55 PM, Arjun Jain notifications@github.com
wrote:

And, it only seems to happen for batch size = 4. It doesnt happen for
batch size 8 or 16.


Reply to this email directly or view it on GitHub
#2033 (comment).

@f0k
Contributor
f0k commented Aug 22, 2014

Very good. In addition to your suggestions, I introduced another 25% saving (only testing subsample in [(1,1),(1,2),(2,2)]), now the test completes in under 5 minutes. Great!

@stencilman
Contributor

@nouiz: Not sure if you got Jan's email, bad news about the bug, he said "The error occurs on the 16th gemm call, so it already did 15 exactly same-shaped calls before. It is just using different "top->devdata + n *top_stride" (because n goes from 0 to 15). I tried to provoke a segfault by calling a memset with the top -> devdata part it uses, but there was no segfault, so it seems to be allocated fine."

@nouiz
Member
nouiz commented Aug 22, 2014

I saw that. I'm not sure what to think about it. Jan, can you add sync
after each loop iteration? Just in case to see if this fix the problem.

@stencilman, are you able to reproduce this on another computer? It could
be a hardware problem.

Also, I could be the nvcc version. If the above don't work, which version
of nvcc is installed? If you have istalled 6.0, you can install 5.5 in
paralelle and use it for the test. To know if it work. You don't need to
update the driver. Otherwise, it could be great to test with cuda 6.5, but
this would probably require a driver update.

On Fri, Aug 22, 2014 at 7:53 PM, Arjun Jain notifications@github.com
wrote:

@nouiz https://github.com/nouiz: Not sure if you got Jan's email, bad
news about the bug, he said "The error occurs on the 16th gemm call, so it
already did 15 exactly same-shaped calls before. It is just using different
"top->devdata + n *top_stride" (because n goes from 0 to 15). I tried to
provoke a segfault by calling a memset with the top -> devdata part it
uses, but there was no segfault, so it seems to be allocated fine."


Reply to this email directly or view it on GitHub
#2033 (comment).

@stencilman
Contributor

I will try on another computer and update you guys tomorrow. I tried 5.5 and 6.0 and both have the problem. If the problem persists on the other computer, I will upgrade to 6.5 and let you guys know.

@nouiz
Member
nouiz commented Aug 23, 2014

Can you also try to add this call "cudaDeviceSynchronize()" after the
cublas call? Don't forget to clear the cache to have the code recompiled.

On Fri, Aug 22, 2014 at 8:01 PM, Arjun Jain notifications@github.com
wrote:

I will try on another computer and update you guys tomorrow. I tried 5.5
and 6.0 and both have the problem. If the problem persists on the other
computer, I will upgrade to 6.5 and let you guys know.


Reply to this email directly or view it on GitHub
#2033 (comment).

@stencilman
Contributor

I just tried it on another computer with cuda 5.5 and I still get the same error.

@f0k
Contributor
f0k commented Aug 26, 2014

So I investigated the problem on @stencilman's computer. The error occurs in a specific gemm call for the gradient wrt. inputs, but it disappears if one of the three matrix arguments to the gemm call (top) is copied to another location directly before the call and the copy is used instead of the original matrix. I don't know how to interpret this, but I speculated that it might indicate a memory access error for top that for some reason is only exposed by cublasSgemm, but not by cublasScopy.

Investigating further, on his computer, cuda-memcheck complains about one of the tests in test_gemm_full, specifically the one for shape[355]. Adding the line shapes = shapes[355:356] directly before the exec_conv call in test_full and then running cuda-memcheck nosetests test_conv_cuda_ndarray.py:test_gemm_full (from the theano/sandbox/cuda/tests subdirectory) gives 1440 out of bounds reads in the gemm kernel, but none in copying top to another location. This indicates that top is fine, but weights or col is wrong, contradicting the previous observations. cuda-memcheck does not complain about any of the tests in test_gemm_valid, which uses the algorithm for the forward pass. So the problem seems to be limited to the full convolution algorithm used for the backprop wrt. inputs.

Now on my desktop machine, cuda-memcheck is happy with all the tests in test_gemm_full, test_gemm_valid, test_gemm_subsample, and test_gemm_grads. On my server, it also gives 1440 invalid memory accesses in test_gemm_full for one of the two GPUs installed. It might be related to which architecture it is compiled for, or it might be a CUDA or driver issue.

So far I know the following:

  • machine A: Ubuntu 12.10, CUDA 6.0, driver 334.21, GTX TITAN Black: fails
  • machine A: Ubuntu 12.10, CUDA 6.0, driver 334.21, Tesla K40c: fails
  • machine B: Ubuntu 14.04, CUDA 5.5, driver 331.38, GT 640: works
  • machine C: Ubuntu 12.04, CUDA 5.5, driver 331.38, GTX 580: works
  • machine C: Ubuntu 12.04, CUDA 5.5, driver 331.38, GTX 780 Ti: fails

It would be helpful if some others could do:

cd your_theano_directory
git pull https://github.com/f0k/Theano corrmm-faster-fullconv
for x in full valid subsample grads; do cuda-memcheck nosetests theano/sandbox/cuda/tests/test_conv_cuda_ndarray.py:test_gemm_$x; done

It should take about 7 minutes to complete. If it works, it will display "ERROR SUMMARY: 0 errors" after each test. If it fails, it will display "ERROR SUMMARY: 1440 errors" after the first test. (Maybe it would be enough to run the first test, that's 20 seconds only.)

Any other ideas on how to proceed are highly welcome.

@lamblin
Member
lamblin commented Aug 26, 2014

Thanks for the investigation!
Here are a couple of ideas that went through my head:

  • When you copy top, could it change its memory layout (from F-contiguous or non-contiguous to C-contiguous for instance)?
  • What happens if you copy one of the other matrices (weights or col if I followed correctly) and use the copy?
  • I think there are two copies of the strides and shapes of CudaNdarrays, one on CPU and one on GPU, @nouiz could confirm. Do these match? Are there any zero strides?
  • Does the "invalid read" error message give pointer addresses? Would it be possible to guess which element it tries to access?

I'll try to run the test on a couple of machines in the lab this afternoon.

@f0k
Contributor
f0k commented Aug 26, 2014

Thanks for your quick reply!

When you copy top, could it change its memory layout (from F-contiguous or non-contiguous to C-contiguous for instance)?

In @stencilman's original error case, the error occurred in the last iteration of this loop: https://github.com/f0k/Theano/blob/corrmm-faster-fullconv/theano/sandbox/cuda/conv_gemm.cu#L379
batchsize was 16, the error occurred for n = 15, but for none of the 15 iterations before, so the only difference was the top->devdata + n * top_stride argument. A cudaDeviceSynchronize before the gemm call confirmed that the error occurred directly in gemm. Forcing n = 14 for the last iteration made the error disappear. Copying top->devdata + n * top_stride into a temporary matrix at each iteration and using that temporary matrix made the error disappear as well. I did the copy by initializing an empty CudaNdarray of correct size and copying the data either via cublasScopy or cudaMemcpy as an opaque memory block, and cublasSgemm also treats it as an opaque memory block, so the memory layout does not really play a role here.

I think there are two copies of the strides and shapes of CudaNdarrays, one on CPU and one on GPU, @nouiz could confirm. Do these match? Are there any zero strides?

From the exceptions quoted further above by @stencilman, all matrices seem to be C-contiguous, but I don't know if the information on CPU and GPU could differ. We can try to find out.

Does the "invalid read" error message give pointer addresses? Would it be possible to guess which element it tries to access?

They give pointer addresses, but very large ones, very different from the base addresses of the three matrices involved. They also give the thread numbers and block indices of the errors; thread indices are from 160 to 255, indicating that not all of the threads do something evil. I will try to add some other memory accesses involving only one of the matrices hoping to trigger an error outside of the gemm call.

@nouiz
Member
nouiz commented Aug 26, 2014

On Tue, Aug 26, 2014 at 12:02 PM, Pascal Lamblin notifications@github.com
wrote:

Thanks for the investigation!
Here are a couple of ideas that went through my head:

  • When you copy top, could it change its memory layout (from
    F-contiguous or non-contiguous to C-contiguous for instance)?
  • What happens if you copy one of the other matrices (weights or col
    if I followed correctly) and use the copy?
  • I think there are two copies of the strides and shapes of
    CudaNdarrays, one on CPU and one on GPU, @nouiz
    https://github.com/nouiz could confirm. Do these match? Are there
    any zero strides?

Not by default. Now by default, there is only a copy on the CPU. They
should match in the case where we have both.

  • Does the "invalid read" error message give pointer addresses? Would
    it be possible to guess which element it tries to access? I'll try to run
    the test on a couple of machines in the lab this afternoon.


Reply to this email directly or view it on GitHub
#2033 (comment).

@nouiz
Member
nouiz commented Aug 26, 2014

Could it be that the last batch don't have the same sizes as the others?
But that we pass the others sizes? I would be surprised of this, bug just
in case.

I tried on my computer on a GTX470

  • oolong: Fedora Core 14, CUDA 5.5, driver 340.24, GTX 470: pass

On Tue, Aug 26, 2014 at 12:25 PM, Jan Schlüter notifications@github.com
wrote:

Thanks for your quick reply!

When you copy top, could it change its memory layout (from F-contiguous
or non-contiguous to C-contiguous for instance)?

In @stencilman https://github.com/stencilman's original error case, the
error occurred in the last iteration of this loop:
https://github.com/f0k/Theano/blob/corrmm-faster-fullconv/theano/sandbox/cuda/conv_gemm.cu#L379
batchsize was 16, the error occurred for n = 15, but for none of the 15
iterations before, so the only difference was the top->devdata + n *
top_stride argument. A cudaDeviceSynchronize before the gemm call
confirmed that the error occurred directly in gemm. Forcing n = 14 for
the last iteration made the error disappear. Copying top->devdata + n *
top_stride into a temporary matrix at each iteration and using that
temporary matrix made the error disappear as well. I did the copy by
initializing an empty Cud aNdarray of correct size and copying the data
either via cublasScopy or cudaMemcpy as an opaque memory block, and
cublasSgemm also treats it as an opaque memory block, so the memory
layout does not really play a role here.

I think there are two copies of the strides and shapes of CudaNdarrays,
one on CPU and one on GPU, @nouiz https://github.com/nouiz could
confirm. Do these match? Are there any zero strides?

From the exceptions quoted further above by @stencilman
https://github.com/stencilman, all matrices seem to be C-contiguous,
but I don't know if the information on CPU and GPU could differ. We can try
to find out.

Does the "invalid read" error message give pointer addresses? Would it be
possible to guess which element it tries to access?

They give pointer addresses, but very large ones, very different from the
base addresses of the three matrices involved. They also give the thread
numbers and block indices of the errors; thread indices are from 160 to
255, indicating that not all of the threads do something evil. I will try
to add some other memory accesses involving only one of the matrices hoping
to trigger an error outside of the gemm call.


Reply to this email directly or view it on GitHub
#2033 (comment).

@stencilman
Contributor

Maybe the error only happens for newer cards such as 780,Titan and K40s and not the older ones like 640 etc?

@f0k
Contributor
f0k commented Aug 26, 2014

Could it be that the last batch don't have the same sizes as the others?

But how? It's a 4-tensor, the first dimension of which indicates the batch. So all batches must be the same size. If not enough memory is allocated to accommodate the last batch, then there is a serious bug somewhere else.

Maybe the error only happens for newer cards such as 780,Titan and K40s and not the older ones like 640 etc?

Possible. We're here to find out :)

@lamblin
Member
lamblin commented Aug 26, 2014

They give pointer addresses, but very large ones, very different from the base addresses of the three matrices involved

Could it be a problem of signed vs. unsigned pointer arithmetic? @nouiz and @abergeron encountered some at some point, where the type of Cuda's variables (like threadIdx.x, etc.) somehow made pointer differences overflow, and addressed completely different memory instead.
Would those large pointers modulo 2**32 look closer to the original base addresses?

@lamblin
Member
lamblin commented Aug 26, 2014

And maybe the error is hidden on cards with less than 4G of RAM. Does that fit the success/failure list to date?

@f0k
Contributor
f0k commented Aug 26, 2014

The GTX 780 which fails only has 3 GiB, but still that's a very good idea, the limit could be at 2 GiB. I get an access fault with a cudaMemset call on top->devdata + n*top_strides as well. Let's see...

@lamblin
Member
lamblin commented Aug 26, 2014

So, to date:

  • GTX TITAN Black: fails, 6 GiB
  • Tesla K40c: fails, 12 GiB
  • GT 640: works, 2 GiB
  • GTX 580: works, 1.5 GiB
  • GTX 780 Ti: fails, 3 GiB
  • GTX 470: works, 1.2 GiB

So a limit at 2 GiB makes sense.

@f0k
Contributor
f0k commented Aug 26, 2014

Err, do you happen to add 0xb0000000 to devdata somewhere in Theano to mark something? Because depending on which operations I enable or disable, I get different errors (out of bounds, cannot free device pointer, etc) with addresses that look like the correct addresses of bottom, weights and top but with 0xb and some zeros in front.

@lamblin
Member
lamblin commented Aug 26, 2014

No, I don't think we do...

@f0k
Contributor
f0k commented Aug 26, 2014

Ah, nevermind, the addresses all start with 0xb, but my format string had used %x instead of %p to display them. So with some more synchronization calls, I get:

  • There are no errors in any of the cudaMemset calls that set bottom, weights and top to zero, so I assume they are allocated correctly
  • When I throw an exception before scheduling col2im, the number of errors detected by cuda-memcheck increases from 1440 to 1471 because it includes a range of Program hit error 30 on CUDA API call to ... for cudaStreamDestroy, cudaEventDestroy, cudaFree and so on. Whatever this means.
  • In the test_gemm_full test case, the error occurs in the very first gemm call. With CUDA_LAUNCH_BLOCKING=1, the gemm call returns CUBLAS_STATUS_EXECUTION_FAILED. Without blocking, a subsequent deviceSynchronize returns "unknown error". Both could indicate a segfault.

cuda-memcheck says that sgemm_sm_heavy_nt_ldg tried to read addresses between 0xb00508000 and 0xb0050807c. The call in question was the following:

cublasSgemm(handle=0x75f1bc0, transa='N', transb='T',
     m=N_=4096, n=K_=512, k=M_=5,
     alpha=1.0f,
     A=top[n]=0xb00700000, lda=N_=4096,
     B=weight=0xb00504800, ldb=K_=512,
     beta=0.0f,
     C=col=0xb00ac0000, ldc=N_=4096)
HOST_DIMS(top): 16 5 64 64
STRIDES(top): 20480 4096 64 1
HOST_DIMS(weight): 5 8 8 8
STRIDES(weight): 512 64 8 1
HOST_DIMS(col): 512 4096
STRIDES(col): 4096 1

So weight has 5*8*8*8 = 2560 floats. We tell cublasSgemm that matrix B has n*k = K_*M_ = 512*5 = 2560 floats, with a leading dimension of ldb = K_ = 512, and to be transposed. Still it tries to access 0xb00508000 and above, which is float number (0xb00508000 - 0xb00504800) / 4 = 3584, clearly too much.

Bonus: Running the test with nvprof instead of cuda-memcheck shows the difference between old and new GPUs: On the GTX 580, it uses gemm_kernel1x1_core, and on the GTX 780Ti, it uses sgemm_sm35_ldg_nt_128x16x64x16x16 and sgemm_sm_heavy_nt_ldg.

Is this a bug in CUBLAS after all? Or can you spot anything wrong with the cublasSgemm call? Puzzling extra question: In @stencilman's original test case, why was the error dependent on the address of A rather than B? (Note that we didn't use cuda-memcheck there, that error came up by itself.)

@abergeron
Member

I think your call is incorrect since the documentation for cublasSgemm states that B is of size ldb x n (so 512 x 512) and your calculations assume is it of size 512 x 5.

@abergeron
Member

Oh sorry, I just noticed that the transpose version means that it's ldb x k, so your calculations are ok.

@lamblin
Member
lamblin commented Aug 26, 2014

@abergeron I think the documentation is mistaken, it should state "if transb == CUBLAS_OP_N" instead of "if transa == CUBLAS_OP_N".

@abergeron
Member

Well otherwise, the problem does not show up on a 750Ti because that also uses the gemm_kernel1x1_core kernel, even though it is sm 5.0.

@lamblin
Member
lamblin commented Aug 26, 2014

Indeed, it looks more and more like a bug in cublas.
Maybe we can force the execution of a different kernel, for instance by calling gemmBatched with only one matrix.
Or simpler: what if we compile for architecture 3.0 or 2.0 instead of 3.5?

@stencilman
Contributor

So, torch I think compiles for 2.0 and I do not have this problem with torch.

@f0k
Contributor
f0k commented Aug 26, 2014

Good find @lamblin. This has caused us a lot of headaches now, seems I had placed too much trust in CUBLAS...

Maybe we can force the execution of a different kernel, for instance by calling gemmBatched with only one matrix.

That seems very hackish... also I imagine it would either be slower or treated as a special case and mapped to one of the standard gemm kernels.

Or simpler: what if we compile for architecture 3.0 or 2.0 instead of 3.5?

Setting THEANO_FLAGS=nvcc.flags=-arch=sm_30 doesn't change anything, unfortunately, because that doesn't recompile CUBLAS. cublasSgemm still decides based on the capabilities of the GPU and goes on to call sgemm_sm35_ldg_nt_128x16x64x16x16 and sgemm_sm_heavy_nt_ldg.

@stencilman: I'm not sure why it works with Torch, but as we saw the error occurrence being dependent on the address of A, it could just be luck. For the test_gemm_full case, the result of the gemm call was correct, so the invalid reads haven't been used in the computation. We can probably conclude that it will either work correctly despite invalid reads, or crash because of invalid reads, but not produce an incorrect result.

Does anybody have a chance to test this with CUDA 6.5 (I don't)? Depending on the result I will update the documentation for GpuCorrMM and then do a final rebase because something introduced a merge conflict again...

@benanne
Contributor
benanne commented Aug 27, 2014

Hey guys,

I have a machine running with the 340 drivers, so I ran @f0k's tests with CUDA 6.0 and with CUDA 6.5. The machine has a "superclocked" 780 Ti (3GB) and the 340.24 drivers. It looks a lot like this is a bug that was fixed in CUDA 6.5.

with CUDA 6.0: https://gist.github.com/benanne/4128e5998122295f992b
with CUDA 6.5: https://gist.github.com/benanne/2f3b90a8eb2649082541

@stencilman
Contributor

Great, thanks for the report @benanne! This is good news. I am sorry I have not been able to free a machine to be able to reboot it, but hopefully I will be able to do this today.

@f0k
Contributor
f0k commented Aug 27, 2014

Perfect, thanks a lot!

@nouiz
Member
nouiz commented Aug 27, 2014

Can you make a better error message in that case, that explain the problem
and tell how to fix is? Use another GPU, don't use cuda 6.0 or use
different shape to work around cublas bug.

On Wed, Aug 27, 2014 at 11:25 AM, Jan Schlüter notifications@github.com
wrote:

Perfect, thanks a lot!


Reply to this email directly or view it on GitHub
#2033 (comment).

@benanne
Contributor
benanne commented Aug 27, 2014

From previous reports it looks like CUDA 5.5 is also affected. I guess it only affects the 700 series GPUs though (including the K40 and the Titans).

@f0k
Contributor
f0k commented Aug 27, 2014

@nouiz: Good idea. I will add it to the documentation of GpuCorrMM anyway, but I can refer to this documentation in all CUDA/CUBLAS-related error messages in conv_gemm.cu (we don't know where the error will occur because everything is asynchronous, it can also appear outside of conv_gemm.cu, but hopefully people will learn to use CUDA_LAUNCH_BLOCKING=1 whenever there is a problem).

@benanne: Yes, it also happens on CUDA 5.5. I will test CUDA 5.0 for comparison.

@abergeron
Member

From the sm35 part of the kernel name, I would guess it more specifically affects cards that support sm 3.5 (which is not all 700 series).

This would make the affected cards: Tesla K40, Tesla K20, Quadro K6000, GeForce GT 640 (DDR5), GeForce 780, GeForce 780 Ti, GeForce TITAN, GeForce TITAN Black, GeForce TITAN Z.

@f0k
Contributor
f0k commented Aug 27, 2014

Same problem for CUDA 5.0.
CUDA 4.2 only supports up to -arch=sm_30 and needs a proper nvcc.flags setting for Theano, but then it works also for the 780 Ti.

@f0k
Contributor
f0k commented Aug 27, 2014

Added documentation about the CUBLAS bug and rebased onto master. The tests pass, Travis passes, should be finally good for merging. Thanks to everyone involved in reviewing and testing!

@nouiz

Can you tell that this is shape specific, so changing the shapes can work around that bug.

Owner
f0k replied Aug 29, 2014

But it already says "for some input and filter shapes", how would you formulate it?

You are right, but I missed it, so some people will miss it. Why not just add this sentence.
Changings shapes could also work around this CUBLAS bug?

Anyway, I let you decide if you do it or not.

Owner
f0k replied Aug 30, 2014

Okay, added a note telling that changing the batch size or channel count may also help.

@nouiz

You should return NULL after setting python error. Do this for all change in this PR, but there is also some other pyerr in the source that need this.

Owner
f0k replied Aug 30, 2014

Ooops. Yes, the "4D" checks were missing it, so I missed it as well. Fixed.

@nouiz nouiz and 1 other commented on an outdated diff Aug 29, 2014
theano/sandbox/cuda/opt.py
+ (None not in node.op.imshp[-2:]) and
+ (node.op.kshp is not None) and
+ (None not in node.op.kshp)):
+ # we know the kernel and output size
+ prod1 = node.op.kshp[0] * node.op.kshp[1]
+ prod2 = ((node.op.imshp[-2] - node.op.kshp[0] + 1) *
+ (node.op.imshp[-1] - node.op.kshp[1] + 1))
+ if ((node.op.bsize is not None) and
+ (len(node.op.imshp) == 3) and
+ (node.op.imshp[0] is not None)):
+ # we also know batchsize and input channels
+ prod1 *= node.op.bsize
+ prod2 *= node.op.imshp[0]
+ # compare to decide
+ if prod1 > prod2:
+ return [gpu_contiguous(GpuCorrMM_gradWeights('valid', subsample, pad)(
@nouiz
nouiz Aug 29, 2014 Member

Why the gpu_contiguous on that line? I think it isn't needed and can be removed.

@f0k
f0k Aug 30, 2014 Contributor

Hmm, you're right. I thought it should have the same memory layout as before the replacement. Removed it.

@nouiz nouiz commented on the diff Aug 29, 2014
theano/sandbox/cuda/caffe_common.hpp
@@ -1,47 +0,0 @@
-/*
@nouiz
nouiz Aug 29, 2014 Member

This is good, one windows, a user reported problem that this file wasn't found.

@nouiz
Member
nouiz commented Aug 29, 2014

I did a global review. There is just 2 thinks left before merging:

  • The return NULL
  • Optional but welcome, remove the not useful gpu_contiguous.

thanks for the good work!

@f0k
Contributor
f0k commented Aug 30, 2014

Thanks for the review, @nouiz! Fixed everything you noticed.
/edit: please wait with the merge, removing the gpu_contiguous was wrong, just figuring it out...

@f0k
Contributor
f0k commented Sep 1, 2014

Removing the gpu_contiguous as per @nouiz's suggestion introduced a problem: The dimshuffle() call on the output inserts a DimShuffle op, which returns a CPU output. This changes the type of the output, which is not allowed and results in:

TypeError: ('The type of the replacement must be the same as the type of the original Variable.', GpuConv{valid, (1, 1), None, (6, 6), True, (1, 287, 29), (6, 6)}.0, DimShuffle{1,0,2,3}.0, CudaNdarrayType(float32, 4D), TensorType(float32, 4D), 'local_conv_gemm')

We could use GpuDimShuffle instead, or wrap it in a as_cuda_ndarray_variable to promise that it's going to be moved to the GPU. I opted for the latter.

Interestingly, this wasn't triggered in the tests, I only noticed when running the Theano benchmark script from convnet-benchmarks. I guess that's a general problem with the tests which should be addressed in a separate issue.

@f0k
Contributor
f0k commented Sep 1, 2014

I've got another optimization question. This is not actually part of this PR, but related to it, so I'll just ask here.

When using GpuCorrMM directly (on a flipped kernel), I get the following graph and timing for the gradient wrt. inputs:

GpuCorrMM_gradInputs{valid, (1, 1), pad=(0, 0)} [@A] ''   3
 |GpuContiguous [@B] ''   2
 | |GpuSubtensor{::, ::, ::int64, ::int64} [@C] ''   1
 |   |<CudaNdarrayType(float32, 4D)> [@D]
 |   |Constant{-1} [@E]
 |   |Constant{-1} [@E]
 |GpuContiguous [@F] ''   0
   |<CudaNdarrayType(float32, 4D)> [@G]
(experimental) theano.sandbox.cuda.blas.CorrMM bprop inputs: 0.0 GFLOP/s ( tm = 0.012757648468 )

When using the standard conv2d and letting the optimizer replace it, I get:

GpuCorrMM_gradInputs{valid, (1, 1), pad=(0, 0)} [@A] ''   5
 |GpuContiguous [@B] ''   4
 | |GpuDimShuffle{1,0,2,3} [@C] ''   3
 |   |GpuSubtensor{::, ::, ::int64, ::int64} [@D] ''   2
 |     |GpuDimShuffle{1,0,2,3} [@E] ''   1
 |     | |<CudaNdarrayType(float32, 4D)> [@F]
 |     |Constant{-1} [@G]
 |     |Constant{-1} [@G]
 |GpuContiguous [@H] ''   0
   |<CudaNdarrayType(float32, 4D)> [@I]
(experimental) theano.sandbox.cuda.blas.CorrMM bprop inputs: 0.0 GFLOP/s ( tm = 0.0244923362732 )

It's 100% slower than the former, seemingly because of the two dimshuffles that are actually redundant. The original conv2d graph for this case looks like this:

GpuConv{full, (1, 1), None, (6, 6), True, (32, 282, 24), (6, 6)} [@A] ''   2
 |<CudaNdarrayType(float32, 4D)> [@B]
 |GpuSubtensor{::, ::, ::int64, ::int64} [@C] ''   1
   |GpuDimShuffle{1,0,2,3} [@D] ''   0
   | |<CudaNdarrayType(float32, 4D)> [@E]
   |Constant{-1} [@F]
   |Constant{-1} [@F]
theano.tensor.nnet.conv.conv2d bprop inputs: 0.0 GFLOP/s ( tm = 0.0263266887665 )

Any idea on how to fix that? Swap the order of dimshuffle and [:,:,::-1,::-1] in the ConvOp gradient and hope they get merged? Introduce an optimizer that swaps dimshuffle and subtensor (shuffling the subtensor slices as needed)?

/edit: Swapping the dimshuffle and flipping in the ConvOp gradient results in:

GpuCorrMM_gradInputs{valid, (1, 1), pad=(0, 0)} [@A] ''   5
 |GpuContiguous [@B] ''   4
 | |GpuDimShuffle{1,0,2,3} [@C] ''   3
 |   |GpuDimShuffle{1,0,2,3} [@D] ''   2
 |     |GpuSubtensor{::, ::, ::int64, ::int64} [@E] ''   1
 |       |<CudaNdarrayType(float32, 4D)> [@F]
 |       |Constant{-1} [@G]
 |       |Constant{-1} [@G]
 |GpuContiguous [@H] ''   0
   |<CudaNdarrayType(float32, 4D)> [@I]
(experimental) theano.sandbox.cuda.blas.CorrMM bprop inputs: 0.0 GFLOP/s ( tm = 0.0244253120422 )

There's something wrong with the dimshuffles. a) They should have been eliminated in optimization, and b) they should be nearly zero-cost operations, shouldn't they? I guess they are not eliminated in optimization because the first one has already been moved to GPU when the second one is added to the graph, and there is no optimization for GpuDimShuffles.

Anyway, I'll open a separate issue when this PR has been merged.

@nouiz nouiz merged commit cfc493d into Theano:master Sep 1, 2014

1 check passed

continuous-integration/travis-ci The Travis CI build passed
Details
@f0k f0k deleted the f0k:corrmm-faster-fullconv branch Sep 1, 2014
@f0k
Contributor
f0k commented Sep 1, 2014

Yay, thanks for merging, @nouiz!

@stencilman
Contributor

Thanks a lot @nouiz! I can tell people to use the master and not to merge with @f0k's branch any more! 👍

@nouiz
Member
nouiz commented Sep 2, 2014

Can you time the gpucontiguous in the fast and slow case with the profile=True and profile_memory=True Theano flags? My guess is the -1 in the subtensor, cause the gpucontiguous to be slower. This cause different memory layout of the input and we didn't optimize for that case.

@nouiz
Member
nouiz commented Sep 2, 2014

Should we update the GpuCorrMM doc to tell that it is faster when called directly?

What about doing a corr2d() method, that call the conv2d on the cpu with the flipping, that way they could cancel themself after substitution. Maybe we will need to make the merge subsentor optimization work for the GPU op. What do you think of this as a user interface? @lamblin @abergeron, you input that that interface would be welcome.

@f0k
Contributor
f0k commented Sep 2, 2014

My guess is the -1 in the subtensor, cause the gpucontiguous to be slower.

That's the same in both the slow and fast case, because I explicitly did a [:,:,::-1,::-1] on the weights when using GpuCorrMM directly. So the only difference is the dimshuffle. But I tried with some other shapes and then both cases perform the same... maybe we shouldn't worry too much.

Oh, maybe another thing that I observed: When I tried calling GpuCorrMM directly with some SharedVariable instances, they were not C-contiguous. Is this to be expected?

Should we update the GpuCorrMM doc to tell that it is faster when called directly?

No, because depending on the shapes, it can actually be slower when called directly. The goal should be to have the automatic graph optimization from conv2d be on par or better than a manual GpuCorrMM in all cases.

What about doing a corr2d() method

I think it's not needed for now. Let's first see if the kernel flipping really is a bottleneck and if so, we can still introduce a corr2d or update the conv2d documentation to say that flipping the kernels may improve performance in some cases.

@abergeron
Member

I would go for making merge subtensor work on GPU ops. I don't think a proliferation of mostly similar interfaces is a good idea.

@ballasn ballasn referenced this pull request Sep 4, 2014
Closed

GPUCorr3Dmm #2077

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment