diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 31154eac1bf0d..9817dfa4526c6 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -12,15 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import os -import os.path as osp import pickle import platform import re from argparse import Namespace -from distutils.version import LooseVersion from pathlib import Path from unittest import mock -from unittest.mock import MagicMock, Mock +from unittest.mock import Mock import cloudpickle import pytest @@ -641,20 +639,17 @@ def validation_epoch_end(self, outputs): @pytest.mark.parametrize("enable_pl_optimizer", [False, True]) def test_checkpoint_repeated_strategy(enable_pl_optimizer, tmpdir): """ - This test validates that the checkpoint can be called when provided to callacks list + This test validates that the checkpoint can be called when provided to callbacks list """ - checkpoint_callback = ModelCheckpoint(monitor='val_loss', dirpath=tmpdir, filename="{epoch:02d}") class ExtendedBoringModel(BoringModel): - def validation_step(self, batch, batch_idx): output = self.layer(batch) loss = self.loss(batch, output) return {"val_loss": loss} model = ExtendedBoringModel() - model.validation_step_end = None model.validation_epoch_end = None trainer = Trainer( max_epochs=1, @@ -663,92 +658,30 @@ def validation_step(self, batch, batch_idx): limit_test_batches=2, callbacks=[checkpoint_callback], enable_pl_optimizer=enable_pl_optimizer, + weights_summary=None, + progress_bar_refresh_rate=0, ) - trainer.fit(model) assert os.listdir(tmpdir) == ['epoch=00.ckpt'] - def get_last_checkpoint(): - ckpts = os.listdir(tmpdir) - ckpts_map = {int(x.split("=")[1].split('.')[0]): osp.join(tmpdir, x) for x in ckpts if "epoch" in x} - num_ckpts = len(ckpts_map) - 1 - return ckpts_map[num_ckpts] - - for idx in range(1, 5): + for idx in range(4): # load from checkpoint - chk = get_last_checkpoint() - model = BoringModel.load_from_checkpoint(chk) - trainer = pl.Trainer( - max_epochs=1, - limit_train_batches=2, - limit_val_batches=2, - limit_test_batches=2, - resume_from_checkpoint=chk, - enable_pl_optimizer=enable_pl_optimizer) - trainer.fit(model) - trainer.test(model) - - assert str(os.listdir(tmpdir)) == "['epoch=00.ckpt']" - - -@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) -@pytest.mark.parametrize("enable_pl_optimizer", [False, True]) -def test_checkpoint_repeated_strategy_tmpdir(enable_pl_optimizer, tmpdir): - """ - This test validates that the checkpoint can be called when provided to callacks list - """ - - checkpoint_callback = ModelCheckpoint(monitor='val_loss', filepath=os.path.join(tmpdir, "{epoch:02d}")) - - class ExtendedBoringModel(BoringModel): - - def validation_step(self, batch, batch_idx): - output = self.layer(batch) - loss = self.loss(batch, output) - return {"val_loss": loss} - - model = ExtendedBoringModel() - model.validation_step_end = None - model.validation_epoch_end = None - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=1, - limit_train_batches=2, - limit_val_batches=2, - limit_test_batches=2, - callbacks=[checkpoint_callback], - enable_pl_optimizer=enable_pl_optimizer, - ) - - trainer.fit(model) - assert sorted(os.listdir(tmpdir)) == sorted(['epoch=00.ckpt', 'lightning_logs']) - path_to_lightning_logs = osp.join(tmpdir, 'lightning_logs') - assert sorted(os.listdir(path_to_lightning_logs)) == sorted(['version_0']) - - def get_last_checkpoint(): - ckpts = os.listdir(tmpdir) - ckpts_map = {int(x.split("=")[1].split('.')[0]): osp.join(tmpdir, x) for x in ckpts if "epoch" in x} - num_ckpts = len(ckpts_map) - 1 - return ckpts_map[num_ckpts] - - for idx in range(1, 5): - - # load from checkpoint - chk = get_last_checkpoint() - model = LogInTwoMethods.load_from_checkpoint(chk) + model = LogInTwoMethods.load_from_checkpoint(checkpoint_callback.best_model_path) trainer = pl.Trainer( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, - resume_from_checkpoint=chk, - enable_pl_optimizer=enable_pl_optimizer) - + resume_from_checkpoint=checkpoint_callback.best_model_path, + enable_pl_optimizer=enable_pl_optimizer, + weights_summary=None, + progress_bar_refresh_rate=0, + ) trainer.fit(model) - trainer.test(model) - assert sorted(os.listdir(tmpdir)) == sorted(['epoch=00.ckpt', 'lightning_logs']) - assert sorted(os.listdir(path_to_lightning_logs)) == sorted([f'version_{i}' for i in range(idx + 1)]) + trainer.test(model, verbose=False) + assert set(os.listdir(tmpdir)) == {'epoch=00.ckpt', 'lightning_logs'} + assert set(os.listdir(tmpdir.join("lightning_logs"))) == {f'version_{i}' for i in range(4)} @mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"}) @@ -760,21 +693,22 @@ def test_checkpoint_repeated_strategy_extended(enable_pl_optimizer, tmpdir): """ class ExtendedBoringModel(BoringModel): - def validation_step(self, batch, batch_idx): output = self.layer(batch) loss = self.loss(batch, output) return {"val_loss": loss} + def validation_epoch_end(self, *_): + ... + def assert_trainer_init(trainer): assert not trainer.checkpoint_connector.has_trained assert trainer.global_step == 0 assert trainer.current_epoch == 0 def get_last_checkpoint(ckpt_dir): - ckpts = os.listdir(ckpt_dir) - ckpts.sort() - return osp.join(ckpt_dir, ckpts[-1]) + last = ckpt_dir.listdir(sort=True)[-1] + return str(last) def assert_checkpoint_content(ckpt_dir): chk = pl_load(get_last_checkpoint(ckpt_dir)) @@ -782,23 +716,15 @@ def assert_checkpoint_content(ckpt_dir): assert chk["global_step"] == 4 def assert_checkpoint_log_dir(idx): - lightning_logs_path = osp.join(tmpdir, 'lightning_logs') - assert sorted(os.listdir(lightning_logs_path)) == [f'version_{i}' for i in range(idx + 1)] - assert len(os.listdir(ckpt_dir)) == epochs - - def get_model(): - model = ExtendedBoringModel() - model.validation_step_end = None - model.validation_epoch_end = None - return model + lightning_logs = tmpdir / 'lightning_logs' + actual = [d.basename for d in lightning_logs.listdir(sort=True)] + assert actual == [f'version_{i}' for i in range(idx + 1)] + assert len(ckpt_dir.listdir()) == epochs - ckpt_dir = osp.join(tmpdir, 'checkpoints') + ckpt_dir = tmpdir / 'checkpoints' checkpoint_cb = ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1) epochs = 2 limit_train_batches = 2 - - model = get_model() - trainer_config = dict( default_root_dir=tmpdir, max_epochs=epochs, @@ -806,40 +732,32 @@ def get_model(): limit_val_batches=3, limit_test_batches=4, enable_pl_optimizer=enable_pl_optimizer, - ) - - trainer = pl.Trainer( - **trainer_config, callbacks=[checkpoint_cb], ) + trainer = pl.Trainer(**trainer_config) assert_trainer_init(trainer) + model = ExtendedBoringModel() trainer.fit(model) assert trainer.checkpoint_connector.has_trained assert trainer.global_step == epochs * limit_train_batches assert trainer.current_epoch == epochs - 1 assert_checkpoint_log_dir(0) + assert_checkpoint_content(ckpt_dir) trainer.test(model) assert trainer.current_epoch == epochs - 1 - assert_checkpoint_content(ckpt_dir) - for idx in range(1, 5): chk = get_last_checkpoint(ckpt_dir) assert_checkpoint_content(ckpt_dir) - checkpoint_cb = ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1) - model = get_model() - # load from checkpoint - trainer = pl.Trainer( - **trainer_config, - resume_from_checkpoint=chk, - callbacks=[checkpoint_cb], - ) + trainer_config["callbacks"] = [ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1)] + trainer = pl.Trainer(**trainer_config, resume_from_checkpoint=chk) assert_trainer_init(trainer) + model = ExtendedBoringModel() trainer.test(model) assert not trainer.checkpoint_connector.has_trained assert trainer.global_step == epochs * limit_train_batches diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index c24f1f5421e5c..9e5ceccf9b646 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -11,12 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import glob import math import os import pickle import sys -import types from argparse import Namespace from copy import deepcopy from pathlib import Path @@ -34,6 +32,7 @@ from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.profiler.profilers import AdvancedProfiler, PassThroughProfiler, SimpleProfiler from pytorch_lightning.trainer.logging import TrainerLoggingMixin +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities import NATIVE_AMP_AVAILABLE from pytorch_lightning.utilities.cloud_io import load as pl_load from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -61,6 +60,7 @@ def test_no_val_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): result = trainer.fit(model) # training complete assert result == 1, "amp + ddp model failed to complete" + assert trainer.state == TrainerState.FINISHED # save model new_weights_path = os.path.join(tmpdir, "save_test.ckpt") @@ -107,6 +107,7 @@ def test_no_val_end_module(monkeypatch, tmpdir, tmpdir_server, url_ckpt): # traning complete assert result == 1, "amp + ddp model failed to complete" + assert trainer.state == TrainerState.FINISHED # save model new_weights_path = os.path.join(tmpdir, "save_test.ckpt") @@ -151,6 +152,7 @@ def test_strict_model_load(monkeypatch, tmpdir, tmpdir_server, url_ckpt): # traning complete assert result == 1 + assert trainer.state == TrainerState.FINISHED # save model new_weights_path = os.path.join(tmpdir, "save_test.ckpt") @@ -468,6 +470,7 @@ def test_model_checkpoint_only_weights(tmpdir): result = trainer.fit(model) # training complete assert result == 1, "training failed to complete" + assert trainer.state == TrainerState.FINISHED checkpoint_path = list(trainer.checkpoint_callback.best_k_models.keys())[0] @@ -507,35 +510,23 @@ def test_resume_from_checkpoint_epoch_restored(monkeypatch, tmpdir, tmpdir_serve # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir monkeypatch.setenv("TORCH_HOME", tmpdir) - hparams = EvalModelTemplate.get_default_hparams() - - def _new_model(): - # Create a model that tracks epochs and batches seen - model = EvalModelTemplate(**hparams) - model.num_epochs_seen = 0 - model.num_batches_seen = 0 - model.num_on_load_checkpoint_called = 0 + class TestModel(BoringModel): + # Model that tracks epochs and batches seen + num_epochs_seen = 0 + num_batches_seen = 0 + num_on_load_checkpoint_called = 0 - def increment_epoch(self): + def on_epoch_end(self): self.num_epochs_seen += 1 - def increment_batch(self, batch, batch_idx, dataloader_idx): + def on_train_batch_start(self, *_): self.num_batches_seen += 1 - def increment_on_load_checkpoint(self, _): + def on_load_checkpoint(self, _): self.num_on_load_checkpoint_called += 1 - # Bind methods to keep track of epoch numbers, batch numbers it has seen - # as well as number of times it has called on_load_checkpoint() - model.on_epoch_end = types.MethodType(increment_epoch, model) - model.on_train_batch_start = types.MethodType(increment_batch, model) - model.on_load_checkpoint = types.MethodType(increment_on_load_checkpoint, model) - return model - - model = _new_model() - - trainer_options = dict( - progress_bar_refresh_rate=0, + model = TestModel() + trainer = Trainer( max_epochs=2, limit_train_batches=0.65, limit_val_batches=1, @@ -543,144 +534,125 @@ def increment_on_load_checkpoint(self, _): default_root_dir=tmpdir, val_check_interval=1.0, enable_pl_optimizer=enable_pl_optimizer, + progress_bar_refresh_rate=0, + logger=False, + weights_summary=None, ) - - trainer = Trainer(**trainer_options) - # fit model trainer.fit(model) - training_batches = trainer.num_training_batches - assert model.num_epochs_seen == 2 - assert model.num_batches_seen == training_batches * 2 + assert model.num_batches_seen == trainer.num_training_batches * 2 assert model.num_on_load_checkpoint_called == 0 # Other checkpoints can be uncommented if/when resuming mid-epoch is supported - checkpoints = sorted(glob.glob(os.path.join(trainer.checkpoint_callback.dirpath, "*.ckpt"))) + checkpoints = Path(trainer.checkpoint_callback.dirpath).glob("*.ckpt") if url_ckpt: # transform local paths into url checkpoints ip, port = tmpdir_server - checkpoints = [f"http://{ip}:{port}/" + os.path.basename(check) for check in checkpoints] + checkpoints = [f"http://{ip}:{port}/" + ckpt.name for ckpt in checkpoints] - for check in checkpoints: - next_model = _new_model() - state = pl_load(check) + for ckpt in checkpoints: + next_model = TestModel() + state = pl_load(ckpt) # Resume training - trainer_options["max_epochs"] = 2 - new_trainer = Trainer(**trainer_options, resume_from_checkpoint=check) + new_trainer = Trainer(resume_from_checkpoint=ckpt, max_epochs=2) new_trainer.fit(next_model) - assert state["global_step"] + next_model.num_batches_seen == training_batches * trainer_options["max_epochs"] + assert state["global_step"] + next_model.num_batches_seen == trainer.num_training_batches * trainer.max_epochs assert next_model.num_on_load_checkpoint_called == 1 -def _init_steps_model(): - """private method for initializing a model with 5% train epochs""" - model = EvalModelTemplate() - - # define train epoch to 5% of data - train_percent = 0.5 - # get number of samples in 1 epoch - num_train_samples = math.floor(len(model.train_dataloader()) * train_percent) - - trainer_options = dict( - limit_train_batches=train_percent, - ) - return model, trainer_options, num_train_samples - - def test_trainer_max_steps_and_epochs(tmpdir): """Verify model trains according to specified max steps""" - model, trainer_options, num_train_samples = _init_steps_model() + model = BoringModel() + num_train_samples = math.floor(len(model.train_dataloader()) * 0.5) # define less train steps than epochs - trainer_options.update( - default_root_dir=tmpdir, - max_epochs=3, - max_steps=num_train_samples + 10, - ) - - # fit model - trainer = Trainer(**trainer_options) + trainer_kwargs = { + 'limit_train_batches': 0.5, + 'default_root_dir': tmpdir, + 'max_epochs': 3, + 'max_steps': num_train_samples + 10, + 'logger': False, + 'weights_summary': None, + 'progress_bar_refresh_rate': 0, + } + trainer = Trainer(**trainer_kwargs) result = trainer.fit(model) - assert result == 1, "Training did not complete" - # check training stopped at max_steps + assert result == 1, "Training did not complete" + assert trainer.state == TrainerState.FINISHED assert trainer.global_step == trainer.max_steps, "Model did not stop at max_steps" # define less train epochs than steps - trainer_options.update( - max_epochs=2, - max_steps=trainer_options["max_epochs"] * 2 * num_train_samples, - ) - - # fit model - trainer = Trainer(**trainer_options) + trainer_kwargs['max_epochs'] = 2 + trainer_kwargs['max_steps'] = 3 * 2 * num_train_samples + trainer = Trainer(**trainer_kwargs) result = trainer.fit(model) - assert result == 1, "Training did not complete" - # check training stopped at max_epochs + assert result == 1, "Training did not complete" + assert trainer.state == TrainerState.FINISHED assert trainer.global_step == num_train_samples * trainer.max_epochs assert trainer.current_epoch == trainer.max_epochs - 1, "Model did not stop at max_epochs" def test_trainer_min_steps_and_epochs(tmpdir): """Verify model trains according to specified min steps""" - model, trainer_options, num_train_samples = _init_steps_model() - - # define callback for stopping the model and default epochs - trainer_options.update( - default_root_dir=tmpdir, - callbacks=[EarlyStopping(monitor="early_stop_on", min_delta=1.0)], - val_check_interval=2, - min_epochs=1, - max_epochs=7, - ) - - # define less min steps than 1 epoch - trainer_options["min_steps"] = math.floor(num_train_samples / 2) - - # fit model - trainer = Trainer(**trainer_options) + model = EvalModelTemplate() + num_train_samples = math.floor(len(model.train_dataloader()) * 0.5) + + trainer_kwargs = { + 'limit_train_batches': 0.5, + 'default_root_dir': tmpdir, + # define callback for stopping the model + 'callbacks': [EarlyStopping(monitor="early_stop_on", min_delta=1.0)], + 'val_check_interval': 2, + 'min_epochs': 1, + 'max_epochs': 7, + # define less min steps than 1 epoch + 'min_steps': num_train_samples // 2, + 'logger': False, + 'weights_summary': None, + 'progress_bar_refresh_rate': 0, + } + trainer = Trainer(**trainer_kwargs) result = trainer.fit(model) - assert result == 1, "Training did not complete" - # check model ran for at least min_epochs - assert ( - trainer.global_step >= num_train_samples and trainer.current_epoch > 0 - ), "Model did not train for at least min_epochs" + assert result == 1, "Training did not complete" + assert trainer.state == TrainerState.FINISHED + assert trainer.current_epoch > 0 + assert trainer.global_step >= num_train_samples, "Model did not train for at least min_epochs" # define less epochs than min_steps - trainer_options["min_steps"] = math.floor(num_train_samples * 1.5) - - # fit model - trainer = Trainer(**trainer_options) + trainer_kwargs["min_steps"] = math.floor(num_train_samples * 1.5) + trainer = Trainer(**trainer_kwargs) result = trainer.fit(model) - assert result == 1, "Training did not complete" - # check model ran for at least num_train_samples*1.5 - assert ( - trainer.global_step >= math.floor(num_train_samples * 1.5) and trainer.current_epoch > 0 - ), "Model did not train for at least min_steps" + assert result == 1, "Training did not complete" + assert trainer.state == TrainerState.FINISHED + assert trainer.current_epoch > 0 + assert trainer.global_step >= math.floor(num_train_samples * 1.5), "Model did not train for at least min_steps" def test_trainer_max_steps_accumulate_batches(tmpdir): """Verify model trains according to specified max steps with grad accumulated batches""" - model, trainer_options, num_train_samples = _init_steps_model() + model = BoringModel() + num_train_samples = math.floor(len(model.train_dataloader()) * 0.5) # define less train steps than epochs - trainer_options.update( + trainer = Trainer( + limit_train_batches=0.5, default_root_dir=tmpdir, - max_steps=(num_train_samples + 10), + max_steps=num_train_samples + 10, accumulate_grad_batches=10, + logger=False, + weights_summary=None, + progress_bar_refresh_rate=0, ) - - # fit model - trainer = Trainer(**trainer_options) result = trainer.fit(model) - assert result == 1, "Training did not complete" - # check training stopped at max_steps + assert result == 1, "Training did not complete" + assert trainer.state == TrainerState.FINISHED assert trainer.global_step == trainer.max_steps, "Model did not stop at max_steps" @@ -703,6 +675,7 @@ def test_benchmark_option(tmpdir): # verify training completed assert result == 1 + assert trainer.state == TrainerState.FINISHED # verify torch.backends.cudnn.benchmark is not turned off assert torch.backends.cudnn.benchmark @@ -788,6 +761,7 @@ def training_epoch_end(self, *args, **kwargs): # check that limit_train_batches=0 turns off training assert result == 1, "training failed to complete" + assert trainer.state == TrainerState.FINISHED assert trainer.current_epoch == 0 assert not model.training_step_invoked, "`training_step` should not run when `limit_train_batches=0`" assert not model.training_epoch_end_invoked, "`training_epoch_end` should not run when `limit_train_batches=0`" @@ -806,6 +780,7 @@ def training_epoch_end(self, *args, **kwargs): assert not torch.all(torch.eq(before_state_dict[key], after_state_dict[key])) assert result == 1, "training failed to complete" + assert trainer.state == TrainerState.FINISHED assert trainer.current_epoch == 0 assert model.training_step_invoked, "did not run `training_step` with `fast_dev_run=True`" assert model.training_epoch_end_invoked, "did not run `training_epoch_end` with `fast_dev_run=True`" @@ -844,6 +819,7 @@ def validation_epoch_end(self, *args, **kwargs): # check that limit_val_batches=0 turns off validation assert result == 1, "training failed to complete" + assert trainer.state == TrainerState.FINISHED assert trainer.current_epoch == 1 assert not model.validation_step_invoked, "`validation_step` should not run when `limit_val_batches=0`" assert not model.validation_epoch_end_invoked, "`validation_epoch_end` should not run when `limit_val_batches=0`" @@ -855,6 +831,7 @@ def validation_epoch_end(self, *args, **kwargs): result = trainer.fit(model) assert result == 1, "training failed to complete" + assert trainer.state == TrainerState.FINISHED assert trainer.current_epoch == 0 assert model.validation_step_invoked, "did not run `validation_step` with `fast_dev_run=True`" assert model.validation_epoch_end_invoked, "did not run `validation_epoch_end` with `fast_dev_run=True`" @@ -1119,7 +1096,7 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): @pytest.mark.parametrize( "trainer_kwargs,expected", [ - pytest.param( + ( dict(accelerator=None, gpus=None), dict( use_dp=False, @@ -1131,7 +1108,7 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): num_processes=1, ), ), - pytest.param( + ( dict(accelerator="dp", gpus=None), dict( use_dp=False, @@ -1143,7 +1120,7 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): num_processes=1, ), ), - pytest.param( + ( dict(accelerator="dp", gpus=None), dict( use_dp=False, @@ -1155,7 +1132,7 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): num_processes=1, ), ), - pytest.param( + ( dict(accelerator="ddp", gpus=None), dict( use_dp=False, @@ -1167,7 +1144,7 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): num_processes=1, ), ), - pytest.param( + ( dict(accelerator="ddp", num_processes=2, gpus=None), dict( use_dp=False, @@ -1179,7 +1156,7 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): num_processes=2, ), ), - pytest.param( + ( dict(accelerator="ddp", num_nodes=2, gpus=None), dict( use_dp=False, @@ -1191,7 +1168,7 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): num_processes=1, ), ), - pytest.param( + ( dict(accelerator="ddp_cpu", num_processes=2, gpus=None), dict( use_dp=False, @@ -1203,7 +1180,7 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): num_processes=2, ), ), - pytest.param( + ( dict(accelerator="ddp2", gpus=None), dict( use_dp=False, @@ -1215,7 +1192,7 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): num_processes=1, ), ), - pytest.param( + ( dict(accelerator=None, gpus=1), dict( use_dp=False, @@ -1226,9 +1203,8 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): use_single_gpu=True, num_processes=1, ), - marks=[pytest.mark.skipif(torch.cuda.device_count() == 0, reason="GPU needed")], ), - pytest.param( + ( dict(accelerator="dp", gpus=1), dict( use_dp=True, @@ -1239,9 +1215,8 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): use_single_gpu=True, num_processes=1, ), - marks=[pytest.mark.skipif(torch.cuda.device_count() == 0, reason="GPU needed")], ), - pytest.param( + ( dict(accelerator="ddp", gpus=1), dict( use_dp=False, @@ -1252,9 +1227,8 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): use_single_gpu=True, num_processes=1, ), - marks=[pytest.mark.skipif(torch.cuda.device_count() == 0, reason="GPU needed")], ), - pytest.param( + ( dict(accelerator="ddp_cpu", num_processes=2, gpus=1), dict( use_dp=False, @@ -1265,9 +1239,8 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): use_single_gpu=False, num_processes=2, ), - marks=[pytest.mark.skipif(torch.cuda.device_count() == 0, reason="GPU needed")], ), - pytest.param( + ( dict(accelerator="ddp2", gpus=1), dict( use_dp=False, @@ -1278,9 +1251,8 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): use_single_gpu=False, num_processes=1, ), - marks=[pytest.mark.skipif(torch.cuda.device_count() == 0, reason="GPU needed")], ), - pytest.param( + ( dict(accelerator=None, gpus=2), dict( use_dp=False, @@ -1291,9 +1263,8 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): use_single_gpu=False, num_processes=2, ), - marks=[pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Multiple GPUs needed")], ), - pytest.param( + ( dict(accelerator="dp", gpus=2), dict( use_dp=True, @@ -1304,9 +1275,8 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): use_single_gpu=False, num_processes=1, ), - marks=[pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Multiple GPUs needed")], ), - pytest.param( + ( dict(accelerator="ddp", gpus=2), dict( use_dp=False, @@ -1317,9 +1287,8 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): use_single_gpu=False, num_processes=2, ), - marks=[pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Multiple GPUs needed")], ), - pytest.param( + ( dict(accelerator="ddp2", gpus=2), dict( use_dp=False, @@ -1330,21 +1299,17 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches): use_single_gpu=False, num_processes=1, ), - marks=[pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Multiple GPUs needed")], ), ], ) -# Todo: mock nb Gpus so all these tests can run on any device -# todo: think about simplification, that the the expected will be just a list use_xxx which shall be true... -def test_trainer_config(trainer_kwargs, expected): +def test_trainer_config(trainer_kwargs, expected, monkeypatch): + if trainer_kwargs["gpus"] is not None: + monkeypatch.setattr(torch.cuda, "is_available", lambda: True) + monkeypatch.setattr(torch.cuda, "device_count", lambda: trainer_kwargs["gpus"]) trainer = Trainer(**trainer_kwargs) - assert trainer.use_dp is expected["use_dp"], 'for input: %s' % trainer_kwargs - assert trainer.use_ddp is expected["use_ddp"], 'for input: %s' % trainer_kwargs - assert trainer.use_ddp2 is expected["use_ddp2"], 'for input: %s' % trainer_kwargs - assert trainer.num_gpus == expected["num_gpus"], 'for input: %s' % trainer_kwargs - assert trainer.on_gpu is expected["on_gpu"], 'for input: %s' % trainer_kwargs - assert trainer.use_single_gpu is expected["use_single_gpu"], 'for input: %s' % trainer_kwargs - assert trainer.num_processes == expected["num_processes"], 'for input: %s' % trainer_kwargs + assert len(expected) == 7 + for k, v in expected.items(): + assert getattr(trainer, k) == v, f"Failed {k}: {v}" def test_trainer_subclassing(): @@ -1360,6 +1325,7 @@ def __init__(self, custom_arg, *args, custom_kwarg="test", **kwargs): trainer = TrainerSubclass(123, custom_kwarg="custom", fast_dev_run=True) result = trainer.fit(model) assert result == 1 + assert trainer.state == TrainerState.FINISHED assert trainer.custom_arg == 123 assert trainer.custom_kwarg == "custom" assert trainer.fast_dev_run @@ -1375,6 +1341,7 @@ def __init__(self, **kwargs): trainer = TrainerSubclass(custom_kwarg="custom", fast_dev_run=True) result = trainer.fit(model) assert result == 1 + assert trainer.state == TrainerState.FINISHED assert trainer.custom_kwarg == "custom" assert trainer.fast_dev_run