Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add num_splist for flash_attn_bwd and FlashAttnUnpaddedGradKernel #56363

Merged
merged 10 commits into from
Sep 4, 2023

Conversation

AnnaTrainingG
Copy link
Contributor

@AnnaTrainingG AnnaTrainingG commented Aug 16, 2023

PR types

Others

PR changes

Others

Description

Pcard-70458
llama模型参考README运行,2CI对别结果如下,loss无diff
image
具体数据1:

[    INFO] - loss: 2.00048828, learning_rate: 2.1e-06, global_step: 20, interval_runtime: 26.3106, interval_samples_per_second: 6.081193361471736, interval_steps_per_second: 0.760149170183967, epoch: 0.0014
[    INFO] - loss: 1.94342518, learning_rate: 4.1e-06, global_step: 40, interval_runtime: 18.792, interval_samples_per_second: 8.514269788523704, interval_steps_per_second: 1.064283723565463, epoch: 0.0028
[    INFO] - loss: 1.96450806, learning_rate: 6.1e-06, global_step: 60, interval_runtime: 18.8364, interval_samples_per_second: 8.494203585698061, interval_steps_per_second: 1.0617754482122577, epoch: 0.0042
[    INFO] - loss: 1.97299156, learning_rate: 8.1e-06, global_step: 80, interval_runtime: 18.7085, interval_samples_per_second: 8.552269357792328, interval_steps_per_second: 1.069033669724041, epoch: 0.0056
[    INFO] - loss: 1.97750511, learning_rate: 1e-05, global_step: 100, interval_runtime: 18.7949, interval_samples_per_second: 8.512925658562247, interval_steps_per_second: 1.064115707320281, epoch: 0.0069
[    INFO] - loss: 1.95206642, learning_rate: 1e-05, global_step: 120, interval_runtime: 18.7179, interval_samples_per_second: 8.547971573787688, interval_steps_per_second: 1.068496446723461, epoch: 0.0083
[    INFO] - loss: 1.94697475, learning_rate: 1e-05, global_step: 140, interval_runtime: 18.7172, interval_samples_per_second: 8.548273071849662, interval_steps_per_second: 1.0685341339812078, epoch: 0.0097
[    INFO] - loss: 1.96742477, learning_rate: 1e-05, global_step: 160, interval_runtime: 18.9591, interval_samples_per_second: 8.439229664237219, interval_steps_per_second: 1.0549037080296524, epoch: 0.0111
[    INFO] - loss: 1.94573994, learning_rate: 9.999e-06, global_step: 180, interval_runtime: 18.7157, interval_samples_per_second: 8.548989500870817, interval_steps_per_second: 1.0686236876088522, epoch: 0.0125
[    INFO] - loss: 1.9408371, learning_rate: 9.999e-06, global_step: 200, interval_runtime: 18.9711, interval_samples_per_second: 8.433899198356313, interval_steps_per_second: 1.0542373997945391, epoch: 0.0139
[    INFO] - loss: 1.99254551, learning_rate: 9.998e-06, global_step: 220, interval_runtime: 19.6621, interval_samples_per_second: 8.137464978509199, interval_steps_per_second: 1.0171831223136498, epoch: 0.0153
[    INFO] - loss: 1.96885376, learning_rate: 9.997e-06, global_step: 240, interval_runtime: 18.7156, interval_samples_per_second: 8.549017162948847, interval_steps_per_second: 1.0686271453686058, epoch: 0.0167
[    INFO] - loss: 1.97938385, learning_rate: 9.997e-06, global_step: 260, interval_runtime: 18.9203, interval_samples_per_second: 8.456534692087457, interval_steps_per_second: 1.0570668365109321, epoch: 0.0181
[    INFO] - loss: 1.97675762, learning_rate: 9.996e-06, global_step: 280, interval_runtime: 19.3148, interval_samples_per_second: 8.283782019454836, interval_steps_per_second: 1.0354727524318545, epoch: 0.0194
[    INFO] - loss: 1.97571411, learning_rate: 9.995e-06, global_step: 300, interval_runtime: 19.0757, interval_samples_per_second: 8.38764279200571, interval_steps_per_second: 1.0484553490007138, epoch: 0.0208

具体数据2:

