Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion src/common/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,16 @@ inline const char* CurandGetErrorString(curandStatus_t status) {
return "Unknown cuRAND status";
}

template <typename DType>
inline DType __device__ CudaMax(DType a, DType b) {
return a > b ? a : b;
}

template <typename DType>
inline DType __device__ CudaMin(DType a, DType b) {
return a < b ? a : b;
}

} // namespace cuda
} // namespace common
} // namespace mxnet
Expand Down Expand Up @@ -219,6 +229,14 @@ inline const char* CurandGetErrorString(curandStatus_t status) {
<< "cuRAND: " << common::cuda::CurandGetErrorString(e); \
}

#if !defined(_MSC_VER)
#define CUDA_UNROLL _Pragma("unroll")
#define CUDA_NOUNROLL _Pragma("nounroll")
#else
#define CUDA_UNROLL
#define CUDA_NOUNROLL
#endif

/*!
* \brief Determine major version number of the gpu's cuda compute architecture.
* \param device_id The device index of the cuda-capable gpu of interest.
Expand Down Expand Up @@ -291,7 +309,6 @@ inline bool GetEnvAllowTensorCore() {
return dmlc::GetEnv("MXNET_CUDA_ALLOW_TENSOR_CORE",
dmlc::optional<bool>(default_value)).value();
}

#endif // MXNET_USE_CUDA

#if MXNET_USE_CUDNN
Expand Down Expand Up @@ -401,6 +418,15 @@ static inline __device__ void atomicAdd(mshadow::half::half_t *address,
old = atomicCAS(address_as_ui, assumed, old);
} while (assumed != old);
}

template <typename DType>
__device__ inline DType ldg(const DType* address) {
#if __CUDA_ARCH__ >= 350
return __ldg(address);
#else
return *address;
#endif
}
#endif

#endif // MXNET_COMMON_CUDA_UTILS_H_
14 changes: 14 additions & 0 deletions src/operator/convolution.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
#include "./cudnn_convolution-inl.h"
#endif // MXNET_USE_CUDNN

#include "./depthwise_convolution-inl.h"

namespace mxnet {
namespace op {

Expand All @@ -45,6 +47,18 @@ Operator* CreateOp<gpu>(ConvolutionParam param, int dtype,
})
return op;
}

// depth wise conv
if (param.num_filter == param.num_group &&
param.layout.value() == mshadow::kNCHW &&
param.num_filter == (*in_shape)[conv::kData][1] &&
param.kernel.ndim() == 2 &&
param.dilate == mshadow::Shape2(1, 1) &&
dtype == mshadow::kFloat32) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason for limiting to float32?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are no processing in cuda kernel when dtype==mshadow::kFloat16

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there plan to support dilation with depthwise kernel? It is used in mobilenet v2 + deeplabv3 for segmentation. Tensorflow has efficient implementation. mxnet is much slower in this case.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@crazy-cat Will you please implement support for dilation rate > 1?

op = new DepthwiseConvolutionOp<float>(param, *in_shape, *out_shape);
return op;
}

#if MXNET_USE_CUDNN == 1
// The NVIDIA Pascal architecture was the first to include 16-bit ALUs.
// Thus, when the framework is compiled with MSHADOW_USE_PASCAL == 1, we
Expand Down
Loading