diff --git a/tools/ckpts/convert_neox_to_hf.py b/tools/ckpts/convert_neox_to_hf.py index 35812383e..fd702d02b 100644 --- a/tools/ckpts/convert_neox_to_hf.py +++ b/tools/ckpts/convert_neox_to_hf.py @@ -674,7 +674,7 @@ def main(input_args=None, overwrite_values=None): # while Sequential model state dicts are saved all together in one mp_rank_xx_model_states.pt # file per tensor/model parallel shard. pipeline_world_size = get_key(loaded_config, "pipe-parallel-size", 1) - if pipeline_world_size == 0: + if pipeline_world_size <= 1: sequential = True print( f"Detected 'pipe-parallel-size' of {pipeline_world_size}, assuming model is saved as Sequential..."