Skip to content

Commit

Permalink
fixes to accomendate mcore changes (#8261)
Browse files Browse the repository at this point in the history
Signed-off-by: HuiyingLi <willwin.lee@gmail.com>
Signed-off-by: Pablo Garay <pagaray@nvidia.com>
  • Loading branch information
HuiyingLi authored and pablo-garay committed Mar 19, 2024
1 parent 7136a3e commit f140543
Showing 1 changed file with 1 addition and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def forward(
context_mask=None,
rotary_pos_emb=None,
inference_params=None,
packed_seq_params=None,
):
# hidden_states: [s, b, h]

Expand Down Expand Up @@ -161,60 +162,3 @@ def forward(
output = make_viewless_tensor(inp=hidden_states, requires_grad=hidden_states.requires_grad, keep_graph=True)

return output, context

def sharded_state_dict(self, prefix=''):

state_dict = self.state_dict(keep_vars=True)

tensor_parallel_layers_axis_map = {
'self_attention.linear_qkv.weight': 0,
'self_attention.linear_qkv.bias': 0,
'self_attention.linear_proj.weight': 1,
'mlp.linear_fc1.weight': 0,
'mlp.linear_fc1.bias': 0,
'mlp.linear_fc2.weight': 1,
}

offset = self._get_layer_offset()
num_layers = self.config.num_layers

sharded_state_dict = {}

for layer_name in state_dict.keys():
tensor = state_dict[layer_name]
global_layer_offset = self.layer_number - 1 # self.layer_number starts at 1
layer_key = f'{prefix}{global_layer_offset - offset}.{layer_name}' # module list index in TransformerBlock
sharded_offsets = [(0, global_layer_offset, num_layers)] # PP sharding

if layer_name in tensor_parallel_layers_axis_map:
tp_axis = tensor_parallel_layers_axis_map[layer_name]
# TP sharding
sharded_offsets.append(
[
tp_axis + 1, # +1 for PP dimension
parallel_state.get_tensor_model_parallel_rank(),
parallel_state.get_tensor_model_parallel_world_size(),
]
)
replica_id = parallel_state.get_data_parallel_rank()
else:
replica_id = (
parallel_state.get_data_parallel_rank() * parallel_state.get_data_parallel_world_size()
+ parallel_state.get_tensor_model_parallel_rank()
)

if layer_name.endswith('._extra_state'):
sharded_state_dict[layer_key] = ShardedObject(
f'{prefix}{layer_name}', tensor, (num_layers,), (global_layer_offset,), replica_id,
)

else:
sharded_state_dict[layer_key] = ShardedTensor.from_rank_offsets(
f'{prefix}{layer_name}',
tensor,
*sharded_offsets,
replica_id=replica_id,
prepend_axis_num=1, # for PP sharding
)

return sharded_state_dict

0 comments on commit f140543

Please sign in to comment.