Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize update_loss_scaling_op #32554

Merged
merged 4 commits into from Apr 28, 2021

Conversation

thisjiang
Copy link
Contributor

@thisjiang thisjiang commented Apr 26, 2021

PR types

Performance optimization

PR changes

OPs

Describe

起因:

CheckFiniteAndUnscale 类似,timeline中显示update_loss_scaling_op在一次运行中多次调用了FillIf,最多调用了300次,且其中包含多个小kernel,存在优化点。

代码分析:

同样的,原有代码中存在一个for循环:

for (size_t i = 0; i < xs.size(); ++i) {
	...
	FillIf<<<...>>>(outs[i]->mutable_data<T>(),...);
	...
}

outs为一个vector<Tensor*>,无论tensor多大,for循环对其中的每个tensor都需要调用一次FillIf

优化

优化方法1:

commit id:ad79dff
显然,融合(fused)kernel,将外部for循环去掉,改为无论xs.size()大小均只用调用一次kernel效果应该最为明显。

基本思路与PR31954相同,这里不再赘叙。需要额外提一句的是,由于该FillIf只是将value一个个赋值给outs中的值,因此若一个thread只处理一个数据会导致线程数过多,计算资源利用率低,为改善这种现象,因此这里设置为一个线程处理50个数据以降低warp切换开销。

优化2:

commit id:527779a

  1. 删除了check_finite_and_unscaleupdate_loss_scaling_opkernel中的无用行while (id < s_starts[index]) index++;,经验证,此行在两kernel中都不会被走到。
  2. 优化了check_finite_and_unscaleupdate_loss_scaling_opkernel中变量的命名,使之更清晰明了。
  3. 添加了若干注释,方便后来者理解和维护。

优化效果:

ernie_doc 模型速度(V100-SXM2-16GB机器单卡) FP32 AMP 加速比
优化前(BS=2048) 4.48 sequence/s 9.78 sequence/s 2.18
优化1(BS=2048) 4.48 sequence/s 9.85 sequence/s 2.19
ernie_doc op cost 优化前 优化1
update_loss_scaling_op 1.406 ms 0.685 ms
ResNet50 AMP模型速度(V100-SXM2-16GB机器单卡) 优化前 优化1
10~510 step平均ips(BS=208) 1415 images/sec 1416 images/sec
10~510 step平均ips(BS=128) 1331 images/sec 1331 images/sec
timeline占比 优化前 优化1
ernie_doc AMP(BS=2048) 1% 0.7%
ResNet50 AMP(BS=208) 0.2% <0.1%
ResNet50 AMP(BS=128) 0.4% <0.1%

ResNet50收敛性验证

模型地址:[ResNet50_fp16.sh]
train loss
train avg loss
test avg loss
test acc 1
test acc 5

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

wangxicoding
wangxicoding previously approved these changes Apr 26, 2021
Copy link
Contributor

@wangxicoding wangxicoding left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines 97 to 103
auto starts_h_tensor =
memory::Alloc(platform::CPUPlace(), (xs_size + 1) * sizeof(int64_t));
int64_t* starts_h = reinterpret_cast<int64_t*>(starts_h_tensor->ptr());

auto starts_d_tensor =
memory::Alloc(dev_ctx, (xs_size + 1) * sizeof(int64_t));
int64_t* starts_d = reinterpret_cast<int64_t*>(starts_d_tensor->ptr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some suggestions about variable names:
starts_h_tensor --> h_in_starts_mem
starts_h --> h_in_starts
starts_d_tensor --> d_in_starts_mem
starts_d --> d_in_starts

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

size_t xs_size = xs.size();
// alloc each tensor's start index and copy to device
auto starts_h_tensor =
memory::Alloc(platform::CPUPlace(), (xs_size + 1) * sizeof(int64_t));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可构造一次platform::CPUPlace()对象,后续使用,不需要多次构建临时对象。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 116 to 121
auto outs_addr_h_tensor =
memory::Alloc(platform::CPUPlace(), xs_size * sizeof(T*));
T** outs_addr_h = reinterpret_cast<T**>(outs_addr_h_tensor->ptr());

auto outs_addr_d_tensor = memory::Alloc(dev_ctx, xs_size * sizeof(T*));
T** outs_addr_d = reinterpret_cast<T**>(outs_addr_d_tensor->ptr());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some suggestions about variable names:
outs_addr_h_tensor --> h_out_addrs_mem
outs_addr_h --> h_out_addrs
outs_addr_d_tensor --> d_out_addrs_mem
outs_addr_d --> d_out_addrs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment on lines 132 to 134
int64_t block = std::min(static_cast<int64_t>(1024), total_num);
int64_t block_num = block * 50; // each thread deal with 50 data
int64_t grid = (total_num + block_num - 1) / block_num;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

block --> threads_per_block
block_num --> elements_per_block
grid --> blocks_per_grid

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

const int tid = threadIdx.x + blockIdx.x * blockDim.x;

// copy starts array from global memory to shared memory
extern __shared__ int64_t starts_s[];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

starts_s --> s_starts

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

for (int64_t id = tid; id < total_num; id += blockDim.x * gridDim.x) {
// get the "out" index of "id"
int next_out_index = out_index;
while (id < starts_s[next_out_index]) next_out_index++;
Copy link
Contributor

@wzzju wzzju Apr 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code in line 57 will not be triggered forever.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

经验证,这一行的确不会被走到,已删除

Copy link
Contributor

@wzzju wzzju left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@wzzju wzzju merged commit 0dc02dc into PaddlePaddle:develop Apr 28, 2021
@thisjiang thisjiang deleted the optimize-update_loss_scaling branch April 28, 2021 02:30
lanxianghit pushed a commit that referenced this pull request Apr 28, 2021
* optimize update_loss_scaling_op by fused for loop to one kernel, test=develop

* remove useless while loop and optimize variable name, test=develop

* optimize variable name from out_addrs_tensor to out_addrs_mem, test=develop

* optimize variable name for readable by change prefix identifier from t_ to local_
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants