From 2a8870560b87d741b9efcd4267f55b30aa5db999 Mon Sep 17 00:00:00 2001 From: nateraw Date: Thu, 10 Sep 2020 13:55:32 -0600 Subject: [PATCH 01/14] :bug: make data dir kwarg instead of arg --- pl_bolts/datamodules/mnist_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 2adc03fc11..45409e0e42 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -11,7 +11,7 @@ class MNISTDataModule(LightningDataModule): def __init__( self, - data_dir: str, + data_dir: str = './', val_split: int = 5000, num_workers: int = 16, normalize: bool = False, From b4831d3f91e1bafbb7796f468a3461578d5085fa Mon Sep 17 00:00:00 2001 From: nateraw Date: Thu, 10 Sep 2020 13:56:09 -0600 Subject: [PATCH 02/14] :sparkles: surface AE up in init --- pl_bolts/models/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pl_bolts/models/__init__.py b/pl_bolts/models/__init__.py index 05f2533a33..178cb6f2ba 100644 --- a/pl_bolts/models/__init__.py +++ b/pl_bolts/models/__init__.py @@ -2,6 +2,7 @@ Collection of PyTorchLightning models """ +from pl_bolts.models.autoencoders.basic_ae.basic_ae_module import AE from pl_bolts.models.autoencoders.basic_vae.basic_vae_module import VAE from pl_bolts.models.mnist_module import LitMNIST from pl_bolts.models.regression import LinearRegression From f171bdb6d0a41d393086d471cc5218510115463c Mon Sep 17 00:00:00 2001 From: nateraw Date: Thu, 10 Sep 2020 13:56:46 -0600 Subject: [PATCH 03/14] :construction: wip --- .../autoencoders/basic_ae/basic_ae_module.py | 20 +++++----- .../autoencoders/basic_ae/components.py | 13 +++--- .../basic_vae/basic_vae_module.py | 40 +++++++++---------- 3 files changed, 37 insertions(+), 36 deletions(-) diff --git a/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py b/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py index 840f00a988..e59cbcf85d 100644 --- a/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py +++ b/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py @@ -13,7 +13,7 @@ class AE(LightningModule): def __init__( self, - datamodule: LightningDataModule = None, + # datamodule: LightningDataModule = None, input_channels=1, input_height=28, input_width=28, @@ -43,24 +43,24 @@ def __init__( self.save_hyperparameters() # link default data - if datamodule is None: - datamodule = MNISTDataModule(data_dir=self.hparams.data_dir, num_workers=self.hparams.num_workers) + # if datamodule is None: + # datamodule = MNISTDataModule(data_dir=self.hparams.data_dir, num_workers=self.hparams.num_workers) - self.datamodule = datamodule + # self.datamodule = datamodule - self.img_dim = self.datamodule.size() + # self.img_dim = self.datamodule.size() - self.encoder = self.init_encoder(self.hparams.hidden_dim, self.hparams.latent_dim, + self.encoder = self.init_encoder(self.hparams.hidden_dim, self.hparams.latent_dim, self.hparams.input_channels, self.hparams.input_width, self.hparams.input_height) self.decoder = self.init_decoder(self.hparams.hidden_dim, self.hparams.latent_dim) - def init_encoder(self, hidden_dim, latent_dim, input_width, input_height): - encoder = AEEncoder(hidden_dim, latent_dim, input_width, input_height) + def init_encoder(self, hidden_dim, latent_dim, input_channels, input_height, input_width): + encoder = AEEncoder(hidden_dim, latent_dim, input_channels, input_height, input_width) return encoder def init_decoder(self, hidden_dim, latent_dim): - c, h, w = self.img_dim - decoder = Decoder(hidden_dim, latent_dim, w, h, c) + # c, h, w = self.img_dim + decoder = Decoder(hidden_dim, latent_dim, self.hparams.input_width, self.hparams.input_height, self.hparams.input_channels) return decoder def forward(self, z): diff --git a/pl_bolts/models/autoencoders/basic_ae/components.py b/pl_bolts/models/autoencoders/basic_ae/components.py index 442012bcf8..a4bb4ec6ba 100644 --- a/pl_bolts/models/autoencoders/basic_ae/components.py +++ b/pl_bolts/models/autoencoders/basic_ae/components.py @@ -9,26 +9,27 @@ class AEEncoder(torch.nn.Module): get split into a mu and sigma vector """ - def __init__(self, hidden_dim, latent_dim, input_width, input_height): + def __init__(self, hidden_dim, latent_dim, input_channels, input_height, input_width): super().__init__() self.hidden_dim = hidden_dim self.latent_dim = latent_dim - self.input_width = input_width + self.input_channels = input_channels self.input_height = input_height + self.input_width = input_width - self.c1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) + self.c1 = nn.Conv2d(input_channels, 32, kernel_size=3, padding=1) self.c2 = nn.Conv2d(32, 32, kernel_size=3, padding=1) self.c3 = nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1) - conv_out_dim = self._calculate_output_dim(input_width, input_height) + conv_out_dim = self._calculate_output_dim(input_channels, input_width, input_height) self.fc1 = DenseBlock(conv_out_dim, hidden_dim) self.fc2 = DenseBlock(hidden_dim, hidden_dim) self.fc_z_out = nn.Linear(hidden_dim, latent_dim) - def _calculate_output_dim(self, input_width, input_height): - x = torch.rand(1, 1, input_width, input_height) + def _calculate_output_dim(self, input_channels, input_width, input_height): + x = torch.rand(1, input_channels, input_width, input_height) x = self.c3(self.c2(self.c1(x))) x = x.view(-1) return x.size(0) diff --git a/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py b/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py index 9c5cffd90d..032d83b510 100644 --- a/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py +++ b/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py @@ -61,11 +61,11 @@ def __init__( super().__init__() self.save_hyperparameters() - self.datamodule = datamodule - self.__set_pretrained_dims(pretrained) + #self.datamodule = datamodule + #self.__set_pretrained_dims(pretrained) # use mnist as the default module - self._set_default_datamodule(datamodule) + #self._set_default_datamodule(datamodule) # init actual model self.__init_system() @@ -77,23 +77,23 @@ def __init_system(self): self.encoder = self.init_encoder() self.decoder = self.init_decoder() - def __set_pretrained_dims(self, pretrained): - if pretrained == 'imagenet2012': - self.datamodule = ImagenetDataModule(data_dir=self.hparams.data_dir) - (self.hparams.input_channels, self.hparams.input_height, self.hparams.input_width) = self.datamodule.size() - - def _set_default_datamodule(self, datamodule): - # link default data - if datamodule is None: - datamodule = MNISTDataModule( - data_dir=self.hparams.data_dir, - num_workers=self.hparams.num_workers, - normalize=False - ) - self.datamodule = datamodule - self.img_dim = self.datamodule.size() - - (self.hparams.input_channels, self.hparams.input_height, self.hparams.input_width) = self.img_dim + # def __set_pretrained_dims(self, pretrained): + # if pretrained == 'imagenet2012': + # self.datamodule = ImagenetDataModule(data_dir=self.hparams.data_dir) + # (self.hparams.input_channels, self.hparams.input_height, self.hparams.input_width) = self.datamodule.size() + + # def _set_default_datamodule(self, datamodule): + # # link default data + # if datamodule is None: + # datamodule = MNISTDataModule( + # data_dir=self.hparams.data_dir, + # num_workers=self.hparams.num_workers, + # normalize=False + # ) + # self.datamodule = datamodule + # self.img_dim = self.datamodule.size() + + # (self.hparams.input_channels, self.hparams.input_height, self.hparams.input_width) = self.img_dim def load_pretrained(self, pretrained): available_weights = {'imagenet2012'} From f0f0a8e985a82c142aee332a923bbd3cc124bc80 Mon Sep 17 00:00:00 2001 From: nateraw Date: Thu, 10 Sep 2020 14:49:29 -0600 Subject: [PATCH 04/14] :construction: . --- .../autoencoders/basic_ae/basic_ae_module.py | 45 ++++++++++--------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py b/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py index e59cbcf85d..098bcdd1c8 100644 --- a/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py +++ b/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py @@ -4,7 +4,7 @@ from pytorch_lightning import LightningDataModule, LightningModule, Trainer from torch.nn import functional as F -from pl_bolts.datamodules import MNISTDataModule +from pl_bolts.datamodules import MNISTDataModule, CIFAR10DataModule, STL10DataModule, ImagenetDataModule from pl_bolts.models.autoencoders.basic_ae.components import AEEncoder from pl_bolts.models.autoencoders.basic_vae.components import Decoder @@ -13,16 +13,12 @@ class AE(LightningModule): def __init__( self, - # datamodule: LightningDataModule = None, - input_channels=1, - input_height=28, - input_width=28, + input_channels: int, + input_height: int, + input_width: int, latent_dim=32, - batch_size=32, hidden_dim=128, learning_rate=0.001, - num_workers=8, - data_dir='.', **kwargs ): """ @@ -127,27 +123,36 @@ def add_model_specific_args(parent_parser): help='itermediate layers dimension before embedding for default encoder/decoder') parser.add_argument('--latent_dim', type=int, default=32, help='dimension of latent variables z') - parser.add_argument('--input_width', type=int, default=28, - help='input image width - 28 for MNIST (must be even)') - parser.add_argument('--input_height', type=int, default=28, - help='input image height - 28 for MNIST (must be even)') - parser.add_argument('--batch_size', type=int, default=32) - parser.add_argument('--num_workers', type=int, default=8, help="num dataloader workers") parser.add_argument('--learning_rate', type=float, default=1e-3) - parser.add_argument('--data_dir', type=str, default='') return parser -def cli_main(): +def cli_main(args=None): + # cli_main() parser = ArgumentParser() + parser.add_argument('--dataset', default='mnist', type=str, help='mnist, cifar10, stl10, imagenet') + script_args, _ = parser.parse_known_args(args) + + if script_args.dataset == 'mnist': + dm_cls = MNISTDataModule + elif script_args.dataset == 'cifar10': + dm_cls = CIFAR10DataModule + elif script_args.dataset == 'stl10': + dm_cls = STL10DataModule + elif script_args.dataset == 'imagenet': + dm_cls = ImagenetDataModule + + parser = dm_cls.add_argparse_args(parser) parser = Trainer.add_argparse_args(parser) parser = AE.add_model_specific_args(parser) - args = parser.parse_args() + args = parser.parse_args(args) - ae = AE(**vars(args)) + dm = dm_cls.from_argparse_args(args) + model = AE(*dm.size(), **vars(args)) trainer = Trainer.from_argparse_args(args) - trainer.fit(ae) + trainer.fit(model, dm) + return dm, model, trainer if __name__ == '__main__': - cli_main() + dm, model, trainer = cli_main() From 9a6f9896a29eab23d3e79e82ea2150bc864bb9c5 Mon Sep 17 00:00:00 2001 From: nateraw Date: Thu, 10 Sep 2020 15:02:31 -0600 Subject: [PATCH 05/14] :construction: . --- .../basic_vae/basic_vae_module.py | 92 ++++++------------- 1 file changed, 27 insertions(+), 65 deletions(-) diff --git a/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py b/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py index 032d83b510..a31b542e8b 100644 --- a/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py +++ b/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py @@ -6,7 +6,7 @@ from torch import distributions from torch.nn import functional as F -from pl_bolts.datamodules import MNISTDataModule, ImagenetDataModule, STL10DataModule, BinaryMNISTDataModule +from pl_bolts.datamodules import MNISTDataModule, CIFAR10DataModule, ImagenetDataModule, STL10DataModule, BinaryMNISTDataModule from pl_bolts.models.autoencoders.basic_vae.components import Encoder, Decoder from pl_bolts.utils.pretrained_weights import load_pretrained from pl_bolts.utils import shaping @@ -16,16 +16,12 @@ class VAE(pl.LightningModule): def __init__( self, + input_channels: int, + input_height: int, + input_width: int, hidden_dim: int = 128, latent_dim: int = 32, - input_channels: int = 3, - input_width: int = 224, - input_height: int = 224, - batch_size: int = 32, learning_rate: float = 0.001, - data_dir: str = '.', - datamodule: pl.LightningDataModule = None, - num_workers: int = 8, pretrained: str = None, **kwargs ): @@ -60,12 +56,7 @@ def __init__( """ super().__init__() self.save_hyperparameters() - - #self.datamodule = datamodule - #self.__set_pretrained_dims(pretrained) - - # use mnist as the default module - #self._set_default_datamodule(datamodule) + self.img_dim = (input_channels, input_height, input_width) # init actual model self.__init_system() @@ -77,24 +68,6 @@ def __init_system(self): self.encoder = self.init_encoder() self.decoder = self.init_decoder() - # def __set_pretrained_dims(self, pretrained): - # if pretrained == 'imagenet2012': - # self.datamodule = ImagenetDataModule(data_dir=self.hparams.data_dir) - # (self.hparams.input_channels, self.hparams.input_height, self.hparams.input_width) = self.datamodule.size() - - # def _set_default_datamodule(self, datamodule): - # # link default data - # if datamodule is None: - # datamodule = MNISTDataModule( - # data_dir=self.hparams.data_dir, - # num_workers=self.hparams.num_workers, - # normalize=False - # ) - # self.datamodule = datamodule - # self.img_dim = self.datamodule.size() - - # (self.hparams.input_channels, self.hparams.input_height, self.hparams.input_width) = self.img_dim - def load_pretrained(self, pretrained): available_weights = {'imagenet2012'} @@ -247,51 +220,40 @@ def add_model_specific_args(parent_parser): help='itermediate layers dimension before embedding for default encoder/decoder') parser.add_argument('--latent_dim', type=int, default=4, help='dimension of latent variables z') - parser.add_argument('--input_width', type=int, default=224, - help='input width (used Imagenet downsampled size)') - parser.add_argument('--input_height', type=int, default=224, - help='input width (used Imagenet downsampled size)') - parser.add_argument('--input_channels', type=int, default=3, - help='number of input channels') - parser.add_argument('--batch_size', type=int, default=64) parser.add_argument('--pretrained', type=str, default=None) - parser.add_argument('--data_dir', type=str, default=os.getcwd()) - parser.add_argument('--num_workers', type=int, default=8, help="num dataloader workers") - parser.add_argument('--learning_rate', type=float, default=1e-3) return parser -def cli_main(): +def cli_main(args=None): from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler - from pl_bolts.datamodules import ImagenetDataModule - pl.seed_everything(1234) + # cli_main() parser = ArgumentParser() - parser.add_argument('--dataset', default='mnist', type=str, help='mnist, stl10, imagenet') - + parser.add_argument('--dataset', default='mnist', type=str, help='mnist, cifar10, stl10, imagenet') + script_args, _ = parser.parse_known_args(args) + + if script_args.dataset == 'mnist': + dm_cls = MNISTDataModule + elif script_args.dataset == 'cifar10': + dm_cls = CIFAR10DataModule + elif script_args.dataset == 'stl10': + dm_cls = STL10DataModule + elif script_args.dataset == 'imagenet': + dm_cls = ImagenetDataModule + + parser = dm_cls.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser) parser = VAE.add_model_specific_args(parser) - parser = ImagenetDataModule.add_argparse_args(parser) - parser = MNISTDataModule.add_argparse_args(parser) - args = parser.parse_args() - - # default is mnist - datamodule = None - if args.dataset == 'imagenet2012': - datamodule = ImagenetDataModule.from_argparse_args(args) - elif args.dataset == 'stl10': - datamodule = STL10DataModule.from_argparse_args(args) + args = parser.parse_args(args) + dm = dm_cls.from_argparse_args(args) + model = VAE(*dm.size(), **vars(args)) callbacks = [TensorboardGenerativeModelImageSampler(), LatentDimInterpolator(interpolate_epoch_interval=5)] - vae = VAE(**vars(args), datamodule=datamodule) - trainer = pl.Trainer.from_argparse_args( - args, - callbacks=callbacks, - progress_bar_refresh_rate=10, - ) - trainer.fit(vae) + trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, progress_bar_refresh_rate=20) + trainer.fit(model, dm) + return dm, model, trainer if __name__ == '__main__': - cli_main() + dm, model, trainer = cli_main() From f279d4a23e5b7189eff2af0c05a6d957a6d349be Mon Sep 17 00:00:00 2001 From: nateraw Date: Thu, 10 Sep 2020 15:10:16 -0600 Subject: [PATCH 06/14] :white_check_mark: update tests --- pl_bolts/datamodules/mnist_datamodule.py | 1 + pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py | 8 -------- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 45409e0e42..841a2dfa07 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -16,6 +16,7 @@ def __init__( num_workers: int = 16, normalize: bool = False, seed: int = 42, + batch_size: int = 32, *args, **kwargs, ): diff --git a/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py b/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py index 098bcdd1c8..93fcbcf8c1 100644 --- a/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py +++ b/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py @@ -38,14 +38,6 @@ def __init__( super().__init__() self.save_hyperparameters() - # link default data - # if datamodule is None: - # datamodule = MNISTDataModule(data_dir=self.hparams.data_dir, num_workers=self.hparams.num_workers) - - # self.datamodule = datamodule - - # self.img_dim = self.datamodule.size() - self.encoder = self.init_encoder(self.hparams.hidden_dim, self.hparams.latent_dim, self.hparams.input_channels, self.hparams.input_width, self.hparams.input_height) self.decoder = self.init_decoder(self.hparams.hidden_dim, self.hparams.latent_dim) From 1cd11bd2115269d98623c12ba06c7551a9236f1a Mon Sep 17 00:00:00 2001 From: nateraw Date: Thu, 10 Sep 2020 15:21:48 -0600 Subject: [PATCH 07/14] :white_check_mark: update tests --- tests/models/test_autoencoders.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/models/test_autoencoders.py b/tests/models/test_autoencoders.py index cf57a7a317..a9b13fec90 100644 --- a/tests/models/test_autoencoders.py +++ b/tests/models/test_autoencoders.py @@ -2,6 +2,7 @@ import torch from pytorch_lightning import seed_everything +from pl_bolts.datamodules import MNISTDataModule from pl_bolts.models.autoencoders import VAE, AE from pl_bolts.models.autoencoders.basic_ae import AEEncoder from pl_bolts.models.autoencoders.basic_vae import Encoder, Decoder @@ -9,11 +10,11 @@ def test_vae(tmpdir): seed_everything() - - model = VAE(data_dir=tmpdir, batch_size=2, num_workers=0) + dm = MNISTDataModule(batch_size=2, num_workers=0) + model = VAE(*dm.size()) trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir, deterministic=True) - trainer.fit(model) - results = trainer.test(model)[0] + trainer.fit(model, dm) + results = trainer.test(model, datamodule=dm)[0] loss = results['test_loss'] assert loss > 0, 'VAE failed' @@ -22,10 +23,11 @@ def test_vae(tmpdir): def test_ae(tmpdir): seed_everything() - model = AE(data_dir=tmpdir, batch_size=2) + dm = MNISTDataModule(batch_size=2, num_workers=0) + model = VAE(*dm.size()) trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir) - trainer.fit(model) - trainer.test(model) + trainer.fit(model, dm) + trainer.test(model, datamodule=dm) def test_basic_ae_encoder(tmpdir): @@ -37,7 +39,7 @@ def test_basic_ae_encoder(tmpdir): batch_size = 16 channels = 1 - encoder = AEEncoder(hidden_dim, latent_dim, width, height) + encoder = AEEncoder(hidden_dim, latent_dim, channels, width, height) x = torch.randn(batch_size, channels, width, height) z = encoder(x) From 33ce0330e4f8c8659111067aa25015f4530a9c47 Mon Sep 17 00:00:00 2001 From: nateraw Date: Thu, 10 Sep 2020 15:36:43 -0600 Subject: [PATCH 08/14] :white_check_mark: pytest is cute --- tests/models/test_autoencoders.py | 88 ++++++++++++++++++++++--------- 1 file changed, 64 insertions(+), 24 deletions(-) diff --git a/tests/models/test_autoencoders.py b/tests/models/test_autoencoders.py index a9b13fec90..3d4136e962 100644 --- a/tests/models/test_autoencoders.py +++ b/tests/models/test_autoencoders.py @@ -1,16 +1,28 @@ +import pytest import pytorch_lightning as pl import torch from pytorch_lightning import seed_everything -from pl_bolts.datamodules import MNISTDataModule +from pl_bolts.datamodules import MNISTDataModule, CIFAR10DataModule from pl_bolts.models.autoencoders import VAE, AE from pl_bolts.models.autoencoders.basic_ae import AEEncoder from pl_bolts.models.autoencoders.basic_vae import Encoder, Decoder -def test_vae(tmpdir): +@pytest.mark.parametrize( + "dm_cls", + [ + pytest.param( + MNISTDataModule, id='mnist' + ), + pytest.param( + CIFAR10DataModule, id='cifar10' + ), + ] +) +def test_vae(tmpdir, dm_cls): seed_everything() - dm = MNISTDataModule(batch_size=2, num_workers=0) + dm = dm_cls(batch_size=2, num_workers=0) model = VAE(*dm.size()) trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir, deterministic=True) trainer.fit(model, dm) @@ -19,42 +31,70 @@ def test_vae(tmpdir): assert loss > 0, 'VAE failed' - -def test_ae(tmpdir): +@pytest.mark.parametrize( + "dm_cls", + [ + pytest.param( + MNISTDataModule, id='mnist' + ), + pytest.param( + CIFAR10DataModule, id='cifar10' + ), + ] +) +def test_ae(tmpdir, dm_cls): seed_everything() - - dm = MNISTDataModule(batch_size=2, num_workers=0) + dm = dm_cls(batch_size=2, num_workers=0) model = VAE(*dm.size()) trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir) trainer.fit(model, dm) trainer.test(model, datamodule=dm) -def test_basic_ae_encoder(tmpdir): +@pytest.mark.parametrize( + "hidden_dim,latent_dim,batch_size,channels,height,width", + [ + pytest.param( + 128, 2, 16, 1, 28, 28, id='like-mnist-hidden-128-latent-2' + ), + pytest.param( + 128, 4, 16, 1, 28, 28, id='like-mnist-hidden-128-latent-4' + ), + pytest.param( + 64, 4, 16, 1, 28, 28, id='like-mnist-hidden-64-latent-4' + ), + pytest.param( + 128, 2, 16, 3, 32, 32, id='like-cifar10-hidden-128-latent-2' + ), + ] +) +def test_basic_ae_encoder(tmpdir, hidden_dim, latent_dim, batch_size, channels, height, width): seed_everything() - - hidden_dim = 128 - latent_dim = 2 - width = height = 28 - batch_size = 16 - channels = 1 - encoder = AEEncoder(hidden_dim, latent_dim, channels, width, height) x = torch.randn(batch_size, channels, width, height) z = encoder(x) - assert z.shape == (batch_size, latent_dim) -def test_basic_vae_components(tmpdir): +@pytest.mark.parametrize( + "hidden_dim,latent_dim,batch_size,channels,height,width", + [ + pytest.param( + 128, 2, 16, 1, 28, 28, id='like-mnist-hidden-128-latent-2' + ), + pytest.param( + 128, 4, 16, 1, 28, 28, id='like-mnist-hidden-128-latent-4' + ), + pytest.param( + 64, 4, 16, 1, 28, 28, id='like-mnist-hidden-64-latent-4' + ), + pytest.param( + 128, 2, 16, 3, 32, 32, id='like-cifar10-hidden-128-latent-2' + ), + ] +) +def test_basic_vae_components(tmpdir, hidden_dim, latent_dim, batch_size, channels, height, width): seed_everything() - - hidden_dim = 128 - latent_dim = 2 - width = height = 28 - batch_size = 16 - channels = 1 - enc = Encoder(hidden_dim, latent_dim, channels, width, height) x = torch.randn(batch_size, channels, width, height) mu, sigma = enc(x) From 499c4b04eb252dfeceedcd1ed57efaa1b6ee8b8a Mon Sep 17 00:00:00 2001 From: nateraw Date: Thu, 10 Sep 2020 15:49:57 -0600 Subject: [PATCH 09/14] :lipstick: apply style --- pl_bolts/datamodules/mnist_datamodule.py | 46 ++++------ pl_bolts/models/__init__.py | 3 +- .../autoencoders/basic_ae/basic_ae_module.py | 84 +++++++++--------- .../basic_vae/basic_vae_module.py | 88 +++++++++---------- tests/models/test_autoencoders.py | 67 ++++---------- 5 files changed, 123 insertions(+), 165 deletions(-) diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 841a2dfa07..ab8a60f32d 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -7,18 +7,18 @@ class MNISTDataModule(LightningDataModule): - name = 'mnist' + name = "mnist" def __init__( - self, - data_dir: str = './', - val_split: int = 5000, - num_workers: int = 16, - normalize: bool = False, - seed: int = 42, - batch_size: int = 32, - *args, - **kwargs, + self, + data_dir: str = "./", + val_split: int = 5000, + num_workers: int = 16, + normalize: bool = False, + seed: int = 42, + batch_size: int = 32, + *args, + **kwargs, ): """ .. figure:: https://miro.medium.com/max/744/1*AO2rIhzRYzFVQlFLx9DM9A.png @@ -88,9 +88,7 @@ def train_dataloader(self, batch_size=32, transforms=None): dataset = MNIST(self.data_dir, train=True, download=False, transform=transforms) train_length = len(dataset) dataset_train, _ = random_split( - dataset, - [train_length - self.val_split, self.val_split], - generator=torch.Generator().manual_seed(self.seed) + dataset, [train_length - self.val_split, self.val_split], generator=torch.Generator().manual_seed(self.seed) ) loader = DataLoader( dataset_train, @@ -98,7 +96,7 @@ def train_dataloader(self, batch_size=32, transforms=None): shuffle=True, num_workers=self.num_workers, drop_last=True, - pin_memory=True + pin_memory=True, ) return loader @@ -114,9 +112,7 @@ def val_dataloader(self, batch_size=32, transforms=None): dataset = MNIST(self.data_dir, train=True, download=True, transform=transforms) train_length = len(dataset) _, dataset_val = random_split( - dataset, - [train_length - self.val_split, self.val_split], - generator=torch.Generator().manual_seed(self.seed) + dataset, [train_length - self.val_split, self.val_split], generator=torch.Generator().manual_seed(self.seed) ) loader = DataLoader( dataset_val, @@ -124,7 +120,7 @@ def val_dataloader(self, batch_size=32, transforms=None): shuffle=False, num_workers=self.num_workers, drop_last=True, - pin_memory=True + pin_memory=True, ) return loader @@ -140,21 +136,15 @@ def test_dataloader(self, batch_size=32, transforms=None): dataset = MNIST(self.data_dir, train=False, download=False, transform=transforms) loader = DataLoader( - dataset, - batch_size=batch_size, - shuffle=False, - num_workers=self.num_workers, - drop_last=True, - pin_memory=True + dataset, batch_size=batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True, pin_memory=True ) return loader def _default_transforms(self): if self.normalize: - mnist_transforms = transform_lib.Compose([ - transform_lib.ToTensor(), - transform_lib.Normalize(mean=(0.5,), std=(0.5,)), - ]) + mnist_transforms = transform_lib.Compose( + [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,)),] + ) else: mnist_transforms = transform_lib.ToTensor() diff --git a/pl_bolts/models/__init__.py b/pl_bolts/models/__init__.py index 178cb6f2ba..caf9de1cba 100644 --- a/pl_bolts/models/__init__.py +++ b/pl_bolts/models/__init__.py @@ -5,7 +5,6 @@ from pl_bolts.models.autoencoders.basic_ae.basic_ae_module import AE from pl_bolts.models.autoencoders.basic_vae.basic_vae_module import VAE from pl_bolts.models.mnist_module import LitMNIST -from pl_bolts.models.regression import LinearRegression -from pl_bolts.models.regression import LogisticRegression +from pl_bolts.models.regression import LinearRegression, LogisticRegression from pl_bolts.models.vision import PixelCNN from pl_bolts.models.vision.image_gpt.igpt_module import GPT2, ImageGPT diff --git a/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py b/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py index 93fcbcf8c1..7803042a92 100644 --- a/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py +++ b/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py @@ -4,22 +4,22 @@ from pytorch_lightning import LightningDataModule, LightningModule, Trainer from torch.nn import functional as F -from pl_bolts.datamodules import MNISTDataModule, CIFAR10DataModule, STL10DataModule, ImagenetDataModule +from pl_bolts.datamodules import (CIFAR10DataModule, ImagenetDataModule, + MNISTDataModule, STL10DataModule) from pl_bolts.models.autoencoders.basic_ae.components import AEEncoder from pl_bolts.models.autoencoders.basic_vae.components import Decoder class AE(LightningModule): - def __init__( - self, - input_channels: int, - input_height: int, - input_width: int, - latent_dim=32, - hidden_dim=128, - learning_rate=0.001, - **kwargs + self, + input_channels: int, + input_height: int, + input_width: int, + latent_dim=32, + hidden_dim=128, + learning_rate=0.001, + **kwargs ): """ Args: @@ -38,8 +38,13 @@ def __init__( super().__init__() self.save_hyperparameters() - self.encoder = self.init_encoder(self.hparams.hidden_dim, self.hparams.latent_dim, self.hparams.input_channels, - self.hparams.input_width, self.hparams.input_height) + self.encoder = self.init_encoder( + self.hparams.hidden_dim, + self.hparams.latent_dim, + self.hparams.input_channels, + self.hparams.input_width, + self.hparams.input_height, + ) self.decoder = self.init_decoder(self.hparams.hidden_dim, self.hparams.latent_dim) def init_encoder(self, hidden_dim, latent_dim, input_channels, input_height, input_width): @@ -48,7 +53,9 @@ def init_encoder(self, hidden_dim, latent_dim, input_channels, input_height, inp def init_decoder(self, hidden_dim, latent_dim): # c, h, w = self.img_dim - decoder = Decoder(hidden_dim, latent_dim, self.hparams.input_width, self.hparams.input_height, self.hparams.input_channels) + decoder = Decoder( + hidden_dim, latent_dim, self.hparams.input_width, self.hparams.input_height, self.hparams.input_channels + ) return decoder def forward(self, z): @@ -66,44 +73,38 @@ def training_step(self, batch, batch_idx): loss = self._run_step(batch) tensorboard_logs = { - 'mse_loss': loss, + "mse_loss": loss, } - return {'loss': loss, 'log': tensorboard_logs} + return {"loss": loss, "log": tensorboard_logs} def validation_step(self, batch, batch_idx): loss = self._run_step(batch) return { - 'val_loss': loss, + "val_loss": loss, } def validation_epoch_end(self, outputs): - avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() + avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() - tensorboard_logs = {'mse_loss': avg_loss} + tensorboard_logs = {"mse_loss": avg_loss} - return { - 'val_loss': avg_loss, - 'log': tensorboard_logs - } + return {"val_loss": avg_loss, "log": tensorboard_logs} def test_step(self, batch, batch_idx): loss = self._run_step(batch) return { - 'test_loss': loss, + "test_loss": loss, } def test_epoch_end(self, outputs): - avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() + avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean() - tensorboard_logs = {'mse_loss': avg_loss} + tensorboard_logs = {"mse_loss": avg_loss} - return { - 'test_loss': avg_loss, - 'log': tensorboard_logs - } + return {"test_loss": avg_loss, "log": tensorboard_logs} def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) @@ -111,27 +112,30 @@ def configure_optimizers(self): @staticmethod def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False) - parser.add_argument('--hidden_dim', type=int, default=128, - help='itermediate layers dimension before embedding for default encoder/decoder') - parser.add_argument('--latent_dim', type=int, default=32, - help='dimension of latent variables z') - parser.add_argument('--learning_rate', type=float, default=1e-3) + parser.add_argument( + "--hidden_dim", + type=int, + default=128, + help="itermediate layers dimension before embedding for default encoder/decoder", + ) + parser.add_argument("--latent_dim", type=int, default=32, help="dimension of latent variables z") + parser.add_argument("--learning_rate", type=float, default=1e-3) return parser def cli_main(args=None): # cli_main() parser = ArgumentParser() - parser.add_argument('--dataset', default='mnist', type=str, help='mnist, cifar10, stl10, imagenet') + parser.add_argument("--dataset", default="mnist", type=str, help="mnist, cifar10, stl10, imagenet") script_args, _ = parser.parse_known_args(args) - if script_args.dataset == 'mnist': + if script_args.dataset == "mnist": dm_cls = MNISTDataModule - elif script_args.dataset == 'cifar10': + elif script_args.dataset == "cifar10": dm_cls = CIFAR10DataModule - elif script_args.dataset == 'stl10': + elif script_args.dataset == "stl10": dm_cls = STL10DataModule - elif script_args.dataset == 'imagenet': + elif script_args.dataset == "imagenet": dm_cls = ImagenetDataModule parser = dm_cls.add_argparse_args(parser) @@ -146,5 +150,5 @@ def cli_main(args=None): return dm, model, trainer -if __name__ == '__main__': +if __name__ == "__main__": dm, model, trainer = cli_main() diff --git a/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py b/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py index a31b542e8b..ca55f6f8f3 100644 --- a/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py +++ b/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py @@ -1,29 +1,30 @@ import os from argparse import ArgumentParser -import torch import pytorch_lightning as pl +import torch from torch import distributions from torch.nn import functional as F -from pl_bolts.datamodules import MNISTDataModule, CIFAR10DataModule, ImagenetDataModule, STL10DataModule, BinaryMNISTDataModule -from pl_bolts.models.autoencoders.basic_vae.components import Encoder, Decoder -from pl_bolts.utils.pretrained_weights import load_pretrained +from pl_bolts.datamodules import (BinaryMNISTDataModule, CIFAR10DataModule, + ImagenetDataModule, MNISTDataModule, + STL10DataModule) +from pl_bolts.models.autoencoders.basic_vae.components import Decoder, Encoder from pl_bolts.utils import shaping +from pl_bolts.utils.pretrained_weights import load_pretrained class VAE(pl.LightningModule): - def __init__( - self, - input_channels: int, - input_height: int, - input_width: int, - hidden_dim: int = 128, - latent_dim: int = 32, - learning_rate: float = 0.001, - pretrained: str = None, - **kwargs + self, + input_channels: int, + input_height: int, + input_width: int, + hidden_dim: int = 128, + latent_dim: int = 32, + learning_rate: float = 0.001, + pretrained: str = None, + **kwargs ): """ Standard VAE with Gaussian Prior and approx posterior. @@ -69,10 +70,10 @@ def __init_system(self): self.decoder = self.init_decoder() def load_pretrained(self, pretrained): - available_weights = {'imagenet2012'} + available_weights = {"imagenet2012"} if pretrained in available_weights: - weights_name = f'vae-{pretrained}' + weights_name = f"vae-{pretrained}" load_pretrained(self, weights_name) def init_encoder(self): @@ -81,7 +82,7 @@ def init_encoder(self): self.hparams.latent_dim, self.hparams.input_channels, self.hparams.input_width, - self.hparams.input_height + self.hparams.input_height, ) return encoder @@ -91,7 +92,7 @@ def init_decoder(self): self.hparams.latent_dim, self.hparams.input_width, self.hparams.input_height, - self.hparams.input_channels + self.hparams.input_channels, ) return decoder @@ -130,7 +131,7 @@ def elbo_loss(self, x, P, Q, num_samples): x = shaping.tile(x.unsqueeze(1), 1, num_samples) pxz = torch.sigmoid(pxz) - recon_loss = F.binary_cross_entropy(pxz, x, reduction='none') + recon_loss = F.binary_cross_entropy(pxz, x, reduction="none") # sum across dimensions because sum of log probabilities of iid univariate gaussians is the same as # multivariate gaussian @@ -183,31 +184,23 @@ def _run_step(self, batch): def training_step(self, batch, batch_idx): loss, recon_loss, kl_div, pxz = self._run_step(batch) result = pl.TrainResult(loss) - result.log_dict({ - 'train_elbo_loss': loss, - 'train_recon_loss': recon_loss, - 'train_kl_loss': kl_div - }) + result.log_dict({"train_elbo_loss": loss, "train_recon_loss": recon_loss, "train_kl_loss": kl_div}) return result def validation_step(self, batch, batch_idx): loss, recon_loss, kl_div, pxz = self._run_step(batch) result = pl.EvalResult(loss, checkpoint_on=loss) - result.log_dict({ - 'val_loss': loss, - 'val_recon_loss': recon_loss, - 'val_kl_div': kl_div, - }) + result.log_dict( + {"val_loss": loss, "val_recon_loss": recon_loss, "val_kl_div": kl_div,} + ) return result def test_step(self, batch, batch_idx): loss, recon_loss, kl_div, pxz = self._run_step(batch) result = pl.EvalResult(loss) - result.log_dict({ - 'test_loss': loss, - 'test_recon_loss': recon_loss, - 'test_kl_div': kl_div, - }) + result.log_dict( + {"test_loss": loss, "test_recon_loss": recon_loss, "test_kl_div": kl_div,} + ) return result def configure_optimizers(self): @@ -216,12 +209,15 @@ def configure_optimizers(self): @staticmethod def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False) - parser.add_argument('--hidden_dim', type=int, default=128, - help='itermediate layers dimension before embedding for default encoder/decoder') - parser.add_argument('--latent_dim', type=int, default=4, - help='dimension of latent variables z') - parser.add_argument('--pretrained', type=str, default=None) - parser.add_argument('--learning_rate', type=float, default=1e-3) + parser.add_argument( + "--hidden_dim", + type=int, + default=128, + help="itermediate layers dimension before embedding for default encoder/decoder", + ) + parser.add_argument("--latent_dim", type=int, default=4, help="dimension of latent variables z") + parser.add_argument("--pretrained", type=str, default=None) + parser.add_argument("--learning_rate", type=float, default=1e-3) return parser @@ -230,16 +226,16 @@ def cli_main(args=None): # cli_main() parser = ArgumentParser() - parser.add_argument('--dataset', default='mnist', type=str, help='mnist, cifar10, stl10, imagenet') + parser.add_argument("--dataset", default="mnist", type=str, help="mnist, cifar10, stl10, imagenet") script_args, _ = parser.parse_known_args(args) - if script_args.dataset == 'mnist': + if script_args.dataset == "mnist": dm_cls = MNISTDataModule - elif script_args.dataset == 'cifar10': + elif script_args.dataset == "cifar10": dm_cls = CIFAR10DataModule - elif script_args.dataset == 'stl10': + elif script_args.dataset == "stl10": dm_cls = STL10DataModule - elif script_args.dataset == 'imagenet': + elif script_args.dataset == "imagenet": dm_cls = ImagenetDataModule parser = dm_cls.add_argparse_args(parser) @@ -255,5 +251,5 @@ def cli_main(args=None): return dm, model, trainer -if __name__ == '__main__': +if __name__ == "__main__": dm, model, trainer = cli_main() diff --git a/tests/models/test_autoencoders.py b/tests/models/test_autoencoders.py index 3d4136e962..6312c82b36 100644 --- a/tests/models/test_autoencoders.py +++ b/tests/models/test_autoencoders.py @@ -3,22 +3,14 @@ import torch from pytorch_lightning import seed_everything -from pl_bolts.datamodules import MNISTDataModule, CIFAR10DataModule -from pl_bolts.models.autoencoders import VAE, AE +from pl_bolts.datamodules import CIFAR10DataModule, MNISTDataModule +from pl_bolts.models.autoencoders import AE, VAE from pl_bolts.models.autoencoders.basic_ae import AEEncoder -from pl_bolts.models.autoencoders.basic_vae import Encoder, Decoder +from pl_bolts.models.autoencoders.basic_vae import Decoder, Encoder @pytest.mark.parametrize( - "dm_cls", - [ - pytest.param( - MNISTDataModule, id='mnist' - ), - pytest.param( - CIFAR10DataModule, id='cifar10' - ), - ] + "dm_cls", [pytest.param(MNISTDataModule, id="mnist"), pytest.param(CIFAR10DataModule, id="cifar10"),] ) def test_vae(tmpdir, dm_cls): seed_everything() @@ -27,20 +19,13 @@ def test_vae(tmpdir, dm_cls): trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir, deterministic=True) trainer.fit(model, dm) results = trainer.test(model, datamodule=dm)[0] - loss = results['test_loss'] + loss = results["test_loss"] + + assert loss > 0, "VAE failed" - assert loss > 0, 'VAE failed' @pytest.mark.parametrize( - "dm_cls", - [ - pytest.param( - MNISTDataModule, id='mnist' - ), - pytest.param( - CIFAR10DataModule, id='cifar10' - ), - ] + "dm_cls", [pytest.param(MNISTDataModule, id="mnist"), pytest.param(CIFAR10DataModule, id="cifar10"),] ) def test_ae(tmpdir, dm_cls): seed_everything() @@ -54,19 +39,11 @@ def test_ae(tmpdir, dm_cls): @pytest.mark.parametrize( "hidden_dim,latent_dim,batch_size,channels,height,width", [ - pytest.param( - 128, 2, 16, 1, 28, 28, id='like-mnist-hidden-128-latent-2' - ), - pytest.param( - 128, 4, 16, 1, 28, 28, id='like-mnist-hidden-128-latent-4' - ), - pytest.param( - 64, 4, 16, 1, 28, 28, id='like-mnist-hidden-64-latent-4' - ), - pytest.param( - 128, 2, 16, 3, 32, 32, id='like-cifar10-hidden-128-latent-2' - ), - ] + pytest.param(128, 2, 16, 1, 28, 28, id="like-mnist-hidden-128-latent-2"), + pytest.param(128, 4, 16, 1, 28, 28, id="like-mnist-hidden-128-latent-4"), + pytest.param(64, 4, 16, 1, 28, 28, id="like-mnist-hidden-64-latent-4"), + pytest.param(128, 2, 16, 3, 32, 32, id="like-cifar10-hidden-128-latent-2"), + ], ) def test_basic_ae_encoder(tmpdir, hidden_dim, latent_dim, batch_size, channels, height, width): seed_everything() @@ -79,19 +56,11 @@ def test_basic_ae_encoder(tmpdir, hidden_dim, latent_dim, batch_size, channels, @pytest.mark.parametrize( "hidden_dim,latent_dim,batch_size,channels,height,width", [ - pytest.param( - 128, 2, 16, 1, 28, 28, id='like-mnist-hidden-128-latent-2' - ), - pytest.param( - 128, 4, 16, 1, 28, 28, id='like-mnist-hidden-128-latent-4' - ), - pytest.param( - 64, 4, 16, 1, 28, 28, id='like-mnist-hidden-64-latent-4' - ), - pytest.param( - 128, 2, 16, 3, 32, 32, id='like-cifar10-hidden-128-latent-2' - ), - ] + pytest.param(128, 2, 16, 1, 28, 28, id="like-mnist-hidden-128-latent-2"), + pytest.param(128, 4, 16, 1, 28, 28, id="like-mnist-hidden-128-latent-4"), + pytest.param(64, 4, 16, 1, 28, 28, id="like-mnist-hidden-64-latent-4"), + pytest.param(128, 2, 16, 3, 32, 32, id="like-cifar10-hidden-128-latent-2"), + ], ) def test_basic_vae_components(tmpdir, hidden_dim, latent_dim, batch_size, channels, height, width): seed_everything() From 0ccaf69cd65b2aaf9f01f70e18e5eacd6b1b6743 Mon Sep 17 00:00:00 2001 From: nateraw Date: Thu, 10 Sep 2020 15:54:12 -0600 Subject: [PATCH 10/14] :lipstick: style --- tests/models/test_autoencoders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_autoencoders.py b/tests/models/test_autoencoders.py index 6312c82b36..57333ea82e 100644 --- a/tests/models/test_autoencoders.py +++ b/tests/models/test_autoencoders.py @@ -10,7 +10,7 @@ @pytest.mark.parametrize( - "dm_cls", [pytest.param(MNISTDataModule, id="mnist"), pytest.param(CIFAR10DataModule, id="cifar10"),] + "dm_cls", [pytest.param(MNISTDataModule, id="mnist"), pytest.param(CIFAR10DataModule, id="cifar10")] ) def test_vae(tmpdir, dm_cls): seed_everything() @@ -25,7 +25,7 @@ def test_vae(tmpdir, dm_cls): @pytest.mark.parametrize( - "dm_cls", [pytest.param(MNISTDataModule, id="mnist"), pytest.param(CIFAR10DataModule, id="cifar10"),] + "dm_cls", [pytest.param(MNISTDataModule, id="mnist"), pytest.param(CIFAR10DataModule, id="cifar10")] ) def test_ae(tmpdir, dm_cls): seed_everything() From f092fe9f57e9ee733a958685413c06c2953d5226 Mon Sep 17 00:00:00 2001 From: nateraw Date: Thu, 10 Sep 2020 15:55:32 -0600 Subject: [PATCH 11/14] :lipstick: . --- pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py b/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py index ca55f6f8f3..2aad593ab3 100644 --- a/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py +++ b/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py @@ -191,7 +191,7 @@ def validation_step(self, batch, batch_idx): loss, recon_loss, kl_div, pxz = self._run_step(batch) result = pl.EvalResult(loss, checkpoint_on=loss) result.log_dict( - {"val_loss": loss, "val_recon_loss": recon_loss, "val_kl_div": kl_div,} + {"val_loss": loss, "val_recon_loss": recon_loss, "val_kl_div": kl_div} ) return result @@ -199,7 +199,7 @@ def test_step(self, batch, batch_idx): loss, recon_loss, kl_div, pxz = self._run_step(batch) result = pl.EvalResult(loss) result.log_dict( - {"test_loss": loss, "test_recon_loss": recon_loss, "test_kl_div": kl_div,} + {"test_loss": loss, "test_recon_loss": recon_loss, "test_kl_div": kl_div} ) return result From db5de3e15f0bce64e15a4ff9052dc41734d3ed92 Mon Sep 17 00:00:00 2001 From: nateraw Date: Thu, 10 Sep 2020 15:56:00 -0600 Subject: [PATCH 12/14] :lipstick: . --- pl_bolts/datamodules/mnist_datamodule.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index ab8a60f32d..e22e2c8a07 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -143,7 +143,7 @@ def test_dataloader(self, batch_size=32, transforms=None): def _default_transforms(self): if self.normalize: mnist_transforms = transform_lib.Compose( - [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,)),] + [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] ) else: mnist_transforms = transform_lib.ToTensor() From ddcf680f395055264ce97ca58575f35d350af460 Mon Sep 17 00:00:00 2001 From: nateraw Date: Thu, 10 Sep 2020 16:11:46 -0600 Subject: [PATCH 13/14] :white_check_mark: add tests --- tests/models/test_executable_scripts.py | 42 +++++++++++++++++++++---- 1 file changed, 36 insertions(+), 6 deletions(-) diff --git a/tests/models/test_executable_scripts.py b/tests/models/test_executable_scripts.py index f381a90005..7923357b6c 100644 --- a/tests/models/test_executable_scripts.py +++ b/tests/models/test_executable_scripts.py @@ -15,14 +15,44 @@ def test_cli_basic_gan(cli_args): cli_main() -@pytest.mark.parametrize('cli_args', ['--max_epochs 1' - ' --limit_train_batches 3' - ' --limit_val_batches 3' - ' --batch_size 3']) -def test_cli_basic_vae(cli_args): +@pytest.mark.parametrize( + "dataset_name", [ + pytest.param('mnist', id="mnist"), + pytest.param('cifar10', id="cifar10") + ] +) +def test_cli_basic_vae(dataset_name): from pl_bolts.models.autoencoders.basic_vae.basic_vae_module import cli_main - cli_args = cli_args.split(' ') if cli_args else [] + cli_args = f""" + --dataset {dataset_name} + --max_epochs 1 + --limit_train_batches 3 + --limit_val_batches 3 + --batch_size 3 + """.strip().split() + + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): + cli_main() + + +@pytest.mark.parametrize( + "dataset_name", [ + pytest.param('mnist', id="mnist"), + pytest.param('cifar10', id="cifar10") + ] +) +def test_cli_basic_ae(dataset_name): + from pl_bolts.models.autoencoders.basic_ae.basic_ae_module import cli_main + + cli_args = f""" + --dataset {dataset_name} + --max_epochs 1 + --limit_train_batches 3 + --limit_val_batches 3 + --batch_size 3 + """.strip().split() + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main() From 7fb2efe683a5926925c8460ff46100436704db20 Mon Sep 17 00:00:00 2001 From: Nathan Raw Date: Thu, 10 Sep 2020 17:12:59 -0600 Subject: [PATCH 14/14] Apply suggestions from code review Co-authored-by: Jirka Borovec --- pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py b/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py index 7803042a92..5e38a6e9c7 100644 --- a/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py +++ b/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py @@ -124,7 +124,6 @@ def add_model_specific_args(parent_parser): def cli_main(args=None): - # cli_main() parser = ArgumentParser() parser.add_argument("--dataset", default="mnist", type=str, help="mnist, cifar10, stl10, imagenet") script_args, _ = parser.parse_known_args(args)