Skip to content

Commit

Permalink
Merge pull request #7 from f0k/cormm-cleantest
Browse files Browse the repository at this point in the history
Cleanup CUDA convolution tests
  • Loading branch information
stencilman committed Aug 13, 2014
2 parents 4d4c928 + 68e82ec commit b5e340b
Showing 1 changed file with 41 additions and 34 deletions.
75 changes: 41 additions & 34 deletions theano/sandbox/cuda/tests/test_conv_cuda_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,6 @@ def exec_conv(version, shapes, verbose, random, mode,

failed_version = set()
failed_id = []
# I put -1 in case we forget to add version in the test to.
for ver in version:
for id, (ishape, kshape, subshape,
istride, kstride) in enumerate(shapes):
Expand Down Expand Up @@ -616,7 +615,7 @@ def test_valid_9_10():
print_=print_, ones=ones, rtol=1.1e-5)


def test_valid():
def test_valid(conv_gemm=False):
seed_rng()
shapes = get_valid_shapes()

Expand All @@ -625,29 +624,32 @@ def test_valid():
# I put -2 to test the reference version.
version = [-2, -1, 6]
verbose = 0
# version=[1]

random = True
print_ = False
ones = False
if ones:
random = False

# exec_conv(version, shapes, verbose, random, 'valid',
# print_=print_, ones=ones, rtol=1.1e-5)

mode = theano_mode.including("conv_gemm")

version = [-1]
# Add tests with strided inputs by still square images and filters.
shapes += get_shapes2(scales_img=(2, 2), img_stride=(2, 2))
shapes += get_shapes2(scales_kern=(2, 2), kern_stride=(2, 2))
if conv_gemm:
# Test the GpuCorrMM version
mode = theano_mode.including("conv_gemm")
cls = cuda.blas.GpuCorrMM
version = [-1] # dummy version; not used by GpuCorrMM so one version is enough
# Add tests with strided inputs by still square images and filters.
shapes += get_shapes2(scales_img=(2, 2), img_stride=(2, 2))
shapes += get_shapes2(scales_kern=(2, 2), kern_stride=(2, 2))
else:
mode = cls = None
exec_conv(version, shapes, verbose, random, 'valid',
print_=print_, ones=ones, rtol=1.1e-5,
theano_mode=mode, cls=cuda.blas.GpuCorrMM)
theano_mode=mode, cls=cls)

def test_gemm_valid():
test_valid(conv_gemm=True)


def test_full():
def test_full(conv_gemm=False):
seed_rng()
shapes = get_basic_shapes()
shapes += get_shapes2()
Expand Down Expand Up @@ -704,22 +706,24 @@ def test_full():
# shapes=shapes[:277]
version = [-2, -1, 0, 1, 2, 3, 4, 5]
verbose = 0
# version=[4]
random = True

# exec_conv(version, shapes, verbose, random, 'full')

# Test the GpuCorrMM version
mode = theano_mode.including("conv_gemm")

shapes = shapes[0:10]
if conv_gemm:
# Test the GpuCorrMM version
mode = theano_mode.including("conv_gemm")
cls = cuda.blas.GpuCorrMM
version = [-1] # dummy version; not used by GpuCorrMM so one version is enough
else:
mode = cls = None
exec_conv(version, shapes, verbose, random, 'full',
theano_mode=mode, cls=cuda.blas.GpuCorrMM)
theano_mode=mode, cls=cls)

def test_gemm_full():
test_full(conv_gemm=True)


def test_subsample():
def test_subsample(conv_gemm=False):
seed_rng()
# implement when
shapes = [((1, 1, 1, 1), (1, 1, 1, 1), (1, 1), (1, 1), (1, 1)),
((1, 1, 1, 1), (1, 1, 1, 1), (2, 2), (1, 1), (1, 1)),
((4, 2, 10, 10), (3, 2, 2, 2), (1, 3), (1, 1), (1, 1)),
Expand All @@ -741,20 +745,23 @@ def test_subsample():
if ones:
random = False

# exec_conv(version_valid, shapes, verbose, random, 'valid',
# print_=print_, ones=ones)
# exec_conv(version_full, shapes, verbose, random, 'full',
# print_=print_, ones=ones)

# Test the GpuCorrMM version
mode = theano_mode.including("conv_gemm")
if conv_gemm:
# Test the GpuCorrMM version
mode = theano_mode.including("conv_gemm")
cls = cuda.blas.GpuCorrMM
version_valid = version_full = [-1] # dummy version; not used by GpuCorrMM so one version is enough
else:
mode = cls = None

exec_conv(version_valid, shapes, verbose, random, 'valid',
print_=print_, ones=ones,
theano_mode=mode, cls=cuda.blas.GpuCorrMM)
theano_mode=mode, cls=cls)
exec_conv(version_full, shapes, verbose, random, 'full',
print_=print_, ones=ones,
theano_mode=mode, cls=cuda.blas.GpuCorrMM)
theano_mode=mode, cls=cls)

def test_gemm_subsample():
test_subsample(conv_gemm=True)


class TestConv2DGPU(unittest.TestCase):
Expand Down Expand Up @@ -829,7 +836,7 @@ def test_invalid_input_shape(self):



def test_gemm():
def test_gemm_directly():
"""
input: (batch size, channels, rows, columns)
filters: (number of filters, channels, rows, columns)
Expand Down

0 comments on commit b5e340b

Please sign in to comment.