Skip to content

Gemm + softmax (gemm + reduce_max + broadcast sub + exp + reduce_sum + broadcast div)#178

Closed
rocking5566 wants to merge 33 commits into
developfrom
gemm_softmax
Closed

Gemm + softmax (gemm + reduce_max + broadcast sub + exp + reduce_sum + broadcast div)#178
rocking5566 wants to merge 33 commits into
developfrom
gemm_softmax

Conversation

@rocking5566
Copy link
Copy Markdown
Collaborator

No description provided.

@asroy

This comment was marked as resolved.

@asroy asroy changed the title Gemm + reduce_max (part of gemm + softmax) [WIP] Gemm + reduce_max (part of gemm + softmax) Apr 7, 2022
@asroy asroy changed the title [WIP] Gemm + reduce_max (part of gemm + softmax) Gemm + reduce_max (part of gemm + softmax) Apr 12, 2022
@asroy
Copy link
Copy Markdown
Contributor

asroy commented Apr 14, 2022

This PR is doing GEMM and reduction separately.

@rocking5566 rocking5566 changed the title Gemm + reduce_max (part of gemm + softmax) Gemm + softmax (gemm + reduce_max + broadcast sub + exp + reduce_sum + broadcast div) Apr 15, 2022
Comment thread include/ck/tensor_operation/gpu/grid/gridwise_elementwise_1d.hpp Outdated
@qianfengz
Copy link
Copy Markdown
Contributor

Also, the file name gridwise_elementwise_1d.hpp should be changed to indicate the elementwise operator used here is binary operator, cause we might also use unary or ternary operate on the 1d tensor. Names could be like

  1. gridwise_1d_binary_operate.hpp
  2. gridwise_1d_unary_operate.hpp
  3. gridwise_1d_ternary_operate.hpp

My PR.182 and PR.192 also have similar elementwise binary/unary kernels defined. By I think your implementation is more generic since your implementation support ThreadSliceSize/ThreadTileSize, but my implementation assumes each thread process only one element, for int8 and fp16, use ThreadSliceSize=4/2 could be beneficial for data access performance.

After merging your P.R, I will change my Batch-Norm forward codes to use your kernel

@rocking5566
Copy link
Copy Markdown
Collaborator Author

rocking5566 commented Apr 19, 2022

Also, the file name gridwise_elementwise_1d.hpp should be changed to indicate the elementwise operator used here is binary operator, cause we might also use unary or ternary operate on the 1d tensor. Names could be like

  1. gridwise_1d_binary_operate.hpp
  2. gridwise_1d_unary_operate.hpp
  3. gridwise_1d_ternary_operate.hpp

My PR.182 and PR.192 also have similar elementwise binary/unary kernels defined. By I think your implementation is more generic since your implementation support ThreadSliceSize/ThreadTileSize, but my implementation assumes each thread process only one element, for int8 and fp16, use ThreadSliceSize=4/2 could be beneficial for data access performance.

After merging your P.R, I will change my Batch-Norm forward codes to use your kernel

Also, the file name gridwise_elementwise_1d.hpp should be changed to indicate the elementwise operator used here is binary operator, cause we might also use unary or ternary operate on the 1d tensor. Names could be like

  1. gridwise_1d_binary_operate.hpp
  2. gridwise_1d_unary_operate.hpp
  3. gridwise_1d_ternary_operate.hpp

My PR.182 and PR.192 also have similar elementwise binary/unary kernels defined. By I think your implementation is more generic since your implementation support ThreadSliceSize/ThreadTileSize, but my implementation assumes each thread process only one element, for int8 and fp16, use ThreadSliceSize=4/2 could be beneficial for data access performance.

After merging your P.R, I will change my Batch-Norm forward codes to use your kernel

Your suggestion is great!
I will rename to gridwise_binary_elementwise_1d.hpp
Because elementwise operation is a term in deep learning area.
This means there must be some sort of operation on two tensor.
https://caffe.berkeleyvision.org/tutorial/layers/eltwise.html

There are also another binary, ternary operation in deep learning, ex: concatenation.

In addition, as discussed with @asroy before.
We will rename original elementwise in each kernel parameter to the functor to stand for the fusion operation.
I open the ticket here
#179

[Why] Prevent loss of precision
[Why] Similar to acc datatype, it increase precision
Let memory coalesce between block
Comment thread example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp Outdated
Comment thread example/19_gemm_softmax/gemm_softmax_xdl_fp16.cpp Outdated
Comment thread include/ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp Outdated
Comment thread include/ck/tensor_operation/gpu/device/device_binary_elementwise_2d.hpp Outdated
typename ElementwiseFunctor,
index_t ThreadPerBlock,
index_t ScalarPerVector>
struct DeviceBinaryElementwise_2D : public DeviceBinaryElementwise<ElementwiseFunctor>
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DeviceBinaryElementwise_ND