[    INFO] - loss: 2.00048828, learning_rate: 2.1e-06, global_step: 20, interval_runtime: 21.0772, interval_samples_per_second: 7.591150606927911, interval_steps_per_second: 0.9488938258659889, epoch: 0.0014
[    INFO] - loss: 1.94342518, learning_rate: 4.1e-06, global_step: 40, interval_runtime: 18.8188, interval_samples_per_second: 8.502147053349322, interval_steps_per_second: 1.0627683816686653, epoch: 0.0028
[    INFO] - loss: 1.96450806, learning_rate: 6.1e-06, global_step: 60, interval_runtime: 18.8288, interval_samples_per_second: 8.497600131243356, interval_steps_per_second: 1.0622000164054195, epoch: 0.0042
[    INFO] - loss: 1.97299156, learning_rate: 8.1e-06, global_step: 80, interval_runtime: 18.7106, interval_samples_per_second: 8.551290530216257, interval_steps_per_second: 1.068911316277032, epoch: 0.0056
[    INFO] - loss: 1.97750511, learning_rate: 1e-05, global_step: 100, interval_runtime: 18.802, interval_samples_per_second: 8.509752841450895, interval_steps_per_second: 1.063719105181362, epoch: 0.0069
[    INFO] - loss: 1.95206642, learning_rate: 1e-05, global_step: 120, interval_runtime: 18.7325, interval_samples_per_second: 8.541326055744099, interval_steps_per_second: 1.0676657569680124, epoch: 0.0083
[    INFO] - loss: 1.94697475, learning_rate: 1e-05, global_step: 140, interval_runtime: 18.744, interval_samples_per_second: 8.53607879100723, interval_steps_per_second: 1.0670098488759037, epoch: 0.0097
[    INFO] - loss: 1.96742477, learning_rate: 1e-05, global_step: 160, interval_runtime: 18.9852, interval_samples_per_second: 8.427604951044907, interval_steps_per_second: 1.0534506188806134, epoch: 0.0111
[    INFO] - loss: 1.94573994, learning_rate: 9.999e-06, global_step: 180, interval_runtime: 18.7522, interval_samples_per_second: 8.532339308513075, interval_steps_per_second: 1.0665424135641344, epoch: 0.0125
[    INFO] - loss: 1.9408371, learning_rate: 9.999e-06, global_step: 200, interval_runtime: 18.9974, interval_samples_per_second: 8.42220448954145, interval_steps_per_second: 1.0527755611926812, epoch: 0.0139
[    INFO] - loss: 1.99254551, learning_rate: 9.998e-06, global_step: 220, interval_runtime: 18.7538, interval_samples_per_second: 8.531624257296887, interval_steps_per_second: 1.0664530321621108, epoch: 0.0153
[    INFO] - loss: 1.96885376, learning_rate: 9.997e-06, global_step: 240, interval_runtime: 18.7428, interval_samples_per_second: 8.536628007676303, interval_steps_per_second: 1.0670785009595378, epoch: 0.0167
[    INFO] - loss: 1.97938385, learning_rate: 9.997e-06, global_step: 260, interval_runtime: 18.8998, interval_samples_per_second: 8.465717244773764, interval_steps_per_second: 1.0582146555967205, epoch: 0.0181
[    INFO] - loss: 1.97675762, learning_rate: 9.996e-06, global_step: 280, interval_runtime: 18.7282, interval_samples_per_second: 8.543248921566835, interval_steps_per_second: 1.0679061151958544, epoch: 0.0194
[    INFO] - loss: 1.97571411, learning_rate: 9.995e-06, global_step: 300, interval_runtime: 18.7413, interval_samples_per_second: 8.537290028215388, interval_steps_per_second: 1.0671612535269235, epoch: 0.0208

不开确定算法的性能:

