Skip to content

Commit

Permalink
Revision of MNIST module (#873)
Browse files Browse the repository at this point in the history
* Remove dataloader related methods from LitMNIST and use MNISTDataModule. Add docstring.
* Fix test_base_log_interval_fallback and test_base_log_interval_override =
* remove data_dir arg
* Move logging to shared step
* Add data_dir and batch_size to parser args. Add typing hints to LitMNIST methods.
* Remove Literal as not available in Python <3.8.
* Get dataset specific args in CLI from MNIST datamodule
* message is regex
* accelerator auto
* Double limit_train, val, test batches for trainer

Co-authored-by: otaj <6065855+otaj@users.noreply.github.com>
Co-authored-by: otaj <ota@lightning.ai>
  • Loading branch information
3 people committed Sep 23, 2022
1 parent cbe4143 commit ac98469
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 72 deletions.
116 changes: 55 additions & 61 deletions pl_bolts/models/mnist_module.py
@@ -1,24 +1,33 @@
from argparse import ArgumentParser
from typing import Any

import torch
from pytorch_lightning import LightningModule, Trainer
from torch import Tensor
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split

from pl_bolts.datasets import MNIST
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
from torchvision import transforms
else: # pragma: no cover
warn_missing_pkg("torchvision")


@under_review()
class LitMNIST(LightningModule):
def __init__(self, hidden_dim=128, learning_rate=1e-3, batch_size=32, num_workers=4, data_dir="", **kwargs):
"""PyTorch Lightning implementation of a two-layer MNIST classification module.
Args:
hidden_dim (int, optional): dimension of hidden layer (default: ``128``).
learning_rate (float, optional): optimizer learning rate (default: ``1e-3``).
Example::
datamodule = MNISTDataModule()
model = LitMNIST()
trainer = Trainer()
trainer.fit(model, datamodule=datamodule)
"""

def __init__(self, hidden_dim: int = 128, learning_rate: float = 1e-3, **kwargs: Any) -> None:

if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError("You want to use `torchvision` which is not installed yet.")

Expand All @@ -28,82 +37,67 @@ def __init__(self, hidden_dim=128, learning_rate=1e-3, batch_size=32, num_worker
self.l1 = torch.nn.Linear(28 * 28, self.hparams.hidden_dim)
self.l2 = torch.nn.Linear(self.hparams.hidden_dim, 10)

self.mnist_train = None
self.mnist_val = None

def forward(self, x):
x = x.view(x.size(0), -1)
x = torch.relu(self.l1(x))
x = torch.relu(self.l2(x))
return x
def forward(self, x: Tensor) -> Tensor:
out = x.view(x.size(0), -1)
out = torch.relu(self.l1(out))
out = torch.relu(self.l2(out))
return out

def training_step(self, batch, batch_idx):
def shared_step(self, batch: Any, batch_idx: int, step: str) -> Tensor:
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log("train_loss", loss)
return loss

def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log("val_loss", loss)
if step == "train":
self.log("train_loss", loss)
elif step == "val":
self.log("val_loss", loss)
elif step == "test":
self.log("test_loss", loss)
else:
raise ValueError(f"Step {step} is not recognized. Must be 'train', 'val', or 'test'.")

def test_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = F.cross_entropy(y_hat, y)
self.log("test_loss", loss)

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
return loss

def prepare_data(self):
MNIST(self.hparams.data_dir, train=True, download=True, transform=transforms.ToTensor())
def training_step(self, batch: Any, batch_idx: int) -> Tensor:
return self.shared_step(batch, batch_idx, "train")

def train_dataloader(self):
dataset = MNIST(self.hparams.data_dir, train=True, download=False, transform=transforms.ToTensor())
mnist_train, _ = random_split(dataset, [55000, 5000])
loader = DataLoader(mnist_train, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers)
return loader
def validation_step(self, batch: Any, batch_idx: int) -> None:
self.shared_step(batch, batch_idx, "val")

def val_dataloader(self):
dataset = MNIST(self.hparams.data_dir, train=True, download=False, transform=transforms.ToTensor())
_, mnist_val = random_split(dataset, [55000, 5000])
loader = DataLoader(mnist_val, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers)
return loader
def test_step(self, batch: Any, batch_idx: int) -> None:
self.shared_step(batch, batch_idx, "test")

def test_dataloader(self):
test_dataset = MNIST(self.hparams.data_dir, train=False, download=True, transform=transforms.ToTensor())
loader = DataLoader(test_dataset, batch_size=self.hparams.batch_size, num_workers=self.hparams.num_workers)
return loader
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

@staticmethod
def add_model_specific_args(parent_parser):
def add_model_specific_args(parent_parser) -> ArgumentParser:
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument("--hidden_dim", type=int, default=128)
parser.add_argument("--data_dir", type=str, default="")
parser.add_argument("--learning_rate", type=float, default=0.0001)
parser.add_argument("--learning_rate", type=float, default=1e-3)
return parser


@under_review()
def cli_main():
# args
from pl_bolts.datamodules import MNISTDataModule

parser = ArgumentParser()
parser = Trainer.add_argparse_args(parser)
parser = LitMNIST.add_model_specific_args(parser)
parser = MNISTDataModule.add_dataset_specific_args(parser)

args = parser.parse_args()

# model
# Initialize MNISTDatamodule
datamodule = MNISTDataModule.from_argparse_args(args)

# Initialize LitMNIST model
model = LitMNIST(**vars(args))

# training
# Train LitMNIST model
trainer = Trainer.from_argparse_args(args)
trainer.fit(model)
trainer.fit(model, datamodule=datamodule)


if __name__ == "__main__": # pragma: no cover
Expand Down
11 changes: 7 additions & 4 deletions tests/callbacks/test_data_monitor.py
Expand Up @@ -8,6 +8,7 @@
from torch import nn

from pl_bolts.callbacks import ModuleDataMonitor, TrainingDataMonitor
from pl_bolts.datamodules import MNISTDataModule
from pl_bolts.models import LitMNIST


Expand All @@ -16,15 +17,16 @@
def test_base_log_interval_override(log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls, datadir):
"""Test logging interval set by log_every_n_steps argument."""
monitor = TrainingDataMonitor(log_every_n_steps=log_every_n_steps)
model = LitMNIST(data_dir=datadir, num_workers=0)
model = LitMNIST(num_workers=0)
datamodule = MNISTDataModule(data_dir=datadir)
trainer = Trainer(
default_root_dir=tmpdir,
log_every_n_steps=1,
max_steps=max_steps,
callbacks=[monitor],
)

trainer.fit(model)
trainer.fit(model, datamodule=datamodule)
assert log_histogram.call_count == (expected_calls * 2) # 2 tensors per log call


Expand All @@ -41,14 +43,15 @@ def test_base_log_interval_override(log_histogram, tmpdir, log_every_n_steps, ma
def test_base_log_interval_fallback(log_histogram, tmpdir, log_every_n_steps, max_steps, expected_calls, datadir):
"""Test that if log_every_n_steps not set in the callback, fallback to what is defined in the Trainer."""
monitor = TrainingDataMonitor()
model = LitMNIST(data_dir=datadir, num_workers=0)
model = LitMNIST(num_workers=0)
datamodule = MNISTDataModule(data_dir=datadir)
trainer = Trainer(
default_root_dir=tmpdir,
log_every_n_steps=log_every_n_steps,
max_steps=max_steps,
callbacks=[monitor],
)
trainer.fit(model)
trainer.fit(model, datamodule=datamodule)
assert log_histogram.call_count == (expected_calls * 2) # 2 tensors per log call


Expand Down
27 changes: 20 additions & 7 deletions tests/models/test_mnist_templates.py
@@ -1,19 +1,32 @@
import warnings

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.utilities.warnings import PossibleUserWarning

from pl_bolts.datamodules import MNISTDataModule
from pl_bolts.models import LitMNIST


def test_mnist(tmpdir, datadir):
seed_everything()
def test_mnist(tmpdir, datadir, catch_warnings):
warnings.filterwarnings(
"ignore",
message=".+does not have many workers which may be a bottleneck.+",
category=PossibleUserWarning,
)

seed_everything(1234)

model = LitMNIST(data_dir=datadir, num_workers=0)
datamodule = MNISTDataModule(data_dir=datadir, num_workers=0)
model = LitMNIST()
trainer = Trainer(
limit_train_batches=0.01,
limit_val_batches=0.01,
limit_train_batches=0.02,
limit_val_batches=0.02,
max_epochs=1,
limit_test_batches=0.01,
limit_test_batches=0.02,
default_root_dir=tmpdir,
log_every_n_steps=5,
accelerator="auto",
)
trainer.fit(model)
trainer.fit(model, datamodule=datamodule)
loss = trainer.callback_metrics["train_loss"]
assert loss <= 2.2, "mnist failed"

0 comments on commit ac98469

Please sign in to comment.