Skip to content

[Op] Refactor qkv processing#46

Merged
szhengac merged 4 commits intoawslabs:mainfrom
comaniac:fix_qkv
Feb 9, 2023
Merged

[Op] Refactor qkv processing#46
szhengac merged 4 commits intoawslabs:mainfrom
comaniac:fix_qkv

Conversation

@comaniac
Copy link
Contributor

@comaniac comaniac commented Feb 8, 2023

Description

Pointed out by @szhengac, the current logic that uses .chunk(3, dim=-1) to split qkv assumes different data layouts for TP=1 and TP>1 cases. Specifically, when TP=1, we assume the qkv is contiguous, meaning that the weight layout is [q0q1,...,k0k1, ..., v0v1]. However, when TP>1, since weight is sharded along axis=0, each partitioned weight has [3 * H // TP]. This assumes the qkv layout is interleaved (i.e., [q0k0v0, ...]).

This won't be an issue if we always run the model within the same case, but the produces incorrect results if, for example, we trained the model with TP=2 but now want to fine-tune it with TP=1. Although transposing trained weights could also resolve this issue, this seems not straightforward to users.

This PR fixes this issue by assuming the qkv weights are always interleaved. This is also the methodology used in Megatron-LM. Accordingly, we need to manually transpose the weights in the unit test to match the GPT-2 attention results.

Checklist

  • PR's title starts with a category (e.g. [Bugfix], [Model], [Tutorial], etc)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented

@szhengac szhengac merged commit 9d6aed3 into awslabs:main Feb 9, 2023
@comaniac comaniac deleted the fix_qkv branch February 9, 2023 21:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants