Permalink
Browse files

Merge pull request #191 from drnikolaev/experimental_v5/fp16

Add cudnn_v5 and support L4T aarch64
  • Loading branch information...
thatguymike committed Jul 11, 2016
2 parents 28eaf01 + 87df88f commit fca1cf475d1d0a6d355f8b9877abcc4e13951c9c
@@ -88,9 +88,9 @@ namespace cub {
*/
__forceinline__ void YieldProcessor()
{
#ifndef __arm__
#if !defined(__arm__) && !defined(__aarch64__)
asm volatile("pause\n": : :"memory");
#endif // __arm__
#endif // __arm__ && __aarch64__
}
#endif // defined(_MSC_VER)
@@ -3,10 +3,7 @@ net: "examples/mnist/lenet_train_test.prototxt"
# test_iter specifies how many forward passes the test should carry out.
# In the case of MNIST, we have test batch size 100 and 100 test iterations,
# covering the full 10,000 testing images.
#test_iter: 100
test_iter: 10
test_iter: 100
# Carry out testing every 500 training iterations.
test_interval: 500
# The base learning rate, momentum and the weight decay of the network.
@@ -20,10 +17,7 @@ power: 0.75
# Display every 100 iterations
display: 100
# The maximum number of iterations
#max_iter: 10000
max_iter: 100
max_iter: 10000
# snapshot intermediate results
snapshot: 5000
snapshot_prefix: "examples/mnist/lenet"
@@ -20,9 +20,16 @@
#include "caffe/util/float16.hpp"
#ifndef CPU_ONLY
#include "cuda_fp16.h"
# include "caffe/util/float16.hpp"
# include "cuda_fp16.h"
# if CUDA_VERSION >= 8000
# define CAFFE_DATA_HALF CUDA_R_16F
# else
# define CAFFE_DATA_HALF CUBLAS_DATA_HALF
# endif
#endif
// We only build 1 flavor per host architecture:
// <float16,float> for Intel
// <float16,float16> for ARM
@@ -499,6 +499,7 @@ class CuDNNReLULayer : public ReLULayer<Dtype,Mtype> {
bool handles_setup_;
cudnnTensorDescriptor_t bottom_desc_;
cudnnTensorDescriptor_t top_desc_;
cudnnActivationDescriptor_t activ_desc_;
};
#endif
@@ -581,6 +582,7 @@ class CuDNNSigmoidLayer : public SigmoidLayer<Dtype,Mtype> {
bool handles_setup_;
cudnnTensorDescriptor_t bottom_desc_;
cudnnTensorDescriptor_t top_desc_;
cudnnActivationDescriptor_t activ_desc_;
};
#endif
@@ -665,6 +667,7 @@ class CuDNNTanHLayer : public TanHLayer<Dtype,Mtype> {
bool handles_setup_;
cudnnTensorDescriptor_t bottom_desc_;
cudnnTensorDescriptor_t top_desc_;
cudnnActivationDescriptor_t activ_desc_;
};
#endif
@@ -8,6 +8,9 @@
#include "caffe/proto/caffe.pb.h"
#include "caffe/util/float16.hpp"
#define CUDNN_VERSION_MIN(major, minor, patch) \
(CUDNN_VERSION >= (major * 1000 + minor * 100 + patch))
#define CUDNN_CHECK(condition) \
do { \
cudnnStatus_t status = condition; \
@@ -95,8 +98,13 @@ template <typename Dtype>
inline void createFilterDesc(cudnnFilterDescriptor_t* desc,
int n, int c, int h, int w) {
CUDNN_CHECK(cudnnCreateFilterDescriptor(desc));
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnSetFilter4dDescriptor(*desc, dataType<Dtype>::type,
n, c, h, w));
CUDNN_TENSOR_NCHW, n, c, h, w));
#else
CUDNN_CHECK(cudnnSetFilter4dDescriptor_v4(*desc, dataType<Dtype>::type,
CUDNN_TENSOR_NCHW, n, c, h, w));
#endif
}
template <typename Dtype>
@@ -111,14 +119,8 @@ inline void setConvolutionDesc(cudnnConvolutionDescriptor_t* conv,
int padA[2] = {pad_h,pad_w};
int strideA[2] = {stride_h,stride_w};
int upscaleA[2] = {1, 1};
CUDNN_CHECK(cudnnSetConvolutionNdDescriptor_v3(*conv,
CUDNN_CHECK(cudnnSetConvolutionNdDescriptor(*conv,
2, padA, strideA, upscaleA, CUDNN_CROSS_CORRELATION, dataType<Dtype>::type));
int array_length;
cudnnConvolutionMode_t mode;
cudnnDataType_t dataType;
CUDNN_CHECK(cudnnGetConvolutionNdDescriptor(*conv,1,&array_length,
padA, strideA, upscaleA, &mode, &dataType));
}
template <typename Dtype>
@@ -136,8 +138,21 @@ inline void createPoolingDesc(cudnnPoolingDescriptor_t* pool_desc,
LOG(FATAL) << "Unknown pooling method.";
}
CUDNN_CHECK(cudnnCreatePoolingDescriptor(pool_desc));
CUDNN_CHECK(cudnnSetPooling2dDescriptor(*pool_desc, *mode, h, w,
pad_h, pad_w, stride_h, stride_w));
int dimA[2] = {h,w};
int padA[2] = {pad_h,pad_w};
int strideA[2] = {stride_h,stride_w};
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnSetPoolingNdDescriptor(*pool_desc, *mode,
CUDNN_PROPAGATE_NAN, 2, dimA,
padA, strideA));
#else
CUDNN_CHECK(cudnnSetPoolingNdDescriptor_v4(*pool_desc, *mode,
CUDNN_PROPAGATE_NAN, 2, dimA,
padA, strideA));
#endif
}
} // namespace cudnn
@@ -5,6 +5,26 @@
#include <mkl.h>
// MKL doesn't support fp16 yet
#ifndef CPU_ONLY
#define DEFINE_VSL_UNARY_FUNC(name, operation) \
template<typename Dtype> \
void v##name(const int n, const Dtype* a, Dtype* y) { \
CHECK_GT(n, 0); CHECK(a); CHECK(y); \
for (int i = 0; i < n; ++i) { operation; } \
} \
inline void vh##name( \
const int n, const caffe::float16* a, caffe::float16* y) { \
v##name<caffe::float16>(n, a, y); \
}
DEFINE_VSL_UNARY_FUNC(Sqr, y[i] = a[i] * a[i]);
DEFINE_VSL_UNARY_FUNC(Exp, y[i] = exp(a[i]));
DEFINE_VSL_UNARY_FUNC(Ln, y[i] = log(a[i]));
DEFINE_VSL_UNARY_FUNC(Abs, y[i] = fabs(a[i]));
#endif // !CPU_ONLY
#else // If use MKL, simply include the MKL header
extern "C" {
@@ -32,7 +32,7 @@ layer {
mean_file: "data/ilsvrc12/imagenet_mean.binaryproto"
}
data_param {
source: "data/ilsvrc12/imagenet_mean.binaryproto"
source: "examples/imagenet/ilsvrc12_val_lmdb"
batch_size: 50
backend: LMDB
}
@@ -12,8 +12,8 @@
// in presence of <80 characters rule
#define cudnnConvFwd cudnnConvolutionForward
#define cudnnConvBwdBias cudnnConvolutionBackwardBias
#define cudnnConvBwdFilter cudnnConvolutionBackwardFilter_v3
#define cudnnConvBwdData cudnnConvolutionBackwardData_v3
#define cudnnConvBwdFilter cudnnConvolutionBackwardFilter
#define cudnnConvBwdData cudnnConvolutionBackwardData
namespace caffe {
@@ -58,7 +58,7 @@ void CuDNNConvolutionLayer<Dtype,Mtype>::Forward_gpu(
// Bias.
if (this->bias_term_) {
const Dtype* bias_data = this->blobs_[1]->gpu_data();
CUDNN_CHECK(cudnnAddTensor_v3(Caffe::cudnn_handle(),
CUDNN_CHECK(cudnnAddTensor(Caffe::cudnn_handle(),
cudnn::dataType<Dtype>::one,
bias_desc_,
bias_data + bias_offset_ * g,
@@ -14,6 +14,8 @@ void CuDNNReLULayer<Dtype,Mtype>::LayerSetUp(const vector<Blob<Dtype,Mtype>*>& b
// initialize cuDNN
cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
cudnn::createTensor4dDesc<Dtype>(&top_desc_);
cudnnCreateActivationDescriptor(&activ_desc_);
cudnnSetActivationDescriptor(activ_desc_, CUDNN_ACTIVATION_RELU, CUDNN_PROPAGATE_NAN, 0.0);
handles_setup_ = true;
}
@@ -34,6 +36,7 @@ CuDNNReLULayer<Dtype,Mtype>::~CuDNNReLULayer() {
// Check that handles have been setup before destroying.
if (!handles_setup_) { return; }
cudnnDestroyActivationDescriptor(this->activ_desc_);
cudnnDestroyTensorDescriptor(this->bottom_desc_);
cudnnDestroyTensorDescriptor(this->top_desc_);
}
@@ -17,12 +17,21 @@ void CuDNNReLULayer<Dtype,Mtype>::Forward_gpu(const vector<Blob<Dtype,Mtype>*>&
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationForward(Caffe::cudnn_handle(),
CUDNN_ACTIVATION_RELU,
activ_desc_,
cudnn::dataType<Dtype>::one,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->top_desc_, top_data));
#else
CUDNN_CHECK(cudnnActivationForward_v4(Caffe::cudnn_handle(),
activ_desc_,
cudnn::dataType<Dtype>::one,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->top_desc_, top_data));
#endif
}
template <typename Dtype, typename Mtype>
@@ -42,13 +51,23 @@ void CuDNNReLULayer<Dtype,Mtype>::Backward_gpu(const vector<Blob<Dtype,Mtype>*>&
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationBackward(Caffe::cudnn_handle(),
CUDNN_ACTIVATION_RELU,
activ_desc_,
cudnn::dataType<Dtype>::one,
this->top_desc_, top_data, this->top_desc_, top_diff,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->bottom_desc_, bottom_diff));
#else
CUDNN_CHECK(cudnnActivationBackward_v4(Caffe::cudnn_handle(),
activ_desc_,
cudnn::dataType<Dtype>::one,
this->top_desc_, top_data, this->top_desc_, top_diff,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->bottom_desc_, bottom_diff));
#endif
}
INSTANTIATE_LAYER_GPU_FUNCS(CuDNNReLULayer);
@@ -14,6 +14,8 @@ void CuDNNSigmoidLayer<Dtype,Mtype>::LayerSetUp(const vector<Blob<Dtype,Mtype>*>
// initialize cuDNN
cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
cudnn::createTensor4dDesc<Dtype>(&top_desc_);
cudnnCreateActivationDescriptor(&activ_desc_);
cudnnSetActivationDescriptor(activ_desc_, CUDNN_ACTIVATION_SIGMOID, CUDNN_PROPAGATE_NAN, 0.0);
handles_setup_ = true;
}
@@ -34,6 +36,7 @@ CuDNNSigmoidLayer<Dtype,Mtype>::~CuDNNSigmoidLayer() {
// Check that handles have been setup before destroying.
if (!handles_setup_) { return; }
cudnnDestroyActivationDescriptor(this->activ_desc_);
cudnnDestroyTensorDescriptor(this->bottom_desc_);
cudnnDestroyTensorDescriptor(this->top_desc_);
}
@@ -12,12 +12,21 @@ void CuDNNSigmoidLayer<Dtype,Mtype>::Forward_gpu(const vector<Blob<Dtype,Mtype>*
const vector<Blob<Dtype,Mtype>*>& top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationForward(Caffe::cudnn_handle(),
CUDNN_ACTIVATION_SIGMOID,
activ_desc_,
cudnn::dataType<Dtype>::one,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->top_desc_, top_data));
#else
CUDNN_CHECK(cudnnActivationForward_v4(Caffe::cudnn_handle(),
activ_desc_,
cudnn::dataType<Dtype>::one,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->top_desc_, top_data));
#endif
}
template <typename Dtype, typename Mtype>
@@ -32,13 +41,23 @@ void CuDNNSigmoidLayer<Dtype,Mtype>::Backward_gpu(const vector<Blob<Dtype,Mtype>
const Dtype* top_diff = top[0]->gpu_diff();
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationBackward(Caffe::cudnn_handle(),
CUDNN_ACTIVATION_SIGMOID,
activ_desc_,
cudnn::dataType<Dtype>::one,
this->top_desc_, top_data, this->top_desc_, top_diff,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->bottom_desc_, bottom_diff));
#else
CUDNN_CHECK(cudnnActivationBackward_v4(Caffe::cudnn_handle(),
activ_desc_,
cudnn::dataType<Dtype>::one,
this->top_desc_, top_data, this->top_desc_, top_diff,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->bottom_desc_, bottom_diff));
#endif
}
INSTANTIATE_LAYER_GPU_FUNCS(CuDNNSigmoidLayer);
@@ -12,6 +12,8 @@ void CuDNNTanHLayer<Dtype,Mtype>::LayerSetUp(const vector<Blob<Dtype,Mtype>*>& b
const vector<Blob<Dtype,Mtype>*>& top) {
TanHLayer<Dtype,Mtype>::LayerSetUp(bottom, top);
// initialize cuDNN
cudnnCreateActivationDescriptor( &activ_desc_);
cudnnSetActivationDescriptor( activ_desc_, CUDNN_ACTIVATION_TANH, CUDNN_PROPAGATE_NAN, 0.0);
cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
cudnn::createTensor4dDesc<Dtype>(&top_desc_);
handles_setup_ = true;
@@ -34,6 +36,7 @@ CuDNNTanHLayer<Dtype,Mtype>::~CuDNNTanHLayer() {
// Check that handles have been setup before destroying.
if (!handles_setup_) { return; }
cudnnDestroyActivationDescriptor(this->activ_desc_);
cudnnDestroyTensorDescriptor(this->bottom_desc_);
cudnnDestroyTensorDescriptor(this->top_desc_);
}
@@ -12,12 +12,21 @@ void CuDNNTanHLayer<Dtype,Mtype>::Forward_gpu(const vector<Blob<Dtype,Mtype>*>&
const vector<Blob<Dtype,Mtype>*>& top) {
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = top[0]->mutable_gpu_data();
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationForward(Caffe::cudnn_handle(),
CUDNN_ACTIVATION_TANH,
activ_desc_,
cudnn::dataType<Dtype>::one,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->top_desc_, top_data));
#else
CUDNN_CHECK(cudnnActivationForward_v4(Caffe::cudnn_handle(),
activ_desc_,
cudnn::dataType<Dtype>::one,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->top_desc_, top_data));
#endif
}
template <typename Dtype, typename Mtype>
@@ -33,13 +42,23 @@ void CuDNNTanHLayer<Dtype,Mtype>::Backward_gpu(const vector<Blob<Dtype,Mtype>*>&
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnActivationBackward(Caffe::cudnn_handle(),
CUDNN_ACTIVATION_TANH,
activ_desc_,
cudnn::dataType<Dtype>::one,
this->top_desc_, top_data, this->top_desc_, top_diff,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->bottom_desc_, bottom_diff));
#else
CUDNN_CHECK(cudnnActivationBackward_v4(Caffe::cudnn_handle(),
activ_desc_,
cudnn::dataType<Dtype>::one,
this->top_desc_, top_data, this->top_desc_, top_diff,
this->bottom_desc_, bottom_data,
cudnn::dataType<Dtype>::zero,
this->bottom_desc_, bottom_diff));
#endif
}
INSTANTIATE_LAYER_GPU_FUNCS(CuDNNTanHLayer);
Oops, something went wrong.

0 comments on commit fca1cf4

Please sign in to comment.