Skip to content

Commit

Permalink
Merge pull request #6293 from notoraptor/support-cudnn-v7
Browse files Browse the repository at this point in the history
Add support for cuDNN V7
  • Loading branch information
nouiz committed Aug 14, 2017
2 parents a6b12aa + e7b8ab7 commit 26d4705
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 13 deletions.
3 changes: 2 additions & 1 deletion theano/gpuarray/c_code/dnn_fwd.c
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,8 @@ APPLY_SPECIFIC(conv_fwd)(PyGpuArrayObject *input, PyGpuArrayObject *kerns,

// Algo `small` does not work for a batch size > 2^16, with cuDNN >= V5.1.
// Issue should be resolved for cuDNN > V6.0.
if (cudnnGetVersion() < 6100 &&
// NB: In cuDNN V7, issue is resolved for 2D convolutionss only.
if ((cudnnGetVersion() < 6100 || PyGpuArray_NDIM(input) == 5) &&
algo == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM &&
PyGpuArray_DIM(input, 0) > 65536)
{
Expand Down
8 changes: 7 additions & 1 deletion theano/gpuarray/c_code/dnn_rnn_desc.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,17 @@ int dnn_rnn_desc(int hidden_size, int num_layers,
PyErr_SetString(PyExc_RuntimeError, "Can't create RNN descriptor");
return -1;
}

#if CUDNN_MAJOR < 7
err = cudnnSetRNNDescriptor(desc, hidden_size, num_layers, ddesc,
(cudnnRNNInputMode_t)input_mode,
(cudnnDirectionMode_t)direction_mode,
(cudnnRNNMode_t)rnn_mode, data_type);
#else
err = cudnnSetRNNDescriptor(_handle, desc, hidden_size, num_layers, ddesc,
(cudnnRNNInputMode_t)input_mode,
(cudnnDirectionMode_t)direction_mode,
(cudnnRNNMode_t)rnn_mode, CUDNN_RNN_ALGO_STANDARD, data_type);
#endif
if (err != CUDNN_STATUS_SUCCESS) {
cudnnDestroyRNNDescriptor(desc);
PyErr_SetString(PyExc_RuntimeError, "Can't set RNN descriptor");
Expand Down
25 changes: 18 additions & 7 deletions theano/gpuarray/cudnn_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
- v5.1
- v6.0
- v7.0
"""

Expand Down Expand Up @@ -102,8 +103,7 @@ class CuDNNV6(CuDNNV51):
# new in v6
('CUDNN_DATA_INT8', 'int8'),
('CUDNN_DATA_INT32', 'int32'),
# Also in v6, but restrictions make this fail
# CUDNN_DATA_INT8x4
# ('CUDNN_DATA_INT8X4', 'int8x4'),
ctype='cudnnDataType_t')

cudnnPoolingMode_t = CEnumType(('CUDNN_POOLING_MAX', 'max'),
Expand All @@ -117,10 +117,8 @@ class CuDNNV6(CuDNNV51):
('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1', 'deterministic'),
('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT', 'fft'),
('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3', 'small'),
# not implemented:
('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD'),
('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED', 'winograd_non_fused'),
# TODO: not yet tested/documented:
# new in v6:
('CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING', 'fft_tiling'),
ctype='cudnnConvolutionBwdFilterAlgo_t')
Expand All @@ -136,6 +134,16 @@ class CuDNNV6(CuDNNV51):
ctype='cudnnReduceTensorOp_t')


class CuDNNV7(CuDNNV6):
version = 7
cudnnMathType_t = CEnumType(('CUDNN_DEFAULT_MATH', 'non_tensor_op'),
('CUDNN_TENSOR_OP_MATH', 'tensor_op'),
ctype='cudnnMathType_t')
cudnnDeterminism_t = CEnumType(('CUDNN_NON_DETERMINISTIC', 'non_deterministic'),
('CUDNN_DETERMINISTIC', 'deterministic'),
ctype='cudnnDeterminism_t')


def get_definitions(cudnn_version=None):
"""
Return cuDNN definitions to be used by Theano for the given cuDNN version.
Expand All @@ -145,7 +153,10 @@ def get_definitions(cudnn_version=None):
if None, return definitions for the most recent supported cuDNN version.
"""
if cudnn_version is not None and cudnn_version // 1000 == 5:
return CuDNNV51()
if cudnn_version is not None:
if cudnn_version // 1000 == 5:
return CuDNNV51()
if cudnn_version // 1000 == 6:
return CuDNNV6()
# By default, we use definitions for the last supported cuDNN version.
return CuDNNV6()
return CuDNNV7()
8 changes: 4 additions & 4 deletions theano/gpuarray/dnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
pass

# Update these names when new versions of cudnn are supported.
WIN32_CUDNN_NAMES = ['cudnn64_6.dll', 'cudnn64_5.dll']
WIN32_CUDNN_NAMES = ['cudnn64_7.dll', 'cudnn64_6.dll', 'cudnn64_5.dll']


def _load_lib(name):
Expand Down Expand Up @@ -90,7 +90,7 @@ def _dnn_lib():
if lib_name:
break
if lib_name is None:
raise RuntimeError('Could not find cudnn library (looked for v5* or v6*)')
raise RuntimeError('Could not find cudnn library (looked for v5* to v7*)')
else:
dnn_handle = ctypes.cdll.LoadLibrary(lib_name)

Expand Down Expand Up @@ -166,11 +166,11 @@ def _dnn_check_version():
v = version()
if v < 5000:
return False, "cuDNN version is too old. Update to v5* or higher, was %d." % v
if v >= 6100:
if v >= 7200:
warnings.warn("Your cuDNN version is more recent than "
"Theano. If you encounter problems, try "
"updating Theano or downgrading cuDNN to "
"a version >= v5 and <= v6.")
"a version >= v5 and <= v7.")
return True, None


Expand Down

0 comments on commit 26d4705

Please sign in to comment.