You could make this Device Operation supporting N-D tensor (N=1~5).

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add this task to the backlog

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please open a JIRA task ticket and a github issue, and refer to this comment in both tickets.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2. Use DeviceGemm_Xdl_CShuffle instead of deprecated DeviceGemmXdl_C_Shuffle
@rocking5566 rocking5566 mentioned this pull request Apr 28, 2022
@zjing14 zjing14 requested review from asroy and qianfengz April 29, 2022 01:26
[Why] F16 issue for host reduction has been fix in c1ef731
@asroy asroy changed the title Gemm + softmax (gemm + reduce_max + broadcast sub + exp + reduce_sum + broadcast div) [WIP] Gemm + softmax (gemm + reduce_max + broadcast sub + exp + reduce_sum + broadcast div) Apr 30, 2022
@asroy
Copy link
Copy Markdown
Contributor

asroy commented Apr 30, 2022

After PR #209 get merged, please fix issues in this PR before we merge it

  1. refactor this PR to use fused GEMM+reduction done by Gemm reduce max #209
  2. refactor device element-wise operation to support 1D~5D (Gemm + softmax (gemm + reduce_max + broadcast sub + exp + reduce_sum + broadcast div) #178 (comment))


// do reduce max
auto reduce_max = DeviceReduceMaxInstance{};
auto reduce_max_workaspace_size = reduce_max.GetWorkspaceSizeInBytes(c_m_n_shape, reduceDims);
Copy link
Copy Markdown
Collaborator

@aosewski aosewski May 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spelling

Suggested change
auto reduce_max_workaspace_size = reduce_max.GetWorkspaceSizeInBytes(c_m_n_shape, reduceDims);
auto reduce_max_workspace_size = reduce_max.GetWorkspaceSizeInBytes(c_m_n_shape, reduceDims);

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!
Fix in b6fe118

@asroy asroy changed the title [WIP] Gemm + softmax (gemm + reduce_max + broadcast sub + exp + reduce_sum + broadcast div) Gemm + softmax (gemm + reduce_max + broadcast sub + exp + reduce_sum + broadcast div) May 7, 2022
@asroy asroy unassigned qianfengz and asroy May 7, 2022
@asroy asroy added the on hold label May 7, 2022
@myamlak myamlak mentioned this pull request May 16, 2022
// m * n
const auto m0 = pArg->c_grid_desc_m0_.GetLength(I0);

if(m0 % BlockTileSize != 0)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think let the merged length be completely dividable by BlockTileSize is too strong restriction. You should pad the tensor and relax the restriction

std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
const std::vector<int>& shape,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't use type reference for arguments here, since MakeArgumentPointer() is an API, we could not assume the user always pass left values

MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
const std::vector<int>& shape_a,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, don't use type reference for declaring the arguments as this is an API. We could not always assume the user will pass addressable values

template <typename ADataType,
typename BDataType,
typename CDataType,
typename ElementwiseFunctor,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explicitly rename the ElementwiseFunctor type to be Binary Operator type, since here the kernel called will use Binary Operator. Also the base class DeviceElementwise should be re-named to indicate its usage since using Unary Operator will lead to different API (eg. p_a, p_b as in/out data) than using Binary Operator (eg. p_a, p_b, p_c as in/out data)

{
dst = src1 - src2;
// FIXME - use float16 exponential
float dst_f32 = static_cast<float>(dst);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To simplify, I suggest, here define dst, src1, src2 as AccDataType, assuming the operator() works on the VGPRs storing the converted values. ThreadwiseTransfer() can do the conversion automatically when the data is loaded from device memory to static buffer.

Expression like dst = src1 - src2 will lead to implicit loss of precision. Remember, always do + - * / in AccDataType.

Also should use ck::type_convert for type conversion, since static_cast<>() does not work at least when ck::bhalf_t is involved

__host__ __device__ constexpr void
operator()(CDataType& dst, const CDataType& src1, const CDataType& src2) const
{
dst = src1 / src2;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same as above. It is horrible if dividing is done using half_t


using DeviceReduceSumInstance =
ck::tensor_operation::device::DeviceReduceBlockWise<CDataType,
CDataType,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use AccDataType !

@@ -0,0 +1,150 @@
#pragma once

#include "cluster_descriptor.hpp"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#include "cluster_descriptor.hpp" is not needed since you don't use make_cluster_descriptor(). Also, data_type.hpp is not needed. Actually several other headers are needed even though they are included in-directly, eg. tensor_descriptor_helper.hpp and get_id.hpp.

You don't have to change, there are lots of similar issues in other C.K codes.


// CAUTION - host reduce_max will call numeric_limits<ck::half_t>::lowest()
// However, numeric_limits<ck::half_t>::lowest() will return zero. So, used half_float::half instead
using HostReduceDataType = half_float::half;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove using half_float::half, since the Host_Reduction can now support using ck::half_t. Check
PR.195

ComputeDataType Bm = static_cast<ComputeDataType>(B(m));
functor(Cmn, Amn, Bm);
}
C(m, n) = static_cast<ComputeDataType>(Cmn);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use ck::type_convert<ComputeDataType>(), or else conversion from bhalf_t will not work

@asroy
Copy link
Copy Markdown
Contributor

asroy commented May 20, 2022

closing this PR. new PR will be creatted

@asroy asroy closed this May 20, 2022
@illsilin illsilin deleted the gemm_softmax branch December 7, 2023 18:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants