Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve some tests #5049

Merged
merged 3 commits into from
Dec 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 30 additions & 112 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import os.path as osp
import pickle
import platform
import re
from argparse import Namespace
from distutils.version import LooseVersion
from pathlib import Path
from unittest import mock
from unittest.mock import MagicMock, Mock
from unittest.mock import Mock

import cloudpickle
import pytest
Expand Down Expand Up @@ -641,20 +639,17 @@ def validation_epoch_end(self, outputs):
@pytest.mark.parametrize("enable_pl_optimizer", [False, True])
def test_checkpoint_repeated_strategy(enable_pl_optimizer, tmpdir):
"""
This test validates that the checkpoint can be called when provided to callacks list
This test validates that the checkpoint can be called when provided to callbacks list
"""

checkpoint_callback = ModelCheckpoint(monitor='val_loss', dirpath=tmpdir, filename="{epoch:02d}")

class ExtendedBoringModel(BoringModel):

def validation_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"val_loss": loss}

model = ExtendedBoringModel()
model.validation_step_end = None
model.validation_epoch_end = None
trainer = Trainer(
max_epochs=1,
Expand All @@ -663,92 +658,30 @@ def validation_step(self, batch, batch_idx):
limit_test_batches=2,
callbacks=[checkpoint_callback],
enable_pl_optimizer=enable_pl_optimizer,
weights_summary=None,
progress_bar_refresh_rate=0,
)

trainer.fit(model)
assert os.listdir(tmpdir) == ['epoch=00.ckpt']

def get_last_checkpoint():
ckpts = os.listdir(tmpdir)
ckpts_map = {int(x.split("=")[1].split('.')[0]): osp.join(tmpdir, x) for x in ckpts if "epoch" in x}
num_ckpts = len(ckpts_map) - 1
return ckpts_map[num_ckpts]

for idx in range(1, 5):
for idx in range(4):
# load from checkpoint
chk = get_last_checkpoint()
model = BoringModel.load_from_checkpoint(chk)
trainer = pl.Trainer(
max_epochs=1,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
resume_from_checkpoint=chk,
enable_pl_optimizer=enable_pl_optimizer)
trainer.fit(model)
trainer.test(model)

assert str(os.listdir(tmpdir)) == "['epoch=00.ckpt']"


@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
@pytest.mark.parametrize("enable_pl_optimizer", [False, True])
def test_checkpoint_repeated_strategy_tmpdir(enable_pl_optimizer, tmpdir):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note for reviewers: I merged these two tests because they are very similar and we don't need to test the filepath parameter here, we have tests for it already and it is deprecated

"""
This test validates that the checkpoint can be called when provided to callacks list
"""

checkpoint_callback = ModelCheckpoint(monitor='val_loss', filepath=os.path.join(tmpdir, "{epoch:02d}"))

class ExtendedBoringModel(BoringModel):

def validation_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"val_loss": loss}

model = ExtendedBoringModel()
model.validation_step_end = None
model.validation_epoch_end = None
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
callbacks=[checkpoint_callback],
enable_pl_optimizer=enable_pl_optimizer,
)

trainer.fit(model)
assert sorted(os.listdir(tmpdir)) == sorted(['epoch=00.ckpt', 'lightning_logs'])
path_to_lightning_logs = osp.join(tmpdir, 'lightning_logs')
assert sorted(os.listdir(path_to_lightning_logs)) == sorted(['version_0'])

def get_last_checkpoint():
ckpts = os.listdir(tmpdir)
ckpts_map = {int(x.split("=")[1].split('.')[0]): osp.join(tmpdir, x) for x in ckpts if "epoch" in x}
num_ckpts = len(ckpts_map) - 1
return ckpts_map[num_ckpts]

for idx in range(1, 5):

# load from checkpoint
chk = get_last_checkpoint()
model = LogInTwoMethods.load_from_checkpoint(chk)
model = LogInTwoMethods.load_from_checkpoint(checkpoint_callback.best_model_path)
trainer = pl.Trainer(
default_root_dir=tmpdir,
max_epochs=1,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
resume_from_checkpoint=chk,
enable_pl_optimizer=enable_pl_optimizer)

resume_from_checkpoint=checkpoint_callback.best_model_path,
enable_pl_optimizer=enable_pl_optimizer,
weights_summary=None,
progress_bar_refresh_rate=0,
)
trainer.fit(model)
trainer.test(model)
assert sorted(os.listdir(tmpdir)) == sorted(['epoch=00.ckpt', 'lightning_logs'])
assert sorted(os.listdir(path_to_lightning_logs)) == sorted([f'version_{i}' for i in range(idx + 1)])
trainer.test(model, verbose=False)
assert set(os.listdir(tmpdir)) == {'epoch=00.ckpt', 'lightning_logs'}
assert set(os.listdir(tmpdir.join("lightning_logs"))) == {f'version_{i}' for i in range(4)}


@mock.patch.dict(os.environ, {"PL_DEV_DEBUG": "1"})
Expand All @@ -760,86 +693,71 @@ def test_checkpoint_repeated_strategy_extended(enable_pl_optimizer, tmpdir):
"""

class ExtendedBoringModel(BoringModel):

def validation_step(self, batch, batch_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"val_loss": loss}

def validation_epoch_end(self, *_):
...

def assert_trainer_init(trainer):
assert not trainer.checkpoint_connector.has_trained
assert trainer.global_step == 0
assert trainer.current_epoch == 0

def get_last_checkpoint(ckpt_dir):
ckpts = os.listdir(ckpt_dir)
ckpts.sort()
return osp.join(ckpt_dir, ckpts[-1])
last = ckpt_dir.listdir(sort=True)[-1]
return str(last)

def assert_checkpoint_content(ckpt_dir):
chk = pl_load(get_last_checkpoint(ckpt_dir))
assert chk["epoch"] == epochs
assert chk["global_step"] == 4

def assert_checkpoint_log_dir(idx):
lightning_logs_path = osp.join(tmpdir, 'lightning_logs')
assert sorted(os.listdir(lightning_logs_path)) == [f'version_{i}' for i in range(idx + 1)]
assert len(os.listdir(ckpt_dir)) == epochs

def get_model():
model = ExtendedBoringModel()
model.validation_step_end = None
model.validation_epoch_end = None
return model
lightning_logs = tmpdir / 'lightning_logs'
actual = [d.basename for d in lightning_logs.listdir(sort=True)]
assert actual == [f'version_{i}' for i in range(idx + 1)]
assert len(ckpt_dir.listdir()) == epochs

ckpt_dir = osp.join(tmpdir, 'checkpoints')
ckpt_dir = tmpdir / 'checkpoints'
checkpoint_cb = ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1)
epochs = 2
limit_train_batches = 2

model = get_model()

trainer_config = dict(
default_root_dir=tmpdir,
max_epochs=epochs,
limit_train_batches=limit_train_batches,
limit_val_batches=3,
limit_test_batches=4,
enable_pl_optimizer=enable_pl_optimizer,
)

trainer = pl.Trainer(
**trainer_config,
callbacks=[checkpoint_cb],
)
trainer = pl.Trainer(**trainer_config)
assert_trainer_init(trainer)

model = ExtendedBoringModel()
trainer.fit(model)
assert trainer.checkpoint_connector.has_trained
assert trainer.global_step == epochs * limit_train_batches
assert trainer.current_epoch == epochs - 1
assert_checkpoint_log_dir(0)
assert_checkpoint_content(ckpt_dir)

trainer.test(model)
assert trainer.current_epoch == epochs - 1

assert_checkpoint_content(ckpt_dir)

for idx in range(1, 5):
chk = get_last_checkpoint(ckpt_dir)
assert_checkpoint_content(ckpt_dir)

checkpoint_cb = ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1)
model = get_model()

# load from checkpoint
trainer = pl.Trainer(
**trainer_config,
resume_from_checkpoint=chk,
callbacks=[checkpoint_cb],
)
trainer_config["callbacks"] = [ModelCheckpoint(dirpath=ckpt_dir, save_top_k=-1)]
trainer = pl.Trainer(**trainer_config, resume_from_checkpoint=chk)
assert_trainer_init(trainer)

model = ExtendedBoringModel()
trainer.test(model)
assert not trainer.checkpoint_connector.has_trained
assert trainer.global_step == epochs * limit_train_batches
Expand Down
Loading