Skip to content

Commit

Permalink
feat(cutlass): add uint8 testcase for rrconv wgrad
Browse files Browse the repository at this point in the history
GitOrigin-RevId: bda81cbd24692cfd5581c61fd3adb2339714364a
  • Loading branch information
megvii-mge committed Nov 21, 2022
1 parent 0027375 commit c8d1437
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 14 deletions.
Expand Up @@ -452,8 +452,8 @@ bool TestRegionRestrictedDepthwiseConv2dWgrad() {

for (int n : {160, 48, 33}) {
for (int g : {3, 7}) {
for (int ih : {16}) {
for (int iw : {16}) {
for (int ih : {16, 19}) {
for (int iw : {16, 19}) {
for (int fh : {15, 7, 5, 3}) {
for (int ph : {static_cast<int>(fh / 2), 0}) {
for (int sh : {1, 2}) {
Expand Down
Expand Up @@ -55,7 +55,7 @@

#include "region_restricted_conv2d_wgrad_testbed.h"

#define RUN_DEPTHWISE_CONVOLUTION(stage) \
#define RUN_DEPTHWISE_CONVOLUTION(stage, dt) \
do { \
using ElementOutput = float; \
using ElementAccumulator = float; \
Expand All @@ -64,8 +64,8 @@
using Convolution = cutlass::conv::device:: \
RegionRestrictedConvolutionBackwardFilter< \
float, cutlass::layout::TensorNCHW, float, \
cutlass::layout::TensorNCHW, int32_t, \
cutlass::layout::TensorNCHW, int32_t, \
cutlass::layout::TensorNCHW, dt, \
cutlass::layout::TensorNCHW, dt, \
cutlass::layout::TensorNCHW, ElementOutput, \
cutlass::layout::TensorNCHW, float, \
cutlass::conv::ConvType::kDepthwiseConvolution, \
Expand All @@ -91,62 +91,71 @@ TEST(SM50_DeviceRegion_Resticted__Depthwise_Conv2dWgrad_f32_f32_NCHW_simt_op,
128x128x8_32x64x8) {
using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 8>;
using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>;
RUN_DEPTHWISE_CONVOLUTION(1);
RUN_DEPTHWISE_CONVOLUTION(1, int32_t);
RUN_DEPTHWISE_CONVOLUTION(1, uint8_t);
}

TEST(SM50_DeviceRegion_Resticted__Depthwise_Conv2dWgrad_f32_f32_NCHW_simt_op,
128x128x8_64x32x8) {
using ThreadBlockShape = cutlass::gemm::GemmShape<128, 128, 8>;
using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>;
RUN_DEPTHWISE_CONVOLUTION(1);
RUN_DEPTHWISE_CONVOLUTION(1, int32_t);
RUN_DEPTHWISE_CONVOLUTION(1, uint8_t);
}

TEST(SM50_DeviceRegion_Resticted__Depthwise_Conv2dWgrad_f32_f32_NCHW_simt_op,
64x128x8_32x64x8) {
using ThreadBlockShape = cutlass::gemm::GemmShape<64, 128, 8>;
using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>;
RUN_DEPTHWISE_CONVOLUTION(1);
RUN_DEPTHWISE_CONVOLUTION(1, int32_t);
RUN_DEPTHWISE_CONVOLUTION(1, uint8_t);
}

TEST(SM50_DeviceRegion_Resticted__Depthwise_Conv2dWgrad_f32_f32_NCHW_simt_op,
128x64x8_32x64x8) {
using ThreadBlockShape = cutlass::gemm::GemmShape<128, 64, 8>;
using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>;
RUN_DEPTHWISE_CONVOLUTION(1);
RUN_DEPTHWISE_CONVOLUTION(1, int32_t);
RUN_DEPTHWISE_CONVOLUTION(1, uint8_t);
}

TEST(SM50_DeviceRegion_Resticted__Depthwise_Conv2dWgrad_f32_f32_NCHW_simt_op,
32x128x8_32x64x8) {
using ThreadBlockShape = cutlass::gemm::GemmShape<32, 128, 8>;
using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>;
RUN_DEPTHWISE_CONVOLUTION(1);
RUN_DEPTHWISE_CONVOLUTION(1, int32_t);
RUN_DEPTHWISE_CONVOLUTION(1, uint8_t);
}

TEST(SM50_DeviceRegion_Resticted__Depthwise_Conv2dWgrad_f32_f32_NCHW_simt_op,
128x32x8_64x32x8) {
using ThreadBlockShape = cutlass::gemm::GemmShape<128, 32, 8>;
using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>;
RUN_DEPTHWISE_CONVOLUTION(1);
RUN_DEPTHWISE_CONVOLUTION(1, int32_t);
RUN_DEPTHWISE_CONVOLUTION(1, uint8_t);
}

TEST(SM50_DeviceRegion_Resticted__Depthwise_Conv2dWgrad_f32_f32_NCHW_simt_op,
32x64x8_32x64x8) {
using ThreadBlockShape = cutlass::gemm::GemmShape<32, 64, 8>;
using WarpShape = cutlass::gemm::GemmShape<32, 64, 8>;
RUN_DEPTHWISE_CONVOLUTION(1);
RUN_DEPTHWISE_CONVOLUTION(1, int32_t);
RUN_DEPTHWISE_CONVOLUTION(1, uint8_t);
}

TEST(SM50_DeviceRegion_Resticted__Depthwise_Conv2dWgrad_f32_f32_NCHW_simt_op,
64x32x8_64x32x8) {
using ThreadBlockShape = cutlass::gemm::GemmShape<64, 32, 8>;
using WarpShape = cutlass::gemm::GemmShape<64, 32, 8>;
RUN_DEPTHWISE_CONVOLUTION(1);
RUN_DEPTHWISE_CONVOLUTION(1, int32_t);
RUN_DEPTHWISE_CONVOLUTION(1, uint8_t);
}

TEST(SM50_DeviceRegion_Resticted__Depthwise_Conv2dWgrad_f32_f32_NCHW_simt_op,
32x32x8_32x32x8) {
using ThreadBlockShape = cutlass::gemm::GemmShape<32, 32, 8>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, 8>;
RUN_DEPTHWISE_CONVOLUTION(1);
RUN_DEPTHWISE_CONVOLUTION(1, int32_t);
RUN_DEPTHWISE_CONVOLUTION(1, uint8_t);
}
//////////////////////////////////////////////////////////////////////////////////

0 comments on commit c8d1437

Please sign in to comment.