diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index 2adc03fc11..e22e2c8a07 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -7,17 +7,18 @@ class MNISTDataModule(LightningDataModule): - name = 'mnist' + name = "mnist" def __init__( - self, - data_dir: str, - val_split: int = 5000, - num_workers: int = 16, - normalize: bool = False, - seed: int = 42, - *args, - **kwargs, + self, + data_dir: str = "./", + val_split: int = 5000, + num_workers: int = 16, + normalize: bool = False, + seed: int = 42, + batch_size: int = 32, + *args, + **kwargs, ): """ .. figure:: https://miro.medium.com/max/744/1*AO2rIhzRYzFVQlFLx9DM9A.png @@ -87,9 +88,7 @@ def train_dataloader(self, batch_size=32, transforms=None): dataset = MNIST(self.data_dir, train=True, download=False, transform=transforms) train_length = len(dataset) dataset_train, _ = random_split( - dataset, - [train_length - self.val_split, self.val_split], - generator=torch.Generator().manual_seed(self.seed) + dataset, [train_length - self.val_split, self.val_split], generator=torch.Generator().manual_seed(self.seed) ) loader = DataLoader( dataset_train, @@ -97,7 +96,7 @@ def train_dataloader(self, batch_size=32, transforms=None): shuffle=True, num_workers=self.num_workers, drop_last=True, - pin_memory=True + pin_memory=True, ) return loader @@ -113,9 +112,7 @@ def val_dataloader(self, batch_size=32, transforms=None): dataset = MNIST(self.data_dir, train=True, download=True, transform=transforms) train_length = len(dataset) _, dataset_val = random_split( - dataset, - [train_length - self.val_split, self.val_split], - generator=torch.Generator().manual_seed(self.seed) + dataset, [train_length - self.val_split, self.val_split], generator=torch.Generator().manual_seed(self.seed) ) loader = DataLoader( dataset_val, @@ -123,7 +120,7 @@ def val_dataloader(self, batch_size=32, transforms=None): shuffle=False, num_workers=self.num_workers, drop_last=True, - pin_memory=True + pin_memory=True, ) return loader @@ -139,21 +136,15 @@ def test_dataloader(self, batch_size=32, transforms=None): dataset = MNIST(self.data_dir, train=False, download=False, transform=transforms) loader = DataLoader( - dataset, - batch_size=batch_size, - shuffle=False, - num_workers=self.num_workers, - drop_last=True, - pin_memory=True + dataset, batch_size=batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True, pin_memory=True ) return loader def _default_transforms(self): if self.normalize: - mnist_transforms = transform_lib.Compose([ - transform_lib.ToTensor(), - transform_lib.Normalize(mean=(0.5,), std=(0.5,)), - ]) + mnist_transforms = transform_lib.Compose( + [transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))] + ) else: mnist_transforms = transform_lib.ToTensor() diff --git a/pl_bolts/models/__init__.py b/pl_bolts/models/__init__.py index 05f2533a33..caf9de1cba 100644 --- a/pl_bolts/models/__init__.py +++ b/pl_bolts/models/__init__.py @@ -2,9 +2,9 @@ Collection of PyTorchLightning models """ +from pl_bolts.models.autoencoders.basic_ae.basic_ae_module import AE from pl_bolts.models.autoencoders.basic_vae.basic_vae_module import VAE from pl_bolts.models.mnist_module import LitMNIST -from pl_bolts.models.regression import LinearRegression -from pl_bolts.models.regression import LogisticRegression +from pl_bolts.models.regression import LinearRegression, LogisticRegression from pl_bolts.models.vision import PixelCNN from pl_bolts.models.vision.image_gpt.igpt_module import GPT2, ImageGPT diff --git a/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py b/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py index 840f00a988..5e38a6e9c7 100644 --- a/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py +++ b/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py @@ -4,26 +4,22 @@ from pytorch_lightning import LightningDataModule, LightningModule, Trainer from torch.nn import functional as F -from pl_bolts.datamodules import MNISTDataModule +from pl_bolts.datamodules import (CIFAR10DataModule, ImagenetDataModule, + MNISTDataModule, STL10DataModule) from pl_bolts.models.autoencoders.basic_ae.components import AEEncoder from pl_bolts.models.autoencoders.basic_vae.components import Decoder class AE(LightningModule): - def __init__( - self, - datamodule: LightningDataModule = None, - input_channels=1, - input_height=28, - input_width=28, - latent_dim=32, - batch_size=32, - hidden_dim=128, - learning_rate=0.001, - num_workers=8, - data_dir='.', - **kwargs + self, + input_channels: int, + input_height: int, + input_width: int, + latent_dim=32, + hidden_dim=128, + learning_rate=0.001, + **kwargs ): """ Args: @@ -42,25 +38,24 @@ def __init__( super().__init__() self.save_hyperparameters() - # link default data - if datamodule is None: - datamodule = MNISTDataModule(data_dir=self.hparams.data_dir, num_workers=self.hparams.num_workers) - - self.datamodule = datamodule - - self.img_dim = self.datamodule.size() - - self.encoder = self.init_encoder(self.hparams.hidden_dim, self.hparams.latent_dim, - self.hparams.input_width, self.hparams.input_height) + self.encoder = self.init_encoder( + self.hparams.hidden_dim, + self.hparams.latent_dim, + self.hparams.input_channels, + self.hparams.input_width, + self.hparams.input_height, + ) self.decoder = self.init_decoder(self.hparams.hidden_dim, self.hparams.latent_dim) - def init_encoder(self, hidden_dim, latent_dim, input_width, input_height): - encoder = AEEncoder(hidden_dim, latent_dim, input_width, input_height) + def init_encoder(self, hidden_dim, latent_dim, input_channels, input_height, input_width): + encoder = AEEncoder(hidden_dim, latent_dim, input_channels, input_height, input_width) return encoder def init_decoder(self, hidden_dim, latent_dim): - c, h, w = self.img_dim - decoder = Decoder(hidden_dim, latent_dim, w, h, c) + # c, h, w = self.img_dim + decoder = Decoder( + hidden_dim, latent_dim, self.hparams.input_width, self.hparams.input_height, self.hparams.input_channels + ) return decoder def forward(self, z): @@ -78,44 +73,38 @@ def training_step(self, batch, batch_idx): loss = self._run_step(batch) tensorboard_logs = { - 'mse_loss': loss, + "mse_loss": loss, } - return {'loss': loss, 'log': tensorboard_logs} + return {"loss": loss, "log": tensorboard_logs} def validation_step(self, batch, batch_idx): loss = self._run_step(batch) return { - 'val_loss': loss, + "val_loss": loss, } def validation_epoch_end(self, outputs): - avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() + avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean() - tensorboard_logs = {'mse_loss': avg_loss} + tensorboard_logs = {"mse_loss": avg_loss} - return { - 'val_loss': avg_loss, - 'log': tensorboard_logs - } + return {"val_loss": avg_loss, "log": tensorboard_logs} def test_step(self, batch, batch_idx): loss = self._run_step(batch) return { - 'test_loss': loss, + "test_loss": loss, } def test_epoch_end(self, outputs): - avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() + avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean() - tensorboard_logs = {'mse_loss': avg_loss} + tensorboard_logs = {"mse_loss": avg_loss} - return { - 'test_loss': avg_loss, - 'log': tensorboard_logs - } + return {"test_loss": avg_loss, "log": tensorboard_logs} def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) @@ -123,31 +112,42 @@ def configure_optimizers(self): @staticmethod def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False) - parser.add_argument('--hidden_dim', type=int, default=128, - help='itermediate layers dimension before embedding for default encoder/decoder') - parser.add_argument('--latent_dim', type=int, default=32, - help='dimension of latent variables z') - parser.add_argument('--input_width', type=int, default=28, - help='input image width - 28 for MNIST (must be even)') - parser.add_argument('--input_height', type=int, default=28, - help='input image height - 28 for MNIST (must be even)') - parser.add_argument('--batch_size', type=int, default=32) - parser.add_argument('--num_workers', type=int, default=8, help="num dataloader workers") - parser.add_argument('--learning_rate', type=float, default=1e-3) - parser.add_argument('--data_dir', type=str, default='') + parser.add_argument( + "--hidden_dim", + type=int, + default=128, + help="itermediate layers dimension before embedding for default encoder/decoder", + ) + parser.add_argument("--latent_dim", type=int, default=32, help="dimension of latent variables z") + parser.add_argument("--learning_rate", type=float, default=1e-3) return parser -def cli_main(): +def cli_main(args=None): parser = ArgumentParser() + parser.add_argument("--dataset", default="mnist", type=str, help="mnist, cifar10, stl10, imagenet") + script_args, _ = parser.parse_known_args(args) + + if script_args.dataset == "mnist": + dm_cls = MNISTDataModule + elif script_args.dataset == "cifar10": + dm_cls = CIFAR10DataModule + elif script_args.dataset == "stl10": + dm_cls = STL10DataModule + elif script_args.dataset == "imagenet": + dm_cls = ImagenetDataModule + + parser = dm_cls.add_argparse_args(parser) parser = Trainer.add_argparse_args(parser) parser = AE.add_model_specific_args(parser) - args = parser.parse_args() + args = parser.parse_args(args) - ae = AE(**vars(args)) + dm = dm_cls.from_argparse_args(args) + model = AE(*dm.size(), **vars(args)) trainer = Trainer.from_argparse_args(args) - trainer.fit(ae) + trainer.fit(model, dm) + return dm, model, trainer -if __name__ == '__main__': - cli_main() +if __name__ == "__main__": + dm, model, trainer = cli_main() diff --git a/pl_bolts/models/autoencoders/basic_ae/components.py b/pl_bolts/models/autoencoders/basic_ae/components.py index 442012bcf8..a4bb4ec6ba 100644 --- a/pl_bolts/models/autoencoders/basic_ae/components.py +++ b/pl_bolts/models/autoencoders/basic_ae/components.py @@ -9,26 +9,27 @@ class AEEncoder(torch.nn.Module): get split into a mu and sigma vector """ - def __init__(self, hidden_dim, latent_dim, input_width, input_height): + def __init__(self, hidden_dim, latent_dim, input_channels, input_height, input_width): super().__init__() self.hidden_dim = hidden_dim self.latent_dim = latent_dim - self.input_width = input_width + self.input_channels = input_channels self.input_height = input_height + self.input_width = input_width - self.c1 = nn.Conv2d(1, 32, kernel_size=3, padding=1) + self.c1 = nn.Conv2d(input_channels, 32, kernel_size=3, padding=1) self.c2 = nn.Conv2d(32, 32, kernel_size=3, padding=1) self.c3 = nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1) - conv_out_dim = self._calculate_output_dim(input_width, input_height) + conv_out_dim = self._calculate_output_dim(input_channels, input_width, input_height) self.fc1 = DenseBlock(conv_out_dim, hidden_dim) self.fc2 = DenseBlock(hidden_dim, hidden_dim) self.fc_z_out = nn.Linear(hidden_dim, latent_dim) - def _calculate_output_dim(self, input_width, input_height): - x = torch.rand(1, 1, input_width, input_height) + def _calculate_output_dim(self, input_channels, input_width, input_height): + x = torch.rand(1, input_channels, input_width, input_height) x = self.c3(self.c2(self.c1(x))) x = x.view(-1) return x.size(0) diff --git a/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py b/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py index 9c5cffd90d..2aad593ab3 100644 --- a/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py +++ b/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py @@ -1,33 +1,30 @@ import os from argparse import ArgumentParser -import torch import pytorch_lightning as pl +import torch from torch import distributions from torch.nn import functional as F -from pl_bolts.datamodules import MNISTDataModule, ImagenetDataModule, STL10DataModule, BinaryMNISTDataModule -from pl_bolts.models.autoencoders.basic_vae.components import Encoder, Decoder -from pl_bolts.utils.pretrained_weights import load_pretrained +from pl_bolts.datamodules import (BinaryMNISTDataModule, CIFAR10DataModule, + ImagenetDataModule, MNISTDataModule, + STL10DataModule) +from pl_bolts.models.autoencoders.basic_vae.components import Decoder, Encoder from pl_bolts.utils import shaping +from pl_bolts.utils.pretrained_weights import load_pretrained class VAE(pl.LightningModule): - def __init__( - self, - hidden_dim: int = 128, - latent_dim: int = 32, - input_channels: int = 3, - input_width: int = 224, - input_height: int = 224, - batch_size: int = 32, - learning_rate: float = 0.001, - data_dir: str = '.', - datamodule: pl.LightningDataModule = None, - num_workers: int = 8, - pretrained: str = None, - **kwargs + self, + input_channels: int, + input_height: int, + input_width: int, + hidden_dim: int = 128, + latent_dim: int = 32, + learning_rate: float = 0.001, + pretrained: str = None, + **kwargs ): """ Standard VAE with Gaussian Prior and approx posterior. @@ -60,12 +57,7 @@ def __init__( """ super().__init__() self.save_hyperparameters() - - self.datamodule = datamodule - self.__set_pretrained_dims(pretrained) - - # use mnist as the default module - self._set_default_datamodule(datamodule) + self.img_dim = (input_channels, input_height, input_width) # init actual model self.__init_system() @@ -77,29 +69,11 @@ def __init_system(self): self.encoder = self.init_encoder() self.decoder = self.init_decoder() - def __set_pretrained_dims(self, pretrained): - if pretrained == 'imagenet2012': - self.datamodule = ImagenetDataModule(data_dir=self.hparams.data_dir) - (self.hparams.input_channels, self.hparams.input_height, self.hparams.input_width) = self.datamodule.size() - - def _set_default_datamodule(self, datamodule): - # link default data - if datamodule is None: - datamodule = MNISTDataModule( - data_dir=self.hparams.data_dir, - num_workers=self.hparams.num_workers, - normalize=False - ) - self.datamodule = datamodule - self.img_dim = self.datamodule.size() - - (self.hparams.input_channels, self.hparams.input_height, self.hparams.input_width) = self.img_dim - def load_pretrained(self, pretrained): - available_weights = {'imagenet2012'} + available_weights = {"imagenet2012"} if pretrained in available_weights: - weights_name = f'vae-{pretrained}' + weights_name = f"vae-{pretrained}" load_pretrained(self, weights_name) def init_encoder(self): @@ -108,7 +82,7 @@ def init_encoder(self): self.hparams.latent_dim, self.hparams.input_channels, self.hparams.input_width, - self.hparams.input_height + self.hparams.input_height, ) return encoder @@ -118,7 +92,7 @@ def init_decoder(self): self.hparams.latent_dim, self.hparams.input_width, self.hparams.input_height, - self.hparams.input_channels + self.hparams.input_channels, ) return decoder @@ -157,7 +131,7 @@ def elbo_loss(self, x, P, Q, num_samples): x = shaping.tile(x.unsqueeze(1), 1, num_samples) pxz = torch.sigmoid(pxz) - recon_loss = F.binary_cross_entropy(pxz, x, reduction='none') + recon_loss = F.binary_cross_entropy(pxz, x, reduction="none") # sum across dimensions because sum of log probabilities of iid univariate gaussians is the same as # multivariate gaussian @@ -210,31 +184,23 @@ def _run_step(self, batch): def training_step(self, batch, batch_idx): loss, recon_loss, kl_div, pxz = self._run_step(batch) result = pl.TrainResult(loss) - result.log_dict({ - 'train_elbo_loss': loss, - 'train_recon_loss': recon_loss, - 'train_kl_loss': kl_div - }) + result.log_dict({"train_elbo_loss": loss, "train_recon_loss": recon_loss, "train_kl_loss": kl_div}) return result def validation_step(self, batch, batch_idx): loss, recon_loss, kl_div, pxz = self._run_step(batch) result = pl.EvalResult(loss, checkpoint_on=loss) - result.log_dict({ - 'val_loss': loss, - 'val_recon_loss': recon_loss, - 'val_kl_div': kl_div, - }) + result.log_dict( + {"val_loss": loss, "val_recon_loss": recon_loss, "val_kl_div": kl_div} + ) return result def test_step(self, batch, batch_idx): loss, recon_loss, kl_div, pxz = self._run_step(batch) result = pl.EvalResult(loss) - result.log_dict({ - 'test_loss': loss, - 'test_recon_loss': recon_loss, - 'test_kl_div': kl_div, - }) + result.log_dict( + {"test_loss": loss, "test_recon_loss": recon_loss, "test_kl_div": kl_div} + ) return result def configure_optimizers(self): @@ -243,55 +209,47 @@ def configure_optimizers(self): @staticmethod def add_model_specific_args(parent_parser): parser = ArgumentParser(parents=[parent_parser], add_help=False) - parser.add_argument('--hidden_dim', type=int, default=128, - help='itermediate layers dimension before embedding for default encoder/decoder') - parser.add_argument('--latent_dim', type=int, default=4, - help='dimension of latent variables z') - parser.add_argument('--input_width', type=int, default=224, - help='input width (used Imagenet downsampled size)') - parser.add_argument('--input_height', type=int, default=224, - help='input width (used Imagenet downsampled size)') - parser.add_argument('--input_channels', type=int, default=3, - help='number of input channels') - parser.add_argument('--batch_size', type=int, default=64) - parser.add_argument('--pretrained', type=str, default=None) - parser.add_argument('--data_dir', type=str, default=os.getcwd()) - parser.add_argument('--num_workers', type=int, default=8, help="num dataloader workers") - - parser.add_argument('--learning_rate', type=float, default=1e-3) + parser.add_argument( + "--hidden_dim", + type=int, + default=128, + help="itermediate layers dimension before embedding for default encoder/decoder", + ) + parser.add_argument("--latent_dim", type=int, default=4, help="dimension of latent variables z") + parser.add_argument("--pretrained", type=str, default=None) + parser.add_argument("--learning_rate", type=float, default=1e-3) return parser -def cli_main(): +def cli_main(args=None): from pl_bolts.callbacks import LatentDimInterpolator, TensorboardGenerativeModelImageSampler - from pl_bolts.datamodules import ImagenetDataModule - pl.seed_everything(1234) + # cli_main() parser = ArgumentParser() - parser.add_argument('--dataset', default='mnist', type=str, help='mnist, stl10, imagenet') - + parser.add_argument("--dataset", default="mnist", type=str, help="mnist, cifar10, stl10, imagenet") + script_args, _ = parser.parse_known_args(args) + + if script_args.dataset == "mnist": + dm_cls = MNISTDataModule + elif script_args.dataset == "cifar10": + dm_cls = CIFAR10DataModule + elif script_args.dataset == "stl10": + dm_cls = STL10DataModule + elif script_args.dataset == "imagenet": + dm_cls = ImagenetDataModule + + parser = dm_cls.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser) parser = VAE.add_model_specific_args(parser) - parser = ImagenetDataModule.add_argparse_args(parser) - parser = MNISTDataModule.add_argparse_args(parser) - args = parser.parse_args() - - # default is mnist - datamodule = None - if args.dataset == 'imagenet2012': - datamodule = ImagenetDataModule.from_argparse_args(args) - elif args.dataset == 'stl10': - datamodule = STL10DataModule.from_argparse_args(args) + args = parser.parse_args(args) + dm = dm_cls.from_argparse_args(args) + model = VAE(*dm.size(), **vars(args)) callbacks = [TensorboardGenerativeModelImageSampler(), LatentDimInterpolator(interpolate_epoch_interval=5)] - vae = VAE(**vars(args), datamodule=datamodule) - trainer = pl.Trainer.from_argparse_args( - args, - callbacks=callbacks, - progress_bar_refresh_rate=10, - ) - trainer.fit(vae) + trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks, progress_bar_refresh_rate=20) + trainer.fit(model, dm) + return dm, model, trainer -if __name__ == '__main__': - cli_main() +if __name__ == "__main__": + dm, model, trainer = cli_main() diff --git a/tests/models/test_autoencoders.py b/tests/models/test_autoencoders.py index cf57a7a317..57333ea82e 100644 --- a/tests/models/test_autoencoders.py +++ b/tests/models/test_autoencoders.py @@ -1,58 +1,69 @@ +import pytest import pytorch_lightning as pl import torch from pytorch_lightning import seed_everything -from pl_bolts.models.autoencoders import VAE, AE +from pl_bolts.datamodules import CIFAR10DataModule, MNISTDataModule +from pl_bolts.models.autoencoders import AE, VAE from pl_bolts.models.autoencoders.basic_ae import AEEncoder -from pl_bolts.models.autoencoders.basic_vae import Encoder, Decoder +from pl_bolts.models.autoencoders.basic_vae import Decoder, Encoder -def test_vae(tmpdir): +@pytest.mark.parametrize( + "dm_cls", [pytest.param(MNISTDataModule, id="mnist"), pytest.param(CIFAR10DataModule, id="cifar10")] +) +def test_vae(tmpdir, dm_cls): seed_everything() - - model = VAE(data_dir=tmpdir, batch_size=2, num_workers=0) + dm = dm_cls(batch_size=2, num_workers=0) + model = VAE(*dm.size()) trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir, deterministic=True) - trainer.fit(model) - results = trainer.test(model)[0] - loss = results['test_loss'] + trainer.fit(model, dm) + results = trainer.test(model, datamodule=dm)[0] + loss = results["test_loss"] - assert loss > 0, 'VAE failed' + assert loss > 0, "VAE failed" -def test_ae(tmpdir): +@pytest.mark.parametrize( + "dm_cls", [pytest.param(MNISTDataModule, id="mnist"), pytest.param(CIFAR10DataModule, id="cifar10")] +) +def test_ae(tmpdir, dm_cls): seed_everything() - - model = AE(data_dir=tmpdir, batch_size=2) + dm = dm_cls(batch_size=2, num_workers=0) + model = VAE(*dm.size()) trainer = pl.Trainer(fast_dev_run=True, default_root_dir=tmpdir) - trainer.fit(model) - trainer.test(model) - - -def test_basic_ae_encoder(tmpdir): + trainer.fit(model, dm) + trainer.test(model, datamodule=dm) + + +@pytest.mark.parametrize( + "hidden_dim,latent_dim,batch_size,channels,height,width", + [ + pytest.param(128, 2, 16, 1, 28, 28, id="like-mnist-hidden-128-latent-2"), + pytest.param(128, 4, 16, 1, 28, 28, id="like-mnist-hidden-128-latent-4"), + pytest.param(64, 4, 16, 1, 28, 28, id="like-mnist-hidden-64-latent-4"), + pytest.param(128, 2, 16, 3, 32, 32, id="like-cifar10-hidden-128-latent-2"), + ], +) +def test_basic_ae_encoder(tmpdir, hidden_dim, latent_dim, batch_size, channels, height, width): seed_everything() - - hidden_dim = 128 - latent_dim = 2 - width = height = 28 - batch_size = 16 - channels = 1 - - encoder = AEEncoder(hidden_dim, latent_dim, width, height) + encoder = AEEncoder(hidden_dim, latent_dim, channels, width, height) x = torch.randn(batch_size, channels, width, height) z = encoder(x) - assert z.shape == (batch_size, latent_dim) -def test_basic_vae_components(tmpdir): +@pytest.mark.parametrize( + "hidden_dim,latent_dim,batch_size,channels,height,width", + [ + pytest.param(128, 2, 16, 1, 28, 28, id="like-mnist-hidden-128-latent-2"), + pytest.param(128, 4, 16, 1, 28, 28, id="like-mnist-hidden-128-latent-4"), + pytest.param(64, 4, 16, 1, 28, 28, id="like-mnist-hidden-64-latent-4"), + pytest.param(128, 2, 16, 3, 32, 32, id="like-cifar10-hidden-128-latent-2"), + ], +) +def test_basic_vae_components(tmpdir, hidden_dim, latent_dim, batch_size, channels, height, width): seed_everything() - - hidden_dim = 128 - latent_dim = 2 - width = height = 28 - batch_size = 16 - channels = 1 - enc = Encoder(hidden_dim, latent_dim, channels, width, height) x = torch.randn(batch_size, channels, width, height) mu, sigma = enc(x) diff --git a/tests/models/test_executable_scripts.py b/tests/models/test_executable_scripts.py index f381a90005..7923357b6c 100644 --- a/tests/models/test_executable_scripts.py +++ b/tests/models/test_executable_scripts.py @@ -15,14 +15,44 @@ def test_cli_basic_gan(cli_args): cli_main() -@pytest.mark.parametrize('cli_args', ['--max_epochs 1' - ' --limit_train_batches 3' - ' --limit_val_batches 3' - ' --batch_size 3']) -def test_cli_basic_vae(cli_args): +@pytest.mark.parametrize( + "dataset_name", [ + pytest.param('mnist', id="mnist"), + pytest.param('cifar10', id="cifar10") + ] +) +def test_cli_basic_vae(dataset_name): from pl_bolts.models.autoencoders.basic_vae.basic_vae_module import cli_main - cli_args = cli_args.split(' ') if cli_args else [] + cli_args = f""" + --dataset {dataset_name} + --max_epochs 1 + --limit_train_batches 3 + --limit_val_batches 3 + --batch_size 3 + """.strip().split() + + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): + cli_main() + + +@pytest.mark.parametrize( + "dataset_name", [ + pytest.param('mnist', id="mnist"), + pytest.param('cifar10', id="cifar10") + ] +) +def test_cli_basic_ae(dataset_name): + from pl_bolts.models.autoencoders.basic_ae.basic_ae_module import cli_main + + cli_args = f""" + --dataset {dataset_name} + --max_epochs 1 + --limit_train_batches 3 + --limit_val_batches 3 + --batch_size 3 + """.strip().split() + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): cli_main()