Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
d6e053a
Part of gemm + softmax, Add gemm + reduceMax
rocking5566 Apr 3, 2022
cbbc7e5
Refine the comment
rocking5566 Apr 6, 2022
3e811cc
Add device op for elementwise 2d
rocking5566 Apr 10, 2022
0277c89
Merge branch 'develop' into gemm_softmax
rocking5566 Apr 10, 2022
6818b58
Fix compile error
rocking5566 Apr 10, 2022
e3a09b5
Add gridwise_elementwise_2d api
rocking5566 Apr 11, 2022
cb1c473
Merge remote-tracking branch 'origin/develop' into gemm_softmax
rocking5566 Apr 11, 2022
a760a73
A kernel of elementwise_2d (except global store)
rocking5566 Apr 12, 2022
c8b4ac2
Add global write
rocking5566 Apr 13, 2022
f2540aa
Add exponential
rocking5566 Apr 13, 2022
30348da
[What] Refine naming
rocking5566 Apr 13, 2022
b05a594
Add reduce sum for denominator of softmax
rocking5566 Apr 13, 2022
6a781e5
Add broadcast div, the final step of softmax
rocking5566 Apr 13, 2022
dba65b1
Rewrite the gridwise_elementwise_
rocking5566 Apr 14, 2022
fe65950
Add verication of softmax
rocking5566 Apr 15, 2022
e83b22e
[What] Use half_float::half instead of ck::half_t for host reduction
rocking5566 Apr 18, 2022
21802fd
[What] Sync input of each host kernel and device kernel
rocking5566 Apr 18, 2022
c16f789
Merge remote-tracking branch 'origin/develop' into gemm_softmax
rocking5566 Apr 18, 2022
cf32669
[What] Use F32 as the acc of reduce sum
rocking5566 Apr 20, 2022
0f421d6
[What] Add ComputeDataType to the eltwise kernel
rocking5566 Apr 20, 2022
5fa209a
Add padding
rocking5566 Apr 20, 2022
0e6bf34
Rename elementwise p[ to binary elementwise
rocking5566 Apr 20, 2022
d7112d3
Fix the padding
rocking5566 Apr 20, 2022
88d621a
Merge remote-tracking branch 'origin/develop' into gemm_softmax
rocking5566 Apr 21, 2022
5d36f7a
Rewrite the elementwise operation.
rocking5566 Apr 21, 2022
680cfaa
Fix the meaning of broadcast dim parameter
rocking5566 Apr 22, 2022
a41f548
1. Fix coding style
rocking5566 Apr 25, 2022
f919809
Move threadPerBlock to argument
rocking5566 Apr 26, 2022
976815e
Prevent compile error when user pass rvalue, eg {3, 4}
rocking5566 Apr 28, 2022
ea09fd3
Fix typo
rocking5566 Apr 28, 2022
bfc8076
[What] Fix data type for host reduction
rocking5566 Apr 29, 2022
d92fb7e
Merge commit 'a3c910ac6cdd0c5b724449af312255abe5b531e1' into gemm_sof…
rocking5566 May 9, 2022
b6fe118
Fix typo
rocking5566 May 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions example/19_gemm_softmax/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
add_example_executable(example_gemm_softmax_xdl_fp16 gemm_softmax_xdl_fp16.cpp)
547 changes: 547 additions & 0 deletions example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions example/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@ add_subdirectory(17_convnd_bwd_data_xdl)
add_subdirectory(15_grouped_gemm)
add_subdirectory(16_gemm_reduce)
add_subdirectory(18_batched_gemm_reduce)
add_subdirectory(19_gemm_softmax)
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once
#include <iostream>
#include <vector>

#include "device_base.hpp"

namespace ck {
namespace tensor_operation {
namespace device {

template <typename ElementwiseFunctor>
struct DeviceBinaryElementwise : public BaseOperator
{

virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
std::vector<int> shape_a,
std::vector<int> stride_a,
std::vector<int> shape_b,
std::vector<int> stride_b,
ElementwiseFunctor functor,
index_t threadPerBlock) = 0;

virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
};

} // namespace device
} // namespace tensor_operation
} // namespace ck
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
#pragma once
#include <iostream>
#include <vector>

#include "device.hpp"
#include "device_binary_elementwise.hpp"
#include "gridwise_binary_elementwise_1d.hpp"

namespace ck {
namespace tensor_operation {
namespace device {

template <typename ADataType,
typename BDataType,
typename CDataType,
typename ComputeDataType,
typename ElementwiseFunctor,
index_t ScalarPerVector>
struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFunctor>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DeviceBinaryElementwise_ND

You could make this Device Operation supporting N-D tensor (N=1~5).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add this task to the backlog

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please open a JIRA task ticket and a github issue, and refer to this comment in both tickets.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{
static constexpr auto I0 = Number<0>{};

static auto MakeDescriptor_M0(const std::vector<int>& shape,
const std::vector<int>& stride,
index_t gridSize,
index_t threadPerBlock)
{
const int m = shape[0];
const int n = shape[1];

// 2d desc - [m, n]
const auto desc_m_n =
make_naive_tensor_descriptor(make_tuple(m, n), make_tuple(stride[0], stride[1]));

// 1d desc - [m * n]
const auto desc_m0 =
transform_tensor_descriptor(desc_m_n,
make_tuple(make_merge_transform(make_tuple(m, n))),
make_tuple(Sequence<0, 1>{}),
make_tuple(Sequence<0>{}));

// pad
const auto m0 = desc_m0.GetLength(I0);
const index_t loop_step = gridSize * threadPerBlock * ScalarPerVector;
const auto pad = math::integer_least_multiple(m0, loop_step) - m0;
const auto desc_m0_pad =
transform_tensor_descriptor(desc_m0,
make_tuple(make_right_pad_transform(m0, pad)),
make_tuple(Sequence<0>{}),
make_tuple(Sequence<0>{}));
return desc_m0_pad;
}

using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1));
using GridwiseBinEltwise = GridwiseBinaryElementwise_1D<ADataType,
BDataType,
CDataType,
ComputeDataType,
GridDesc_M0,
ElementwiseFunctor,
ScalarPerVector>;

struct Argument : public BaseArgument
{
Argument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
const std::vector<int>& shape,
const std::vector<int>& stride_a,
const std::vector<int>& stride_b,
const std::vector<int>& stride_c,
ElementwiseFunctor functor,
index_t threadPerBlock)
: p_a_(p_a),
p_b_(p_b),
p_c_(p_c),
functor_(functor),
threadPerBlock_(threadPerBlock),
gridSize_(128) // FIXME - Calculate the grid size by number of CU in the future
{
a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, threadPerBlock_);
b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, threadPerBlock_);
c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_, threadPerBlock_);
}

const ADataType* p_a_;
const BDataType* p_b_;
CDataType* p_c_;
GridDesc_M0 a_grid_desc_m0_;
GridDesc_M0 b_grid_desc_m0_;
GridDesc_M0 c_grid_desc_m0_;
ElementwiseFunctor functor_;
index_t threadPerBlock_;
index_t gridSize_;
};

struct Invoker : public BaseInvoker
{
float Run(const Argument& arg, int nrepeat = 1)
{
(void)arg;
const auto kernel = kernel_elementwise_1d<GridwiseBinEltwise,
ADataType,
BDataType,
CDataType,
GridDesc_M0,
ElementwiseFunctor>;
float avgTime = 0;
if(nrepeat == 0)
{
launch_kernel(kernel,
dim3(arg.gridSize_),
dim3(arg.threadPerBlock_),
0,
arg.p_a_,
arg.p_b_,
arg.p_c_,
arg.a_grid_desc_m0_,
arg.b_grid_desc_m0_,
arg.c_grid_desc_m0_,
arg.functor_);
}
else
{
avgTime = launch_and_time_kernel(kernel,
nrepeat,
dim3(arg.gridSize_),
dim3(arg.threadPerBlock_),
0,
arg.p_a_,
arg.p_b_,
arg.p_c_,
arg.a_grid_desc_m0_,
arg.b_grid_desc_m0_,
arg.c_grid_desc_m0_,
arg.functor_);
}
return avgTime;
}

float Run(const BaseArgument* p_arg, int nrepeat = 1) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
};
};

bool IsSupportedArgument(const BaseArgument* p_arg) override
{
const Argument* pArg = dynamic_cast<const Argument*>(p_arg);

if(pArg == nullptr)
return false;

// m * n
const auto m0 = pArg->c_grid_desc_m0_.GetLength(I0);

if(m0 % ScalarPerVector != 0)
return false;

return true;
};

std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
std::vector<int> shape,
std::vector<int> stride_a,
std::vector<int> stride_b,
std::vector<int> stride_c,
ElementwiseFunctor functor,
index_t threadPerBlock) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
shape,
stride_a,
stride_b,
stride_c,
functor,
threadPerBlock);
}

std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}

std::string GetTypeString() const override
{
auto str = std::stringstream();

// clang-format off
str << "DeviceBinaryElementwise_2D"
<< "<"
<< "ScalarPerVector = " << ScalarPerVector
<< ">";
// clang-format on

return str.str();
}
};

} // namespace device
} // namespace tensor_operation
} // namespace ck
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#pragma once

#include "cluster_descriptor.hpp"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#include "cluster_descriptor.hpp" is not needed since you don't use make_cluster_descriptor(). Also, data_type.hpp is not needed. Actually several other headers are needed even though they are included in-directly, eg. tensor_descriptor_helper.hpp and get_id.hpp.

You don't have to change, there are lots of similar issues in other C.K codes.

