Skip to content

Commit

Permalink
Fix async ckpt unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
mikolajblaz committed May 13, 2024
1 parent df86290 commit b36e603
Showing 1 changed file with 33 additions and 25 deletions.
58 changes: 33 additions & 25 deletions tests/core/test_dist_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,22 @@ def save_checkpoint(self, *args, **kwargs) -> None:

def _get_last_checkpoint_dir(root_dir: Path, model: pl.LightningModule, suffix: str = '') -> Path:
steps = len(model.train_dataloader().dataset) * model.trainer.max_epochs // torch.distributed.get_world_size()
return root_dir / 'checkpoints' / f'epoch=1-step={steps}{suffix}'
return root_dir / 'checkpoints' / f'epoch={model.trainer.max_epochs - 1}-step={steps}{suffix}'


def _get_nlp_strategy_without_optimizer_state():
strategy = NLPDDPStrategy()
# this ensures optimizer sharded state creation is skipped
strategy.optimizer_sharded_state_dict = types.MethodType(
lambda self, unsharded_optim_state: unsharded_optim_state, strategy
)
return strategy


class TestDistCkptIO:
@pytest.mark.run_only_on('GPU')
def test_dist_ckpt_io_called_for_mcore_models(self, tmp_path):
strategy = NLPDDPStrategy()
# skip optimizer sharded state creation:
strategy.optimizer_sharded_state_dict = types.MethodType(
lambda self, unsharded_optim_state: unsharded_optim_state, strategy
)
strategy = _get_nlp_strategy_without_optimizer_state()
checkpoint_io = MockDistributedCheckpointIO('xxx')

test_trainer = pl.Trainer(
Expand Down Expand Up @@ -120,43 +125,46 @@ def test_dist_ckpt_path_not_executed_for_non_core_models(self, tmp_path):
class TestAsyncSave:
@pytest.mark.run_only_on('GPU')
def test_async_save_produces_same_checkpoints_as_sync(self, tmp_path):
strategy = NLPDDPStrategy()
# skip optimizer sharded state creation:
strategy.optimizer_sharded_state_dict = types.MethodType(
lambda self, unsharded_optim_state: unsharded_optim_state, strategy
)
strategy = _get_nlp_strategy_without_optimizer_state()
sync_checkpoint_io = DistributedCheckpointIO('torch_dist')
async_checkpoint_io = AsyncFinalizableCheckpointIO(DistributedCheckpointIO('torch_dist', async_save=True))

model = ExampleMCoreModel()

# dummy_trainer just to initialize NCCL
dummy_trainer = pl.Trainer(
enable_checkpointing=False,
logger=False,
max_epochs=1,
strategy=_get_nlp_strategy_without_optimizer_state(),
plugins=[sync_checkpoint_io],
)
dummy_trainer.fit(model)
tmp_path = strategy.broadcast(tmp_path)

sync_ckpt_dir = tmp_path / 'sync_checkpoints'
async_ckpt_dir = tmp_path / 'async_checkpoints'

test_trainer = pl.Trainer(
sync_test_trainer = pl.Trainer(
enable_checkpointing=True,
logger=False,
max_epochs=2,
strategy=strategy,
max_epochs=1,
strategy=_get_nlp_strategy_without_optimizer_state(),
plugins=[sync_checkpoint_io],
default_root_dir=sync_ckpt_dir,
)
model = ExampleMCoreModel()
test_trainer.fit(model)
sync_test_trainer.fit(model)

strategy = NLPDDPStrategy()
# skip optimizer sharded state creation:
strategy.optimizer_sharded_state_dict = types.MethodType(
lambda self, unsharded_optim_state: unsharded_optim_state, strategy
)
test_trainer = pl.Trainer(
async_test_trainer = pl.Trainer(
enable_checkpointing=True,
logger=False,
max_epochs=2,
strategy=strategy,
max_epochs=1,
strategy=_get_nlp_strategy_without_optimizer_state(),
plugins=[async_checkpoint_io],
callbacks=AsyncFinalizerCallback(),
default_root_dir=async_ckpt_dir,
)
test_trainer.fit(model)
async_test_trainer.fit(model)

# Load and compare checkpoints
checkpoint = {'sharded_state_dict': model.sharded_state_dict()}
Expand Down

0 comments on commit b36e603

Please sign in to comment.