Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Remove dataloader related methods from LitMNIST and use MNISTDataModule. Add docstring. * Fix test_base_log_interval_fallback and test_base_log_interval_override = * remove data_dir arg * Move logging to shared step * Add data_dir and batch_size to parser args. Add typing hints to LitMNIST methods. * Remove Literal as not available in Python <3.8. * Get dataset specific args in CLI from MNIST datamodule * message is regex * accelerator auto * Double limit_train, val, test batches for trainer Co-authored-by: otaj <6065855+otaj@users.noreply.github.com> Co-authored-by: otaj <ota@lightning.ai>
- Loading branch information
1 parent
cbe4143
commit ac98469
Showing
3 changed files
with
82 additions
and
72 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,32 @@ | ||
import warnings | ||
|
||
from pytorch_lightning import Trainer, seed_everything | ||
from pytorch_lightning.utilities.warnings import PossibleUserWarning | ||
|
||
from pl_bolts.datamodules import MNISTDataModule | ||
from pl_bolts.models import LitMNIST | ||
|
||
|
||
def test_mnist(tmpdir, datadir): | ||
seed_everything() | ||
def test_mnist(tmpdir, datadir, catch_warnings): | ||
warnings.filterwarnings( | ||
"ignore", | ||
message=".+does not have many workers which may be a bottleneck.+", | ||
category=PossibleUserWarning, | ||
) | ||
|
||
seed_everything(1234) | ||
|
||
model = LitMNIST(data_dir=datadir, num_workers=0) | ||
datamodule = MNISTDataModule(data_dir=datadir, num_workers=0) | ||
model = LitMNIST() | ||
trainer = Trainer( | ||
limit_train_batches=0.01, | ||
limit_val_batches=0.01, | ||
limit_train_batches=0.02, | ||
limit_val_batches=0.02, | ||
max_epochs=1, | ||
limit_test_batches=0.01, | ||
limit_test_batches=0.02, | ||
default_root_dir=tmpdir, | ||
log_every_n_steps=5, | ||
accelerator="auto", | ||
) | ||
trainer.fit(model) | ||
trainer.fit(model, datamodule=datamodule) | ||
loss = trainer.callback_metrics["train_loss"] | ||
assert loss <= 2.2, "mnist failed" |