#include "data_type.hpp"
#include "element_wise_operation.hpp"
#include "threadwise_tensor_slice_transfer.hpp"

namespace ck {

template <typename GridwiseBinEltwise,
typename ADataType,
typename BDataType,
typename CDataType,
typename GridDesc_M0,
typename ElementwiseFunctor>
__global__ void kernel_elementwise_1d(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global,
CDataType* __restrict__ p_c_global,
const GridDesc_M0 a_grid_desc_m0,
const GridDesc_M0 b_grid_desc_m0,
const GridDesc_M0 c_grid_desc_m0,
const ElementwiseFunctor functor)
{
GridwiseBinEltwise::Run(p_a_global,
p_b_global,
p_c_global,
a_grid_desc_m0,
b_grid_desc_m0,
c_grid_desc_m0,
functor);
}

template <typename ADataType,
typename BDataType,
typename CDataType,
typename ComputeDataType,
typename GridDesc_M0,
typename ElementwiseFunctor,
index_t ScalarPerVector>
struct GridwiseBinaryElementwise_1D
{
static constexpr auto I0 = Number<0>{};
static constexpr auto thread_desc_m0 =
make_naive_tensor_descriptor_packed(make_tuple(Number<ScalarPerVector>{}));

using PassThrough = tensor_operation::element_wise::PassThrough;

static __device__ __host__ auto CalculateElementwiseIndex()
{
const index_t global_thread_id = get_thread_global_1d_id();
return make_multi_index(global_thread_id * ScalarPerVector);
}

__device__ static void Run(const ADataType* __restrict__ p_a_global,
const BDataType* __restrict__ p_b_global,
CDataType* __restrict__ p_c_global,
const GridDesc_M0 a_grid_desc_m0,
const GridDesc_M0 b_grid_desc_m0,
const GridDesc_M0 c_grid_desc_m0,
const ElementwiseFunctor functor)
{
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_global, a_grid_desc_m0.GetElementSpaceSize());
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_global, b_grid_desc_m0.GetElementSpaceSize());
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_global, c_grid_desc_m0.GetElementSpaceSize());

StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> a_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> b_thread_buf;
StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> c_thread_buf;

const auto thread_to_global_offset = CalculateElementwiseIndex();

auto a_global_load =
ThreadwiseTensorSliceTransfer_v2<ADataType,
ComputeDataType,
GridDesc_M0,
decltype(thread_desc_m0),
Sequence<ScalarPerVector>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
ScalarPerVector,
1, // SrcScalarStrideInVector
false>{a_grid_desc_m0, thread_to_global_offset};

auto b_global_load =
ThreadwiseTensorSliceTransfer_v2<BDataType,
ComputeDataType,
GridDesc_M0,
decltype(thread_desc_m0),
Sequence<ScalarPerVector>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // SrcVectorDim
ScalarPerVector,
1, // SrcScalarStrideInVector
false>{b_grid_desc_m0, thread_to_global_offset};

auto c_global_write =
ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType,
CDataType,
decltype(thread_desc_m0),
GridDesc_M0,
PassThrough,
Sequence<ScalarPerVector>, // SliceLengths
Sequence<0>, // DimAccessOrder
0, // DstVectorDim
ScalarPerVector,
InMemoryDataOperationEnum::Set,
1, // DstScalarStrideInVector
false>{
c_grid_desc_m0, thread_to_global_offset, PassThrough{}};

const index_t threadPerBlock = get_block_size();
const index_t blockPerGrid = get_grid_size();
const auto m0 = c_grid_desc_m0.GetLength(I0);
const index_t loop_step = blockPerGrid * threadPerBlock * ScalarPerVector;
const auto loop_step_index = make_multi_index(loop_step);

index_t num_iter = m0 / (loop_step);
do
{
// read and process ScalarPerVector elements
a_global_load.Run(
a_grid_desc_m0, a_global_buf, thread_desc_m0, make_tuple(I0), a_thread_buf);

b_global_load.Run(
b_grid_desc_m0, b_global_buf, thread_desc_m0, make_tuple(I0), b_thread_buf);

static_for<0, ScalarPerVector, 1>{}([&](auto m) {
constexpr auto offset = thread_desc_m0.CalculateOffset(make_tuple(m));
functor(c_thread_buf(Number<offset>{}),
a_thread_buf(Number<offset>{}),
b_thread_buf(Number<offset>{}));
});

c_global_write.Run(thread_desc_m0,
make_tuple(I0), // SrcSliceOriginIdx
c_thread_buf,
c_grid_desc_m0,
c_global_buf);

a_global_load.MoveSrcSliceWindow(a_grid_desc_m0, loop_step_index);
b_global_load.MoveSrcSliceWindow(b_grid_desc_m0, loop_step_index);
c_global_write.MoveDstSliceWindow(c_grid_desc_m0, loop_step_index);
} while(--num_iter);
}
};

} // namespace ck
Loading