-
Notifications
You must be signed in to change notification settings - Fork 291
Gemm + softmax (gemm + reduce_max + broadcast sub + exp + reduce_sum + broadcast div) #178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d6e053a
cbbc7e5
3e811cc
0277c89
6818b58
e3a09b5
cb1c473
a760a73
c8b4ac2
f2540aa
30348da
b05a594
6a781e5
dba65b1
fe65950
e83b22e
21802fd
c16f789
cf32669
0f421d6
5fa209a
0e6bf34
d7112d3
88d621a
5d36f7a
680cfaa
a41f548
f919809
976815e
ea09fd3
bfc8076
d92fb7e
b6fe118
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| add_example_executable(example_gemm_softmax_xdl_fp16 gemm_softmax_xdl_fp16.cpp) |
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| #pragma once | ||
| #include <iostream> | ||
| #include <vector> | ||
|
|
||
| #include "device_base.hpp" | ||
|
|
||
| namespace ck { | ||
| namespace tensor_operation { | ||
| namespace device { | ||
|
|
||
| template <typename ElementwiseFunctor> | ||
| struct DeviceBinaryElementwise : public BaseOperator | ||
| { | ||
|
|
||
| virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, | ||
| const void* p_b, | ||
| void* p_c, | ||
| std::vector<int> shape_a, | ||
| std::vector<int> stride_a, | ||
| std::vector<int> shape_b, | ||
| std::vector<int> stride_b, | ||
| ElementwiseFunctor functor, | ||
| index_t threadPerBlock) = 0; | ||
|
|
||
| virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; | ||
| }; | ||
|
|
||
| } // namespace device | ||
| } // namespace tensor_operation | ||
| } // namespace ck |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,206 @@ | ||
| #pragma once | ||
| #include <iostream> | ||
| #include <vector> | ||
|
|
||
| #include "device.hpp" | ||
| #include "device_binary_elementwise.hpp" | ||
| #include "gridwise_binary_elementwise_1d.hpp" | ||
|
|
||
| namespace ck { | ||
| namespace tensor_operation { | ||
| namespace device { | ||
|
|
||
| template <typename ADataType, | ||
| typename BDataType, | ||
| typename CDataType, | ||
| typename ComputeDataType, | ||
| typename ElementwiseFunctor, | ||
| index_t ScalarPerVector> | ||
| struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFunctor> | ||
| { | ||
| static constexpr auto I0 = Number<0>{}; | ||
|
|
||
| static auto MakeDescriptor_M0(const std::vector<int>& shape, | ||
| const std::vector<int>& stride, | ||
| index_t gridSize, | ||
| index_t threadPerBlock) | ||
| { | ||
| const int m = shape[0]; | ||
| const int n = shape[1]; | ||
|
|
||
| // 2d desc - [m, n] | ||
| const auto desc_m_n = | ||
| make_naive_tensor_descriptor(make_tuple(m, n), make_tuple(stride[0], stride[1])); | ||
|
|
||
| // 1d desc - [m * n] | ||
| const auto desc_m0 = | ||
| transform_tensor_descriptor(desc_m_n, | ||
| make_tuple(make_merge_transform(make_tuple(m, n))), | ||
| make_tuple(Sequence<0, 1>{}), | ||
| make_tuple(Sequence<0>{})); | ||
|
|
||
| // pad | ||
| const auto m0 = desc_m0.GetLength(I0); | ||
| const index_t loop_step = gridSize * threadPerBlock * ScalarPerVector; | ||
| const auto pad = math::integer_least_multiple(m0, loop_step) - m0; | ||
| const auto desc_m0_pad = | ||
| transform_tensor_descriptor(desc_m0, | ||
| make_tuple(make_right_pad_transform(m0, pad)), | ||
| make_tuple(Sequence<0>{}), | ||
| make_tuple(Sequence<0>{})); | ||
| return desc_m0_pad; | ||
| } | ||
|
|
||
| using GridDesc_M0 = decltype(MakeDescriptor_M0({1, 1}, {1, 1}, 1, 1)); | ||
| using GridwiseBinEltwise = GridwiseBinaryElementwise_1D<ADataType, | ||
| BDataType, | ||
| CDataType, | ||
| ComputeDataType, | ||
| GridDesc_M0, | ||
| ElementwiseFunctor, | ||
| ScalarPerVector>; | ||
|
|
||
| struct Argument : public BaseArgument | ||
| { | ||
| Argument(const ADataType* p_a, | ||
| const BDataType* p_b, | ||
| CDataType* p_c, | ||
| const std::vector<int>& shape, | ||
| const std::vector<int>& stride_a, | ||
| const std::vector<int>& stride_b, | ||
| const std::vector<int>& stride_c, | ||
| ElementwiseFunctor functor, | ||
| index_t threadPerBlock) | ||
| : p_a_(p_a), | ||
| p_b_(p_b), | ||
| p_c_(p_c), | ||
| functor_(functor), | ||
| threadPerBlock_(threadPerBlock), | ||
| gridSize_(128) // FIXME - Calculate the grid size by number of CU in the future | ||
| { | ||
| a_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_a, gridSize_, threadPerBlock_); | ||
| b_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_b, gridSize_, threadPerBlock_); | ||
| c_grid_desc_m0_ = MakeDescriptor_M0(shape, stride_c, gridSize_, threadPerBlock_); | ||
| } | ||
|
|
||
| const ADataType* p_a_; | ||
| const BDataType* p_b_; | ||
| CDataType* p_c_; | ||
| GridDesc_M0 a_grid_desc_m0_; | ||
| GridDesc_M0 b_grid_desc_m0_; | ||
| GridDesc_M0 c_grid_desc_m0_; | ||
| ElementwiseFunctor functor_; | ||
| index_t threadPerBlock_; | ||
| index_t gridSize_; | ||
| }; | ||
|
|
||
| struct Invoker : public BaseInvoker | ||
| { | ||
| float Run(const Argument& arg, int nrepeat = 1) | ||
| { | ||
| (void)arg; | ||
| const auto kernel = kernel_elementwise_1d<GridwiseBinEltwise, | ||
| ADataType, | ||
| BDataType, | ||
| CDataType, | ||
| GridDesc_M0, | ||
| ElementwiseFunctor>; | ||
| float avgTime = 0; | ||
| if(nrepeat == 0) | ||
| { | ||
| launch_kernel(kernel, | ||
| dim3(arg.gridSize_), | ||
| dim3(arg.threadPerBlock_), | ||
| 0, | ||
| arg.p_a_, | ||
| arg.p_b_, | ||
| arg.p_c_, | ||
| arg.a_grid_desc_m0_, | ||
| arg.b_grid_desc_m0_, | ||
| arg.c_grid_desc_m0_, | ||
| arg.functor_); | ||
| } | ||
| else | ||
| { | ||
| avgTime = launch_and_time_kernel(kernel, | ||
| nrepeat, | ||
| dim3(arg.gridSize_), | ||
| dim3(arg.threadPerBlock_), | ||
| 0, | ||
| arg.p_a_, | ||
| arg.p_b_, | ||
| arg.p_c_, | ||
| arg.a_grid_desc_m0_, | ||
| arg.b_grid_desc_m0_, | ||
| arg.c_grid_desc_m0_, | ||
| arg.functor_); | ||
| } | ||
| return avgTime; | ||
| } | ||
|
|
||
| float Run(const BaseArgument* p_arg, int nrepeat = 1) override | ||
| { | ||
| return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat); | ||
| }; | ||
| }; | ||
|
|
||
| bool IsSupportedArgument(const BaseArgument* p_arg) override | ||
| { | ||
| const Argument* pArg = dynamic_cast<const Argument*>(p_arg); | ||
|
|
||
| if(pArg == nullptr) | ||
| return false; | ||
|
|
||
| // m * n | ||
| const auto m0 = pArg->c_grid_desc_m0_.GetLength(I0); | ||
|
|
||
| if(m0 % ScalarPerVector != 0) | ||
| return false; | ||
|
|
||
| return true; | ||
| }; | ||
|
|
||
| std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a, | ||
| const void* p_b, | ||
| void* p_c, | ||
| std::vector<int> shape, | ||
| std::vector<int> stride_a, | ||
| std::vector<int> stride_b, | ||
| std::vector<int> stride_c, | ||
| ElementwiseFunctor functor, | ||
| index_t threadPerBlock) override | ||
| { | ||
| return std::make_unique<Argument>(static_cast<const ADataType*>(p_a), | ||
| static_cast<const BDataType*>(p_b), | ||
| static_cast<CDataType*>(p_c), | ||
| shape, | ||
| stride_a, | ||
| stride_b, | ||
| stride_c, | ||
| functor, | ||
| threadPerBlock); | ||
| } | ||
|
|
||
| std::unique_ptr<BaseInvoker> MakeInvokerPointer() override | ||
| { | ||
| return std::make_unique<Invoker>(Invoker{}); | ||
| } | ||
|
|
||
| std::string GetTypeString() const override | ||
| { | ||
| auto str = std::stringstream(); | ||
|
|
||
| // clang-format off | ||
| str << "DeviceBinaryElementwise_2D" | ||
| << "<" | ||
| << "ScalarPerVector = " << ScalarPerVector | ||
| << ">"; | ||
| // clang-format on | ||
|
|
||
| return str.str(); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace device | ||
| } // namespace tensor_operation | ||
| } // namespace ck | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,150 @@ | ||
| #pragma once | ||
|
|
||
| #include "cluster_descriptor.hpp" | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
You don't have to change, there are lots of similar issues in other C.K codes. |
||
| #include "data_type.hpp" | ||
| #include "element_wise_operation.hpp" | ||
| #include "threadwise_tensor_slice_transfer.hpp" | ||
|
|
||
| namespace ck { | ||
|
|
||
| template <typename GridwiseBinEltwise, | ||
| typename ADataType, | ||
| typename BDataType, | ||
| typename CDataType, | ||
| typename GridDesc_M0, | ||
| typename ElementwiseFunctor> | ||
| __global__ void kernel_elementwise_1d(const ADataType* __restrict__ p_a_global, | ||
| const BDataType* __restrict__ p_b_global, | ||
| CDataType* __restrict__ p_c_global, | ||
| const GridDesc_M0 a_grid_desc_m0, | ||
| const GridDesc_M0 b_grid_desc_m0, | ||
| const GridDesc_M0 c_grid_desc_m0, | ||
| const ElementwiseFunctor functor) | ||
| { | ||
| GridwiseBinEltwise::Run(p_a_global, | ||
| p_b_global, | ||
| p_c_global, | ||
| a_grid_desc_m0, | ||
| b_grid_desc_m0, | ||
| c_grid_desc_m0, | ||
| functor); | ||
| } | ||
|
|
||
| template <typename ADataType, | ||
| typename BDataType, | ||
| typename CDataType, | ||
| typename ComputeDataType, | ||
| typename GridDesc_M0, | ||
| typename ElementwiseFunctor, | ||
| index_t ScalarPerVector> | ||
| struct GridwiseBinaryElementwise_1D | ||
| { | ||
| static constexpr auto I0 = Number<0>{}; | ||
| static constexpr auto thread_desc_m0 = | ||
| make_naive_tensor_descriptor_packed(make_tuple(Number<ScalarPerVector>{})); | ||
|
|
||
| using PassThrough = tensor_operation::element_wise::PassThrough; | ||
|
|
||
| static __device__ __host__ auto CalculateElementwiseIndex() | ||
| { | ||
| const index_t global_thread_id = get_thread_global_1d_id(); | ||
| return make_multi_index(global_thread_id * ScalarPerVector); | ||
| } | ||
|
|
||
| __device__ static void Run(const ADataType* __restrict__ p_a_global, | ||
| const BDataType* __restrict__ p_b_global, | ||
| CDataType* __restrict__ p_c_global, | ||
| const GridDesc_M0 a_grid_desc_m0, | ||
| const GridDesc_M0 b_grid_desc_m0, | ||
| const GridDesc_M0 c_grid_desc_m0, | ||
| const ElementwiseFunctor functor) | ||
| { | ||
| const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( | ||
| p_a_global, a_grid_desc_m0.GetElementSpaceSize()); | ||
| const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( | ||
| p_b_global, b_grid_desc_m0.GetElementSpaceSize()); | ||
| auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( | ||
| p_c_global, c_grid_desc_m0.GetElementSpaceSize()); | ||
|
|
||
| StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> a_thread_buf; | ||
| StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> b_thread_buf; | ||
| StaticBuffer<AddressSpaceEnum::Vgpr, ComputeDataType, ScalarPerVector, true> c_thread_buf; | ||
|
|
||
| const auto thread_to_global_offset = CalculateElementwiseIndex(); | ||
|
|
||
| auto a_global_load = | ||
| ThreadwiseTensorSliceTransfer_v2<ADataType, | ||
| ComputeDataType, | ||
| GridDesc_M0, | ||
| decltype(thread_desc_m0), | ||
| Sequence<ScalarPerVector>, // SliceLengths | ||
| Sequence<0>, // DimAccessOrder | ||
| 0, // SrcVectorDim | ||
| ScalarPerVector, | ||
| 1, // SrcScalarStrideInVector | ||
| false>{a_grid_desc_m0, thread_to_global_offset}; | ||
|
|
||
| auto b_global_load = | ||
| ThreadwiseTensorSliceTransfer_v2<BDataType, | ||
| ComputeDataType, | ||
| GridDesc_M0, | ||
| decltype(thread_desc_m0), | ||
| Sequence<ScalarPerVector>, // SliceLengths | ||
| Sequence<0>, // DimAccessOrder | ||
| 0, // SrcVectorDim | ||
| ScalarPerVector, | ||
| 1, // SrcScalarStrideInVector | ||
| false>{b_grid_desc_m0, thread_to_global_offset}; | ||
|
|
||
| auto c_global_write = | ||
| ThreadwiseTensorSliceTransfer_v1r3<ComputeDataType, | ||
| CDataType, | ||
| decltype(thread_desc_m0), | ||
| GridDesc_M0, | ||
| PassThrough, | ||
| Sequence<ScalarPerVector>, // SliceLengths | ||
| Sequence<0>, // DimAccessOrder | ||
| 0, // DstVectorDim | ||
| ScalarPerVector, | ||
| InMemoryDataOperationEnum::Set, | ||
| 1, // DstScalarStrideInVector | ||
| false>{ | ||
| c_grid_desc_m0, thread_to_global_offset, PassThrough{}}; | ||
|
|
||
| const index_t threadPerBlock = get_block_size(); | ||
| const index_t blockPerGrid = get_grid_size(); | ||
| const auto m0 = c_grid_desc_m0.GetLength(I0); | ||
| const index_t loop_step = blockPerGrid * threadPerBlock * ScalarPerVector; | ||
| const auto loop_step_index = make_multi_index(loop_step); | ||
|
|
||
| index_t num_iter = m0 / (loop_step); | ||
| do | ||
| { | ||
| // read and process ScalarPerVector elements | ||
| a_global_load.Run( | ||
| a_grid_desc_m0, a_global_buf, thread_desc_m0, make_tuple(I0), a_thread_buf); | ||
|
|
||
| b_global_load.Run( | ||
| b_grid_desc_m0, b_global_buf, thread_desc_m0, make_tuple(I0), b_thread_buf); | ||
|
|
||
| static_for<0, ScalarPerVector, 1>{}([&](auto m) { | ||
| constexpr auto offset = thread_desc_m0.CalculateOffset(make_tuple(m)); | ||
| functor(c_thread_buf(Number<offset>{}), | ||
| a_thread_buf(Number<offset>{}), | ||
| b_thread_buf(Number<offset>{})); | ||
| }); | ||
|
|
||
| c_global_write.Run(thread_desc_m0, | ||
| make_tuple(I0), // SrcSliceOriginIdx | ||
| c_thread_buf, | ||
| c_grid_desc_m0, | ||
| c_global_buf); | ||
|
|
||
| a_global_load.MoveSrcSliceWindow(a_grid_desc_m0, loop_step_index); | ||
| b_global_load.MoveSrcSliceWindow(b_grid_desc_m0, loop_step_index); | ||
| c_global_write.MoveDstSliceWindow(c_grid_desc_m0, loop_step_index); | ||
| } while(--num_iter); | ||
| } | ||
| }; | ||
|
|
||
| } // namespace ck | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DeviceBinaryElementwise_ND
You could make this Device Operation supporting N-D tensor (N=1~5).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will add this task to the backlog
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please open a JIRA task ticket and a github issue, and refer to this comment in both tickets.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#203