diff --git a/tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py b/tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py index a409c40f2..1b044b845 100644 --- a/tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py +++ b/tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py @@ -212,7 +212,8 @@ def _forward_backward_test_impl( params = list(model_module.parameters()) rank = params[0].get_device() offset = pipeline_model_parallel_world_size - param_id = rank // data_parallel_size + vm_id * offset + param_id = parallel_state.get_pipeline_model_parallel_rank() + vm_id * pipeline_model_parallel_world_size + # param_id = rank // data_parallel_size + vm_id * offset target_params = target_model[param_id] self.assertEqual(params[0].cpu(), target_params[0])