From ac98469e3d7be04955a4c98822584a1b6b0ca756 Mon Sep 17 00:00:00 2001 From: Shion Matsumoto Date: Fri, 23 Sep 2022 18:29:28 -0400 Subject: [PATCH] Revision of MNIST module (#873) * Remove dataloader related methods from LitMNIST and use MNISTDataModule. Add docstring. * Fix test_base_log_interval_fallback and test_base_log_interval_override = * remove data_dir arg * Move logging to shared step * Add data_dir and batch_size to parser args. Add typing hints to LitMNIST methods. * Remove Literal as not available in Python <3.8. * Get dataset specific args in CLI from MNIST datamodule * message is regex * accelerator auto * Double limit_train, val, test batches for trainer Co-authored-by: otaj <6065855+otaj@users.noreply.github.com> Co-authored-by: otaj --- pl_bolts/models/mnist_module.py | 116 +++++++++++++-------------- tests/callbacks/test_data_monitor.py | 11 ++- tests/models/test_mnist_templates.py | 27 +++++-- 3 files changed, 82 insertions(+), 72 deletions(-) diff --git a/pl_bolts/models/mnist_module.py b/pl_bolts/models/mnist_module.py index 7cd49dd892..1e470ea579 100644 --- a/pl_bolts/models/mnist_module.py +++ b/pl_bolts/models/mnist_module.py @@ -1,24 +1,33 @@ from argparse import ArgumentParser +from typing import Any import torch from pytorch_lightning import LightningModule, Trainer +from torch import Tensor from torch.nn import functional as F -from torch.utils.data import DataLoader, random_split -from pl_bolts.datasets import MNIST from pl_bolts.utils import _TORCHVISION_AVAILABLE -from pl_bolts.utils.stability import under_review -from pl_bolts.utils.warnings import warn_missing_pkg -if _TORCHVISION_AVAILABLE: - from torchvision import transforms -else: # pragma: no cover - warn_missing_pkg("torchvision") - -@under_review() class LitMNIST(LightningModule): - def __init__(self, hidden_dim=128, learning_rate=1e-3, batch_size=32, num_workers=4, data_dir="", **kwargs): + """PyTorch Lightning implementation of a two-layer MNIST classification module. + + Args: + hidden_dim (int, optional): dimension of hidden layer (default: ``128``). + learning_rate (float, optional): optimizer learning rate (default: ``1e-3``). + + Example:: + + datamodule = MNISTDataModule() + + model = LitMNIST() + + trainer = Trainer() + trainer.fit(model, datamodule=datamodule) + """ + + def __init__(self, hidden_dim: int = 128, learning_rate: float = 1e-3, **kwargs: Any) -> None: + if not _TORCHVISION_AVAILABLE: # pragma: no cover raise ModuleNotFoundError("You want to use `torchvision` which is not installed yet.") @@ -28,82 +37,67 @@ def __init__(self, hidden_dim=128, learning_rate=1e-3, batch_size=32, num_worker self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim) self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10) - self.mnist_train = None - self.mnist_val = None - - def forward(self, x): - x = x.view(x.size(0), -1) - x = torch.relu(self.l1(x)) - x = torch.relu(self.l2(x)) - return x + def forward(self, x: Tensor) -> Tensor: + out = x.view(x.size(0), -1) + out = torch.relu(self.l1(out)) + out = torch.relu(self.l2(out)) + return out - def training_step(self, batch, batch_idx): + def shared_step(self, batch: Any, batch_idx: int, step: str) -> Tensor: x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) - self.log("train_loss", loss) - return loss - def validation_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - loss = F.cross_entropy(y_hat, y) - self.log("val_loss", loss) + if step == "train": + self.log("train_loss", loss) + elif step == "val": + self.log("val_loss", loss) + elif step == "test": + self.log("test_loss", loss) + else: + raise ValueError(f"Step {step} is not recognized. Must be 'train', 'val', or 'test'.") - def test_step(self, batch, batch_idx): - x, y = batch - y_hat = self(x) - loss = F.cross_entropy(y_hat, y) - self.log("test_loss", loss) - - def configure_optimizers(self): - return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) + return loss - def prepare_data(self): - MNIST(self.hparams.data_dir, train=True, download=True, transform=transforms.ToTensor()) + def training_step(self, batch: Any, batch_idx: int) -> Tensor: + return self.shared_step(batch, batch_idx, "train") - def train_dataloader(self): - dataset = MNIST(self.hparams.data_dir, train=True, download=False, transform=transforms.ToTensor()) - mnist_train, _ = random_split(dataset, [55000, 5000]) - loader = DataLoader(mnist_train, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers) - return loader + def validation_step(self, batch: Any, batch_idx: int) -> None: + self.shared_step(batch, batch_idx, "val") - def val_dataloader(self): - dataset = MNIST(self.hparams.data_dir, train=True, download=False, transform=transforms.ToTensor()) - _, mnist_val = random_split(dataset, [55000, 5000]) - loader = DataLoader(mnist_val, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers) - return loader + def test_step(self, batch: Any, batch_idx: int) -> None: + self.shared_step(batch, batch_idx, "test") - def test_dataloader(self): - test_dataset = MNIST(self.hparams.data_dir, train=False, download=True, transform=transforms.ToTensor()) - loader = DataLoader(test_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers) - return loader + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) @staticmethod - def add_model_specific_args(parent_parser): + def add_model_specific_args(parent_parser) -> ArgumentParser: parser = ArgumentParser(parents=[parent_parser], add_help=False) - parser.add_argument("--batch_size", type=int, default=32) - parser.add_argument("--num_workers", type=int, default=4) parser.add_argument("--hidden_dim", type=int, default=128) - parser.add_argument("--data_dir", type=str, default="") - parser.add_argument("--learning_rate", type=float, default=0.0001) + parser.add_argument("--learning_rate", type=float, default=1e-3) return parser -@under_review() def cli_main(): - # args + from pl_bolts.datamodules import MNISTDataModule + parser = ArgumentParser() parser = Trainer.add_argparse_args(parser) parser = LitMNIST.add_model_specific_args(parser) + parser = MNISTDataModule.add_dataset_specific_args(parser) + args = parser.parse_args() - # model + # Initialize MNISTDatamodule + datamodule = MNISTDataModule.from_argparse_args(args) + + # Initialize LitMNIST model model = LitMNIST(**vars(args)) - # training + # Train LitMNIST model trainer = Trainer.from_argparse_args(args) - trainer.fit(model) + trainer.fit(model, datamodule=datamodule) if __name__ == "__main__": # pragma: no cover diff --git a/tests/callbacks/test_data_monitor.py b/tests/callbacks/test_data_monitor.py index e4abe1f914..495f4c002a 100644 --- a/tests/callbacks/test_data_monitor.py +++ b/tests/callbacks/test_data_monitor.py @@ -8,6 +8,7 @@ from torch import nn from pl_bolts.callbacks import ModuleDataMonitor, TrainingDataMonitor +from pl_bolts.datamodules import MNISTDataModule from pl_bolts.models import LitMNIST @@ -16,7 +17,8 @@ def test_base_log_interval_override(log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls, datadir): """Test logging interval set by log_every_n_steps argument.""" monitor = TrainingDataMonitor(log_every_n_steps=log_every_n_steps) - model = LitMNIST(data_dir=datadir, num_workers=0) + model = LitMNIST(num_workers=0) + datamodule = MNISTDataModule(data_dir=datadir) trainer = Trainer( default_root_dir=tmpdir, log_every_n_steps=1, @@ -24,7 +26,7 @@ def test_base_log_interval_override(log_histogram, tmpdir, log_every_n_steps, ma callbacks=[monitor], ) - trainer.fit(model) + trainer.fit(model, datamodule=datamodule) assert log_histogram.call_count == (expected_calls * 2) # 2 tensors per log call @@ -41,14 +43,15 @@ def test_base_log_interval_override(log_histogram, tmpdir, log_every_n_steps, ma def test_base_log_interval_fallback(log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls, datadir): """Test that if log_every_n_steps not set in the callback, fallback to what is defined in the Trainer.""" monitor = TrainingDataMonitor() - model = LitMNIST(data_dir=datadir, num_workers=0) + model = LitMNIST(num_workers=0) + datamodule = MNISTDataModule(data_dir=datadir) trainer = Trainer( default_root_dir=tmpdir, log_every_n_steps=log_every_n_steps, max_steps=max_steps, callbacks=[monitor], ) - trainer.fit(model) + trainer.fit(model, datamodule=datamodule) assert log_histogram.call_count == (expected_calls * 2) # 2 tensors per log call diff --git a/tests/models/test_mnist_templates.py b/tests/models/test_mnist_templates.py index 6177f9b740..fc8435803b 100644 --- a/tests/models/test_mnist_templates.py +++ b/tests/models/test_mnist_templates.py @@ -1,19 +1,32 @@ +import warnings + from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning.utilities.warnings import PossibleUserWarning +from pl_bolts.datamodules import MNISTDataModule from pl_bolts.models import LitMNIST -def test_mnist(tmpdir, datadir): - seed_everything() +def test_mnist(tmpdir, datadir, catch_warnings): + warnings.filterwarnings( + "ignore", + message=".+does not have many workers which may be a bottleneck.+", + category=PossibleUserWarning, + ) + + seed_everything(1234) - model = LitMNIST(data_dir=datadir, num_workers=0) + datamodule = MNISTDataModule(data_dir=datadir, num_workers=0) + model = LitMNIST() trainer = Trainer( - limit_train_batches=0.01, - limit_val_batches=0.01, + limit_train_batches=0.02, + limit_val_batches=0.02, max_epochs=1, - limit_test_batches=0.01, + limit_test_batches=0.02, default_root_dir=tmpdir, + log_every_n_steps=5, + accelerator="auto", ) - trainer.fit(model) + trainer.fit(model, datamodule=datamodule) loss = trainer.callback_metrics["train_loss"] assert loss <= 2.2, "mnist failed"