Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions example/30_grouped_conv_fwd_multiple_d/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ endif() # USE_BITINT_EXTENSION_INT4

if(GPU_TARGETS MATCHES "gfx1100" OR GPU_TARGETS MATCHES "gfx1101" OR GPU_TARGETS MATCHES "gfx1102")
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_fp16 grouped_conv_fwd_bias_relu_add_wmma_fp16.cpp)
add_example_executable(example_grouped_conv_fwd_bias_relu_add_wmma_int8 grouped_conv_fwd_bias_relu_add_wmma_int8.cpp)
endif()

add_example_executable(example_grouped_conv_fwd_xdl_fp16 grouped_conv_fwd_xdl_fp16.cpp)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.

#include "common_wmma.hpp"

// kernel data types
using InKernelDataType = I8;
using WeiKernelDataType = I8;
using AccDataType = I32;
using CShuffleDataType = I8;
using BiasKernelDataType = I8;
using ResidualKernelDataType = I8;
using OutKernelDataType = I8;

// tensor data types
using InUserDataType = InKernelDataType;
using WeiUserDataType = WeiKernelDataType;
using OutUserDataType = OutKernelDataType;

using InElementOp = PassThrough;
using WeiElementOp = PassThrough;
using OutElementOp = ck::tensor_operation::element_wise::AddReluAdd;

#include "run_grouped_conv_fwd_bias_relu_add_wmma_example.inc"

int main(int argc, char* argv[]) { return !run_grouped_conv_fwd_bias_relu_add_example(argc, argv); }
8 changes: 4 additions & 4 deletions include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,12 +262,12 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,

template <index_t MPerWmma,
index_t NPerWmma,
bool neg_a,
bool neg_b,
bool clamp,
class FloatA,
class FloatB,
class FloatC>
class FloatC,
bool neg_a = false,
bool neg_b = false,
bool clamp = false>
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
{
if constexpr(wave_size == 32)
Expand Down