- loss: 2.00049, learning_rate: 2.1e-06, global_step: 20, interval_runtime: 21.3354, interval_samples_per_second: 7.499260565531818, interval_steps_per_second: 0.9374075706914773, epoch: 0.0014
- loss: 1.94349308, learning_rate: 4.1e-06, global_step: 40, interval_runtime: 18.761, interval_samples_per_second: 8.52834850825146, interval_steps_per_second: 1.0660435635314325, epoch: 0.0028
- loss: 1.96464062, learning_rate: 6.1e-06, global_step: 60, interval_runtime: 18.8037, interval_samples_per_second: 8.508971442820126, interval_steps_per_second: 1.0636214303525158, epoch: 0.0042
- loss: 1.97301292, learning_rate: 8.1e-06, global_step: 80, interval_runtime: 18.6895, interval_samples_per_second: 8.560953695414609, interval_steps_per_second: 1.0701192119268261, epoch: 0.0056
- loss: 1.97778549, learning_rate: 1e-05, global_step: 100, interval_runtime: 18.7648, interval_samples_per_second: 8.526600043709887, interval_steps_per_second: 1.065825005463736, epoch: 0.0069
- loss: 1.95245552, learning_rate: 1e-05, global_step: 120, interval_runtime: 18.7123, interval_samples_per_second: 8.550504971226191, interval_steps_per_second: 1.068813121403274, epoch: 0.0083
- loss: 1.94730988, learning_rate: 1e-05, global_step: 140, interval_runtime: 18.7294, interval_samples_per_second: 8.54271657817509, interval_steps_per_second: 1.0678395722718863, epoch: 0.0097
- loss: 1.96823559, learning_rate: 1e-05, global_step: 160, interval_runtime: 18.9585, interval_samples_per_second: 8.439481829428935, interval_steps_per_second: 1.0549352286786169, epoch: 0.0111
- loss: 1.94611969, learning_rate: 9.999e-06, global_step: 180, interval_runtime: 18.7352, interval_samples_per_second: 8.540070636271357, interval_steps_per_second: 1.0675088295339197, epoch: 0.0125
- loss: 1.9408411, learning_rate: 9.999e-06, global_step: 200, interval_runtime: 18.9657, interval_samples_per_second: 8.436264988184817, interval_steps_per_second: 1.0545331235231021, epoch: 0.0139
- loss: 1.99281082, learning_rate: 9.998e-06, global_step: 220, interval_runtime: 18.7107, interval_samples_per_second: 8.551269282269537, interval_steps_per_second: 1.0689086602836921, epoch: 0.0153

@paddle-bot
Copy link

paddle-bot bot commented Aug 16, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

paddle.set_flags({'FLAGS_cudnn_deterministic': 1})
self.test_dot_scale_product()
paddle.set_flags({'FLAGS_cudnn_deterministic': 0})

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需添加单测对比2次执行结果,输入相同时,要保证得到完全一样的输出,结果比较用np.equal

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR需要更新flash-attention的submodule,ci才能测试到flash-attention repo的修改

np.testing.assert_allclose(out1.numpy(), out1_, rtol=5e-03, atol=1e-03)

out2, out2_ = self.get_out_data()
np.equal(out1.numpy(), out2.numpy())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np.equal只是得到比较的结果,还需要加assert

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已经修改

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fa的PR还没有合入 submodule暂时无法更新

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为了在PR中验证正确性,可以先把submodule更新到自己fork的版本

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. 更新一下flash-attn的submodule
  2. 补充一下算子级和模型级的性能测试结果吧,确定性实现肯定比非确定性实现慢,需要看下大约慢多少。

int num_splits = 0; // 0 for an internal heuristic, which is optimal
if (FLAGS_cudnn_deterministic) {
num_splits = 1;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

封装成一个函数

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -306,6 +302,29 @@ def test_all(self):
np.testing.assert_allclose(
fetches_result[0], out_, rtol=5e-03, atol=1e-03
)
return out, out_, fetches_result[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fetches_result[0]是静态图执行的前向输出吧,确定性实现可以只比较测试动态图,但是需要检查前向out和反向的dqdkdv,保证两次执行的结果完全一样。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

// 0 for an internal heuristic, which is optimal
return FLAGS_cudnn_deterministic ? 1 : 0;
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个函数,以及kernel中两处int num_splits = get_num_split();建议直接封装在FlashAttnBwdParamsV2里面。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下个PR再改

Copy link
Contributor

@risemeup1 risemeup1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@ZzSean ZzSean left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for skipIf

Copy link
Contributor

@wanghuancoder wanghuancoder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -0,0 +1,208 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个单测挪出来了吗?那需要加到GPUPS CI跑的单测列表里面去,下个PR加下吧。

@Xreki Xreki merged commit 7fd6ffb into PaddlePaddle:develop Sep 4, 2023
25 of 26 checks passed
BeingGod pushed a commit to BeingGod/Paddle that referenced this pull request Sep 9, 2023
…ttnUnpaddedGradKernel (PaddlePaddle#56363)

* add num_splist for flash_attn_bwd and FlashAttnUnpaddedGradKernel

* Add assertTrue

* Update submodule to a specific commit
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants