Skip to content

Commit

Permalink
change the condition for get qkv tensor from linear_qkv output (#8965)
Browse files Browse the repository at this point in the history
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Co-authored-by: Adi Renduchintala <adithya.r@gmail.com>
Signed-off-by: Ao Tang <aot@nvidia.com>
  • Loading branch information
2 people authored and suiyoubi committed May 2, 2024
1 parent e86485e commit ac9bd15
Showing 1 changed file with 12 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -86,29 +86,27 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None):
linear_qkv_output, _ = self.linear_qkv(hidden_states)
layernorm_output = None

# In megatron/core/models/gpt/gpt_layer_specs.py TELayerNormColumnParallelLinear is used for linear_qkv.
# TELayerNormColumnParallelLinear fused LN and linear, both will be returned.
# In nemo/collections/nlp/models/language_modeling/megatron/falcon/falcon_spec.py TEColumnParallelLinear is used for linear_qkv,
# In megatron/core/models/gpt/gpt_layer_specs.py when fused module is used(e.g. TELayerNormColumnParallelLinear)
# both LN and qkv will be returned.
# In nemo/collections/nlp/models/language_modeling/megatron/falcon/falcon_spec.py TEColumnParallelLinear(non-fused) is used for linear_qkv,
# which only returns linear.
if isinstance(self.linear_qkv, TELayerNormColumnParallelLinear):
mixed_qkv, layernorm_output = linear_qkv_output
elif isinstance(self.linear_qkv, TEColumnParallelLinear): # only mixed_qkv
if isinstance(linear_qkv_output, tuple):
if len(linear_qkv_output) == 2: # fused module, qkv&LN
mixed_qkv, layernorm_output = linear_qkv_output
else:
raise ValueError(f"Unexpected number of outputs from linear_qkv output: {len(linear_qkv_output)}")
else: # for qkv&LN not fused only mixed_qkv
mixed_qkv = linear_qkv_output
else:
raise ValueError(
f"Unrecognized module type '{type(self.linear_qkv)}' when getting query, key, value tensors for mcore mixins. "
)

# LoRA logic
if self.is_adapter_available():
lora_kqv_adapter = self.get_adapter_module(AdapterName.LORA_KQV_ADAPTER)
if lora_kqv_adapter and self.adapter_cfg[AdapterName.LORA_KQV_ADAPTER]['enabled']:
if isinstance(self.linear_qkv, TELayerNormColumnParallelLinear):
if layernorm_output is not None:
lora_mixed_qkv = lora_kqv_adapter(layernorm_output)
elif isinstance(self.linear_qkv, TEColumnParallelLinear):
lora_mixed_qkv = lora_kqv_adapter(hidden_states)
else:
raise ValueError(f"Unrecognized module type '{type(self.linear_qkv)}' when applying lora.")
lora_mixed_qkv = lora_kqv_adapter(hidden_states)

mixed_qkv = mixed_qkv + lora_mixed_qkv

# [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
Expand Down

0 comments on commit ac9bd15

Please sign in to comment.