Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

caffe conv kernel for theano. tests work, but needs integration and some... #2002

Merged
merged 28 commits into from Aug 5, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
fb66035
caffe conv kernel for theano. tests work, but needs integration and s…
stencilman Jul 29, 2014
53630ed
changes after abergeron commented on the code
stencilman Jul 29, 2014
41ab038
remove old code that isn't used.
nouiz Jul 30, 2014
5a4e453
Add error about the missing implementation.
nouiz Jul 30, 2014
5741566
remove old code and small fix for not yet used code.
nouiz Jul 30, 2014
ea8153b
Opt to use GpuConvMM in valid mode.
nouiz Jul 30, 2014
c5728f5
Better handling of transfer and better error reporting and fix refcount.
nouiz Jul 30, 2014
598f485
warn about bugged code.
nouiz Jul 30, 2014
998b9bc
Reuse the current gpu conv test for gpuconvmm
nouiz Jul 30, 2014
2955b33
Reuse pre allocated memory.
nouiz Jul 30, 2014
1df727f
Add check and better error message
nouiz Jul 30, 2014
549c2fd
Partial fix
nouiz Jul 30, 2014
9bf8ef2
code simplication
nouiz Jul 30, 2014
a1509a7
Indentation.
nouiz Jul 30, 2014
80dd43e
Hi Fred, I tried it out, but for me, it doesnt find conv() in package…
stencilman Jul 31, 2014
f18c849
Look what I did on like 117 in file theano/sandbox/cuda/tests/test_co…
stencilman Jul 31, 2014
c649d66
- fixed a bug in cafe conv (values of _M, _N, _K)
stencilman Aug 1, 2014
0d876f7
renamed to test_gemm
stencilman Aug 1, 2014
6b078b2
renamed conv_gemm function
stencilman Aug 1, 2014
d823134
Some more fix.
nouiz Aug 1, 2014
5174e47
Merge branch 'conv_gemm2' of https://github.com/nouiz/Theano into con…
stencilman Aug 1, 2014
03ed592
Fix error detection and update tests to tests only case covered.
nouiz Aug 1, 2014
b6febcb
Remove warning as it work now!
nouiz Aug 1, 2014
46f64a6
Merge commit 'refs/pullreqs/origin/pr/2' into conv_gemm after Fred's …
stencilman Aug 1, 2014
497a2d9
Including support for 'full' convolutions. It uses the existing pad f…
stencilman Aug 2, 2014
74ea01a
fixed the bug after which test_full also passes all tests.
stencilman Aug 2, 2014
1e3de2c
Changes suggested by Fred
stencilman Aug 4, 2014
4c55bc4
- added some documentaiton
stencilman Aug 4, 2014
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 13 additions & 0 deletions doc/library/tensor/nnet/conv.txt
Expand Up @@ -51,8 +51,21 @@ TODO: Give examples for how to use these things! They are pretty complicated.
implementation.

Also, there is restrictions on which shape are supported.
- :func:`GpuCorrMM <theano.sandbox.cuda.blas.GpuCorrMM>`
This is a GPU-only version of a correlation that computes correlations
as `caffe <https://github.com/BVLC/caffe/blob/master/src/caffe/layers/conv_layer.cu>`.
For each element in a batch, it first creates a
Toeplitz<http://en.wikipedia.org/wiki/Toeplitz_matrix> matrix in a cuda kernel.
Then, it performs a `gemm` call to multiply this Toeplitz matrix and to the kernel.
It need extra memory for this, which is the size of the Toeplitz matrix. Precisely,
the dimensions of this Toeplitz matrix is equal to
(no of channels * filter width * filter height, output width * output height).
You can enable it for call to conv2d 2d by setting 'THEANO_FLAGS=optimizer_including=conv_gemm'
in your environment. This is not enabled by default because it
uses some extra memory. It don't support strides for now and requires square kernels.

.. autofunction:: theano.tensor.nnet.conv.conv2d
.. autofunction:: theano.tensor.nnet.Conv3D.conv3D
.. autofunction:: theano.tensor.nnet.conv3d2d.conv3d
.. autofunction:: theano.sandbox.cuda.fftconv.conv2d_fft
.. autofunction:: theano.sandbox.cuda.blas.GpuCorrMM
171 changes: 171 additions & 0 deletions theano/sandbox/cuda/blas.py
Expand Up @@ -7,6 +7,7 @@
from theano.compat.six import StringIO
from theano.sandbox.cuda.type import CudaNdarrayType
from theano.sandbox.cuda import GpuOp
from theano.sandbox.cuda import as_cuda_ndarray_variable


class GpuDot22(GpuOp):
Expand Down Expand Up @@ -497,9 +498,179 @@ def c_code(self, node, name, inputs, outputs, sub):
gpu_ger_inplace = GpuGer(inplace=True)


class GpuCorrMM(GpuOp):
"""
Author: Arjun Jain
Implement the caffe convolution
"""
def __init__(self, border_mode,
subsample=(1, 1),
pad=0):
"""
:param border_mode: "valid" or "full"
:param subsample: not yet supported
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't true anymore, you are passing the subsample parameters to the C code now.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't pass it to validMM(), we just init the output memory to the right size, but will write to the full memory output region! So it is not working for now.

:param pad: not yet supported
"""
self.border_mode = border_mode
self.subsample = subsample
self.pad = pad
if pad != 0:
raise NotImplementedError(
"GpuCorrMM don't implement the pad parameter")
if subsample != (1, 1):
raise NotImplementedError(
"GpuCorrMM we don't implement the subsample parameter")

def __eq__(self, other):
return type(self) == type(other) \
and self.border_mode == other.border_mode \
and self.subsample == other.subsample \
and self.pad == other.pad

def __hash__(self):
# don't use hash(self.version) as hash(-1)==-2 and
# hash(-2)==-2 in python!
return hash(type(self)) \
^ hash(self.border_mode) \
^ hash(self.subsample) \
^ hash(self.pad)

def __str__(self):
return '%s{%s, %s, pad=%d}' % (
self.__class__.__name__,
self.border_mode,
str(self.subsample),
self.pad)

def make_node(self, img, kern):
img = as_cuda_ndarray_variable(img)
kern = as_cuda_ndarray_variable(kern)
if img.type.ndim != 4:
raise TypeError('img must be 4D tensor')
if kern.type.ndim != 4:
raise TypeError('kern must be 4D tensor')

broadcastable = [img.type.broadcastable[0], kern.type.broadcastable[0],
False, False]
return Apply(self, [img, kern], [CudaNdarrayType(broadcastable)()])

def flops(self, inputs, outputs):
""" Useful with the hack in profilemode to print the MFlops"""
images, kerns = inputs
out, = outputs
assert images[1] == kerns[1]
flops = 0
if self.border_mode == "valid":
# nb mul and add by output pixel
flops = kerns[2] * kerns[3] * 2
# nb flops by output image
flops *= out[2] * out[3]
# nb patch multiplied
flops *= images[1] * kerns[0] * images[0]
else:
flops = (images[0] * kerns[0] * images[1] *
kerns[2] * kerns[3] *
images[2] * images[3] * 2)
return flops

def c_headers(self):
return ['cuda_ndarray.cuh', '<stdio.h>']

def c_code_cache_version(self):
return
# raise this whenever modifying any of the support_code_files
return (0, 21)

def c_support_code_apply(self, node, nodename):
# REMEMBER TO RAISE c_code_cache_version when changing any of
# these files
files = ['conv_gemm.cu']
codes = [open(os.path.join(os.path.split(__file__)[0], f)).read()
for f in files]
return reduce(str.__add__, codes)

def c_code(self, node, nodename, inp, out_, sub):
img, kern = inp
out, = out_
dx = self.subsample[0]
dy = self.subsample[1]
border_mode = self.border_mode
sub = sub.copy()
pad = self.pad
sub.update(locals())

return """
//Mandatory args
const char *mode_str = "%(border_mode)s";

//Optional args
int dx = %(dx)s;
int dy = %(dy)s;
int pad = 0;
CudaNdarray * img = %(img)s;
CudaNdarray * kern = %(kern)s;
CudaNdarray * out2 = NULL;
int mode;
if (strcmp(mode_str, "full") == 0)
{
mode = 0;
}
else if (strcmp(mode_str, "valid") == 0)
{
mode = 1;
}
else
{
PyErr_SetString(PyExc_ValueError,
"mode must be one of 'full' or 'valid'");
%(fail)s;
}
//TODO: Send self.pad, stride, etc

int out_dim[4];
out_dim[0] = CudaNdarray_HOST_DIMS(img)[0];
out_dim[1] = CudaNdarray_HOST_DIMS(kern)[0];
int logical_rows, logical_cols;
if (mode == 1)
{
logical_rows = CudaNdarray_HOST_DIMS(img)[2] - CudaNdarray_HOST_DIMS(kern)[2] + 1;
logical_cols = CudaNdarray_HOST_DIMS(img)[3] - CudaNdarray_HOST_DIMS(kern)[3] + 1;
}
else
{
logical_rows = CudaNdarray_HOST_DIMS(img)[2] + CudaNdarray_HOST_DIMS(kern)[2] - 1;
logical_cols = CudaNdarray_HOST_DIMS(img)[3] + CudaNdarray_HOST_DIMS(kern)[3] - 1;
pad = CudaNdarray_HOST_DIMS(kern)[2] - 1;
}
out_dim[2] = ceil_intdiv(logical_rows, dx);
out_dim[3] = ceil_intdiv(logical_cols, dy);

if ( !(%(out)s
&& %(out)s->nd==4
&& CudaNdarray_is_c_contiguous(%(out)s)
&& CudaNdarray_HOST_DIMS(%(out)s)[0]==out_dim[0]
&& CudaNdarray_HOST_DIMS(%(out)s)[1]==out_dim[1]
&& CudaNdarray_HOST_DIMS(%(out)s)[2]==out_dim[2]
&& CudaNdarray_HOST_DIMS(%(out)s)[3]==out_dim[3]))
{
Py_XDECREF(%(out)s);
%(out)s = (CudaNdarray*)CudaNdarray_NewDims(4,out_dim);

}

out2 = corrMM(%(img)s, %(kern)s, %(out)s, pad);
if (out2==NULL){
%(fail)s
}
assert (out2 == %(out)s);

""" % sub


##
# Not really a BLAS operation, but whatever.
#

class GpuConv(GpuOp):
"""
Implement the batched and stacked 2d convolution on the gpu.
Expand Down
53 changes: 53 additions & 0 deletions theano/sandbox/cuda/caffe_common.hpp
@@ -0,0 +1,53 @@
/*
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here with the copyright notice.

Copyright (c) 2014, The Regents of the University of California (Regents)
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

#ifndef CAFFE_COMMON_HPP_
#define CAFFE_COMMON_HPP_

#include <cublas_v2.h>
#include <cuda.h>
#include <driver_types.h> // cuda driver types

// CUDA: grid stride looping
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
i += blockDim.x * gridDim.x)

// CUDA: thread number configuration.
// Use 1024 threads per block, which requires cuda sm_2x or above,
// or fall back to attempt compatibility (best of luck to you).
#if __CUDA_ARCH__ >= 200
const int CAFFE_CUDA_NUM_THREADS = 1024;
#else
const int CAFFE_CUDA_NUM_THREADS = 512;
#endif

// CUDA: number of blocks for threads.
inline int CAFFE_GET_BLOCKS(const int N) {
return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
}

#endif // CAFFE_COMMON_HPP_