Skip to content

Commit

Permalink
Merge pull request #4591 from pcs-theano/pcs-corr-opt
Browse files Browse the repository at this point in the history
corr_gemm optimization to improve CNN performance
  • Loading branch information
theano-bot committed Jul 8, 2016
2 parents b9813e0 + cc3deb5 commit 2cdc1e6
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 54 deletions.
3 changes: 2 additions & 1 deletion theano/gof/cmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1873,7 +1873,8 @@ def compile_args(march_flags=True):

if ('g++' not in theano.config.cxx and
'clang++' not in theano.config.cxx and
'clang-omp++' not in theano.config.cxx):
'clang-omp++' not in theano.config.cxx and
'icpc' not in theano.config.cxx):
_logger.warn(
"OPTIMIZATION WARNING: your Theano flag `cxx` seems not to be"
" the g++ compiler. So we disable the compiler optimization"
Expand Down
43 changes: 43 additions & 0 deletions theano/tensor/blas_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,49 @@ def blas_header_text():
return header


def mkl_threads_text():
"""C header for MKL threads interface"""
header = """
extern "C"
{
int MKL_Set_Num_Threads_Local(int);
#define mkl_set_num_threads_local MKL_Set_Num_Threads_Local
void MKL_Set_Num_Threads(int);
#define mkl_set_num_threads MKL_Set_Num_Threads
int MKL_Get_Max_Threads(void);
#define mkl_get_max_threads MKL_Get_Max_Threads
int MKL_Domain_Set_Num_Threads(int, int);
#define mkl_domain_set_num_threads MKL_Domain_Set_Num_Threads
int MKL_Domain_Get_Max_Threads(int);
#define mkl_domain_get_max_threads MKL_Domain_Get_Max_Threads
void MKL_Set_Dynamic(int);
#define mkl_set_dynamic MKL_Set_Dynamic
int MKL_Get_Dynamic(void);
#define mkl_get_dynamic MKL_Get_Dynamic
}
"""
return header


def openblas_threads_text():
"""C header for OpenBLAS threads interface"""
header = """
extern "C"
{
void openblas_set_num_threads(int);
void goto_set_num_threads(int);
int openblas_get_num_threads(void);
}
"""
return header


def blas_header_version():
# Version for the base header
version = (1,)
Expand Down
78 changes: 52 additions & 26 deletions theano/tensor/nnet/corr.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@
from theano import gof
from theano.tensor import as_tensor_variable, TensorType
from theano.tensor.nnet.abstract_conv import get_conv_output_shape
from theano.tensor.blas_headers import blas_header_text
from theano.tensor.blas import ldflags

from theano.tensor import blas_headers
from theano.tensor.blas import ldflags, blas_header_version

_logger = logging.getLogger(__name__)


class BaseCorrMM(gof.Op):
class BaseCorrMM(gof.OpenMPOp):
"""
Base class for `CorrMM`, `CorrMM_gradWeights` and
`CorrMM_gradInputs`. Cannot be used directly.
Expand All @@ -34,7 +33,8 @@ class BaseCorrMM(gof.Op):
__props__ = ('border_mode', 'subsample', 'filter_dilation')

def __init__(self, border_mode="valid", subsample=(1, 1),
filter_dilation=(1, 1)):
filter_dilation=(1, 1), openmp=None):
super(BaseCorrMM, self).__init__(openmp=openmp)
if isinstance(border_mode, integer_types):
if border_mode < 0:
raise ValueError(
Expand Down Expand Up @@ -62,6 +62,16 @@ def __init__(self, border_mode="valid", subsample=(1, 1),
self.subsample = tuple(subsample)
self.filter_dilation = tuple(filter_dilation)

if not theano.config.blas.ldflags:
raise NotImplementedError("C code for corrMM* classes need a blas library.")
else:
if 'openblas' in theano.config.blas.ldflags:
self.blas_type = 'openblas'
elif 'mkl' in theano.config.blas.ldflags:
self.blas_type = 'mkl'
else:
self.blas_type = ''

@property
def pad(self):
if self.border_mode != 'valid':
Expand All @@ -76,13 +86,20 @@ def __str__(self):
str(self.filter_dilation))

def c_support_code(self):
return blas_header_text()
ccodes = blas_headers.blas_header_text()
if self.blas_type == 'openblas':
ccodes += blas_headers.openblas_threads_text()
elif self.blas_type == 'mkl':
ccodes += blas_headers.mkl_threads_text()
return ccodes

def c_libraries(self):
return ldflags()

def c_compile_args(self):
return ldflags(libs=False, flags=True)
compile_args = ldflags(libs=False, flags=True)
compile_args += super(BaseCorrMM, self).c_compile_args()
return compile_args

def c_lib_dirs(self):
return ldflags(libs=False, libs_dir=True)
Expand All @@ -91,11 +108,13 @@ def c_header_dirs(self):
return ldflags(libs=False, include_dir=True)

def c_headers(self):
return ['<stdio.h>']
headers = ['<stdio.h>']
headers += super(BaseCorrMM, self).c_headers()
return headers

def c_code_cache_version(self):
# raise this whenever modifying any of the support_code_files
return (1, 2)
return (1, self.openmp, blas_header_version())

def c_support_code_apply(self, node, nodename):
# REMEMBER TO RAISE c_code_cache_version when changing any of
Expand All @@ -115,6 +134,28 @@ def c_support_code_apply(self, node, nodename):
sub['float_typenum'] = 'NPY_DOUBLE'
sub['n_bytes'] = 8
sub['c_float_type'] = 'double'

if self.openmp:
sub['omp_flags'] = '#pragma omp parallel for schedule(static)'
sub['omp_get_max_threads'] = 'omp_get_max_threads()'
sub['omp_get_thread_num'] = 'omp_get_thread_num()'

if self.blas_type == 'openblas':
sub['blas_set_num_threads'] = 'openblas_set_num_threads'
sub['blas_get_num_threads'] = 'openblas_get_num_threads()'
elif self.blas_type == 'mkl':
sub['blas_set_num_threads'] = 'mkl_set_num_threads'
sub['blas_get_num_threads'] = 'mkl_get_max_threads()'
else:
sub['blas_set_num_threads'] = ''
sub['blas_get_num_threads'] = '0'
else:
sub['omp_flags'] = ''
sub['omp_get_max_threads'] = '1'
sub['omp_get_thread_num'] = '0'
sub['blas_set_num_threads'] = ''
sub['blas_get_num_threads'] = '0'

files = ['corr_gemm.c']
codes = [open(os.path.join(os.path.split(__file__)[0], f)).read()
for f in files]
Expand Down Expand Up @@ -158,8 +199,6 @@ def c_code_helper(self, bottom, weights, top, direction, sub, height=None, width
If self.border_mode == 'half', a variable giving the width of the
filters for direction="backprop weights". Ignored otherwise.
"""
if not theano.config.blas.ldflags:
raise NotImplementedError("C code for CorrMM* classes need a blas library.")
dH, dW = self.subsample
dilH, dilW = self.filter_dilation
if self.border_mode == "half":
Expand Down Expand Up @@ -325,7 +364,8 @@ def c_code_helper(self, bottom, weights, top, direction, sub, height=None, width
else {
typenum = PyArray_TYPE(bottom);
}
%(out)s = (PyArrayObject*)PyArray_EMPTY(4,
//Change to PyArray_ZEROS which is faster than PyArray_EMPTY.
%(out)s = (PyArrayObject*)PyArray_ZEROS(4,
out_dim,
typenum,
0);
Expand Down Expand Up @@ -376,9 +416,6 @@ class CorrMM(BaseCorrMM):
Set to `(1, 1)` to disable filter dilation.
"""
def __init__(self, border_mode="valid", subsample=(1, 1),
filter_dilation=(1, 1)):
super(CorrMM, self).__init__(border_mode, subsample, filter_dilation)

def make_node(self, img, kern):
img = as_tensor_variable(img)
Expand Down Expand Up @@ -436,12 +473,6 @@ class CorrMM_gradWeights(BaseCorrMM):
"""

def __init__(self, border_mode="valid", subsample=(1, 1),
filter_dilation=(1, 1)):
super(CorrMM_gradWeights, self).__init__(border_mode,
subsample,
filter_dilation)

def make_node(self, img, topgrad, shape=None):
img = as_tensor_variable(img)
topgrad = as_tensor_variable(topgrad)
Expand Down Expand Up @@ -538,11 +569,6 @@ class CorrMM_gradInputs(BaseCorrMM):
"""

def __init__(self, border_mode="valid", subsample=(1, 1), filter_dilation=(1, 1)):
super(CorrMM_gradInputs, self).__init__(border_mode,
subsample,
filter_dilation)

def make_node(self, kern, topgrad, shape=None):
kern = as_tensor_variable(kern)
topgrad = as_tensor_variable(topgrad)
Expand Down

0 comments on commit 2cdc1e6

Please sign in to comment.