Skip to content

Commit

Permalink
Revert "Modefied reduce op for store temp_data with MpType (PaddlePad…
Browse files Browse the repository at this point in the history
…dle#55709) (PaddlePaddle#60427)"

This reverts commit 51d97a6.
  • Loading branch information
SylarTiaNII committed Feb 6, 2024
1 parent 274c15c commit ada37e5
Showing 1 changed file with 36 additions and 65 deletions.
101 changes: 36 additions & 65 deletions paddle/phi/kernels/funcs/reduce_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ struct OneDimIndexCal {
};

// reduce config
template <typename Ty, typename MPType>
template <typename Ty>
struct ReduceConfig {
ReduceConfig(const std::vector<int>& origin_reduce_dims,
const std::vector<int>& origin_x_dim)
Expand All @@ -250,7 +250,7 @@ struct ReduceConfig {
bool should_reduce_again = false;
bool reduce_last_dim = false;
bool vectorize_input = false;
MPType* tmp_data;
Ty* output_data;
dim3 block;
dim3 grid;

Expand Down Expand Up @@ -288,9 +288,11 @@ struct ReduceConfig {
const KPDevice& dev_ctx,
phi::DenseTensor* tmp) {
if (should_reduce_again) {
tmp->Resize(
phi::make_ddim({static_cast<int64_t>(left_num * grid.z * grid.y)}));
tmp_data = dev_ctx.Alloc<MPType>(tmp);
tmp->Resize(phi::make_ddim(
{static_cast<int64_t>(left_num * grid.z * grid.y * sizeof(Ty))}));
output_data = dev_ctx.Alloc<Ty>(tmp);
} else {
output_data = y_data;
}
}

Expand Down Expand Up @@ -581,9 +583,7 @@ __global__ void ReduceAnyKernel(const Tx* x,
const Calculator reduce_index_calculator,
const Calculator left_index_calculator,
const kps::DimConfig dim,
bool is_mean,
MPType* tmp_data,
bool need_store_tmp = false) {
bool is_mean) {
int input_idx, left_idx, stride;
int block_size = 0;
bool need_store = true;
Expand Down Expand Up @@ -686,15 +686,9 @@ __global__ void ReduceAnyKernel(const Tx* x,
if (is_mean) {
reduce_var = reduce_var / static_cast<MPType>(reduce_num);
}
if (!need_store_tmp) {
Ty result = static_cast<Ty>(reduce_var);
kps::details::WriteData<Ty>(
y + store_offset + i, &result, static_cast<int>(need_store));
} else {
kps::details::WriteData<MPType>(tmp_data + store_offset + i,
&reduce_var,
static_cast<int>(need_store));
}
Ty result = static_cast<Ty>(reduce_var);
kps::details::WriteData<Ty>(
y + store_offset + i, &result, static_cast<int>(need_store));
}
}

Expand All @@ -713,9 +707,7 @@ __global__ void ReduceHigherDimKernel(const Tx* x,
int blocking_size,
const kps::DimConfig dim,
int mean_div,
bool is_mean,
MPType* tmp_data,
bool need_store_tmp = false) {
bool is_mean) {
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
// function will be used
auto block = ReduceIndexMapping<false>(dim);
Expand Down Expand Up @@ -747,14 +739,9 @@ __global__ void ReduceHigherDimKernel(const Tx* x,
if (is_mean) {
reduce_var = reduce_var / static_cast<MPType>(mean_div);
}
if (!need_store_tmp) {
Ty result = static_cast<Ty>(reduce_var);
kps::WriteData<Ty, 1, 1, false>(
y + store_offset + idx, &result, block.BlockDimX());
} else {
kps::WriteData<MPType, 1, 1, false>(
tmp_data + store_offset + idx, &reduce_var, block.BlockDimX());
}
Ty result = static_cast<Ty>(reduce_var);
kps::WriteData<Ty, 1, 1, false>(
y + store_offset + idx, &result, block.BlockDimX());
}

if (idx < left_num) {
Expand All @@ -776,14 +763,8 @@ __global__ void ReduceHigherDimKernel(const Tx* x,
if (is_mean) {
reduce_var = reduce_var / static_cast<MPType>(mean_div);
}
if (!need_store_tmp) {
Ty result = static_cast<Ty>(reduce_var);
kps::WriteData<Ty, 1, 1, true>(
y + store_offset + idx, &result, dim.rem_x);
} else {
kps::WriteData<MPType, 1, 1, true>(
tmp_data + store_offset + idx, &reduce_var, dim.rem_x);
}
Ty result = static_cast<Ty>(reduce_var);
kps::WriteData<Ty, 1, 1, true>(y + store_offset + idx, &result, dim.rem_x);
}
}

Expand All @@ -798,7 +779,7 @@ static void LaunchReduceKernel(const Tx* x_data,
const TransformOp& transform,
MPType init,
KPStream stream,
ReduceConfig<Ty, MPType> config,
ReduceConfig<Ty> config,
bool is_mean = false) {
if (config.reduce_type == kReduceLastDim) {
int stride_reduce = 1;
Expand All @@ -825,7 +806,7 @@ static void LaunchReduceKernel(const Tx* x_data,
ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp, OneDimIndexCal>
<<<grid_num, block_num, 0, stream>>>(
x_data,
y_data,
config.output_data,
reducer,
transform,
init,
Expand All @@ -835,9 +816,7 @@ static void LaunchReduceKernel(const Tx* x_data,
reduce_index_calculator,
left_index_calculator,
dim,
is_mean && (!config.should_reduce_again),
config.tmp_data,
config.should_reduce_again);
is_mean && (!config.should_reduce_again));
} else {
int reduce_rank = config.reduce_strides.size();
int left_rank = config.left_strides.size();
Expand Down Expand Up @@ -866,7 +845,7 @@ static void LaunchReduceKernel(const Tx* x_data,
ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp, IndexCalculator>
<<<grid_num, block_num, 0, stream>>>(
x_data,
y_data,
config.output_data,
reducer,
transform,
init,
Expand All @@ -876,9 +855,7 @@ static void LaunchReduceKernel(const Tx* x_data,
reduce_index_calculator,
left_index_calculator,
dim,
is_mean && (!config.should_reduce_again),
config.tmp_data,
config.should_reduce_again);
is_mean && (!config.should_reduce_again));
}

if (config.should_reduce_again) {
Expand All @@ -902,25 +879,23 @@ static void LaunchReduceKernel(const Tx* x_data,
auto grid_size = grid;
auto block_size = block;
#endif
ReduceHigherDimKernel<MPType,
ReduceHigherDimKernel<Ty,
Ty,
MPType,
ReduceOp,
kps::IdentityFunctor<MPType, MPType>>
kps::IdentityFunctor<Ty, MPType>>
<<<grid_size, block_size, 0, stream>>>(
config.tmp_data,
config.output_data,
y_data,
reducer,
kps::IdentityFunctor<MPType, MPType>(),
kps::IdentityFunctor<Ty, MPType>(),
init,
config.grid.y,
config.left_num,
config.grid.y,
dim,
config.reduce_num,
is_mean,
config.tmp_data,
false);
is_mean);
}
}

Expand Down Expand Up @@ -1029,8 +1004,7 @@ void ReduceKernel(const KPDevice& dev_ctx,
return;
}

using MPType = typename phi::dtype::MPTypeTrait<Ty>::Type;
auto config = ReduceConfig<Ty, MPType>(origin_reduce_dims, x_dim);
auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim);
config.Run(dev_ctx);
int numel = x.numel();
// after config.run()
Expand Down Expand Up @@ -1073,6 +1047,7 @@ void ReduceKernel(const KPDevice& dev_ctx,
}
#endif

using MPType = typename kps::details::MPTypeTrait<Ty>::Type;
auto reducer = ReduceOp<MPType>();
// launch ReduceHigherDimKernel
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
Expand Down Expand Up @@ -1102,7 +1077,7 @@ void ReduceKernel(const KPDevice& dev_ctx,
ReduceHigherDimKernel<Tx, Ty, MPType, ReduceOp<MPType>, TransformOp>
<<<grid_num, block_num, 0, stream>>>(
x_data,
y_data,
config.output_data,
reducer,
transform,
reducer.initial(),
Expand All @@ -1111,9 +1086,7 @@ void ReduceKernel(const KPDevice& dev_ctx,
config.blocking_size,
dim,
config.reduce_num,
is_mean && (!config.should_reduce_again),
config.tmp_data,
config.should_reduce_again);
is_mean && (!config.should_reduce_again));

if (config.should_reduce_again) {
dim3 block = dim3(config.block.x, 1, 1);
Expand All @@ -1129,25 +1102,23 @@ void ReduceKernel(const KPDevice& dev_ctx,
auto grid_size = grid;
auto block_size = block;
#endif
ReduceHigherDimKernel<MPType,
ReduceHigherDimKernel<Ty,
Ty,
MPType,
ReduceOp<MPType>,
kps::IdentityFunctor<MPType, MPType>>
kps::IdentityFunctor<Ty, MPType>>
<<<grid_size, block_size, 0, stream>>>(
config.tmp_data,
config.output_data,
y_data,
reducer,
kps::IdentityFunctor<MPType, MPType>(config.grid.y),
kps::IdentityFunctor<Ty, MPType>(config.grid.y),
reducer.initial(),
config.grid.y,
config.left_num,
config.grid.y,
dim2,
config.reduce_num,
is_mean,
config.tmp_data,
false);
is_mean);
}
return;
}
Expand Down

0 comments on commit ada37e5

Please sign in to comment.