Skip to content

Commit

Permalink
Modularize AE & VAE (#196)
Browse files Browse the repository at this point in the history
* 🐛 make data dir kwarg instead of arg

* ✨ surface AE up in init

* 🚧 wip

* 🚧 .

* 🚧 .

* ✅ update tests

* ✅ update tests

* ✅ pytest is cute

* 💄 apply style

* 💄 style

* 💄 .

* 💄 .

* ✅ add tests

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
nateraw and Borda committed Sep 11, 2020
1 parent ca8e7b2 commit 752c433
Show file tree
Hide file tree
Showing 7 changed files with 230 additions and 239 deletions.
45 changes: 18 additions & 27 deletions pl_bolts/datamodules/mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,18 @@

class MNISTDataModule(LightningDataModule):

name = 'mnist'
name = "mnist"

def __init__(
self,
data_dir: str,
val_split: int = 5000,
num_workers: int = 16,
normalize: bool = False,
seed: int = 42,
*args,
**kwargs,
self,
data_dir: str = "./",
val_split: int = 5000,
num_workers: int = 16,
normalize: bool = False,
seed: int = 42,
batch_size: int = 32,
*args,
**kwargs,
):
"""
.. figure:: https://miro.medium.com/max/744/1*AO2rIhzRYzFVQlFLx9DM9A.png
Expand Down Expand Up @@ -87,17 +88,15 @@ def train_dataloader(self, batch_size=32, transforms=None):
dataset = MNIST(self.data_dir, train=True, download=False, transform=transforms)
train_length = len(dataset)
dataset_train, _ = random_split(
dataset,
[train_length - self.val_split, self.val_split],
generator=torch.Generator().manual_seed(self.seed)
dataset, [train_length - self.val_split, self.val_split], generator=torch.Generator().manual_seed(self.seed)
)
loader = DataLoader(
dataset_train,
batch_size=batch_size,
shuffle=True,
num_workers=self.num_workers,
drop_last=True,
pin_memory=True
pin_memory=True,
)
return loader

Expand All @@ -113,17 +112,15 @@ def val_dataloader(self, batch_size=32, transforms=None):
dataset = MNIST(self.data_dir, train=True, download=True, transform=transforms)
train_length = len(dataset)
_, dataset_val = random_split(
dataset,
[train_length - self.val_split, self.val_split],
generator=torch.Generator().manual_seed(self.seed)
dataset, [train_length - self.val_split, self.val_split], generator=torch.Generator().manual_seed(self.seed)
)
loader = DataLoader(
dataset_val,
batch_size=batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=True,
pin_memory=True
pin_memory=True,
)
return loader

Expand All @@ -139,21 +136,15 @@ def test_dataloader(self, batch_size=32, transforms=None):

dataset = MNIST(self.data_dir, train=False, download=False, transform=transforms)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=True,
pin_memory=True
dataset, batch_size=batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True, pin_memory=True
)
return loader

def _default_transforms(self):
if self.normalize:
mnist_transforms = transform_lib.Compose([
transform_lib.ToTensor(),
transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
])
mnist_transforms = transform_lib.Compose(
[transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
)
else:
mnist_transforms = transform_lib.ToTensor()

Expand Down
4 changes: 2 additions & 2 deletions pl_bolts/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
Collection of PyTorchLightning models
"""

from pl_bolts.models.autoencoders.basic_ae.basic_ae_module import AE
from pl_bolts.models.autoencoders.basic_vae.basic_vae_module import VAE
from pl_bolts.models.mnist_module import LitMNIST
from pl_bolts.models.regression import LinearRegression
from pl_bolts.models.regression import LogisticRegression
from pl_bolts.models.regression import LinearRegression, LogisticRegression
from pl_bolts.models.vision import PixelCNN
from pl_bolts.models.vision.image_gpt.igpt_module import GPT2, ImageGPT
124 changes: 62 additions & 62 deletions pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,22 @@
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from torch.nn import functional as F

from pl_bolts.datamodules import MNISTDataModule
from pl_bolts.datamodules import (CIFAR10DataModule, ImagenetDataModule,
MNISTDataModule, STL10DataModule)
from pl_bolts.models.autoencoders.basic_ae.components import AEEncoder
from pl_bolts.models.autoencoders.basic_vae.components import Decoder


class AE(LightningModule):

def __init__(
self,
datamodule: LightningDataModule = None,
input_channels=1,
input_height=28,
input_width=28,
latent_dim=32,
batch_size=32,
hidden_dim=128,
learning_rate=0.001,
num_workers=8,
data_dir='.',
**kwargs
self,
input_channels: int,
input_height: int,
input_width: int,
latent_dim=32,
hidden_dim=128,
learning_rate=0.001,
**kwargs
):
"""
Args:
Expand All @@ -42,25 +38,24 @@ def __init__(
super().__init__()
self.save_hyperparameters()

# link default data
if datamodule is None:
datamodule = MNISTDataModule(data_dir=self.hparams.data_dir, num_workers=self.hparams.num_workers)

self.datamodule = datamodule

self.img_dim = self.datamodule.size()

self.encoder = self.init_encoder(self.hparams.hidden_dim, self.hparams.latent_dim,
self.hparams.input_width, self.hparams.input_height)
self.encoder = self.init_encoder(
self.hparams.hidden_dim,
self.hparams.latent_dim,
self.hparams.input_channels,
self.hparams.input_width,
self.hparams.input_height,
)
self.decoder = self.init_decoder(self.hparams.hidden_dim, self.hparams.latent_dim)

def init_encoder(self, hidden_dim, latent_dim, input_width, input_height):
encoder = AEEncoder(hidden_dim, latent_dim, input_width, input_height)
def init_encoder(self, hidden_dim, latent_dim, input_channels, input_height, input_width):
encoder = AEEncoder(hidden_dim, latent_dim, input_channels, input_height, input_width)
return encoder

def init_decoder(self, hidden_dim, latent_dim):
c, h, w = self.img_dim
decoder = Decoder(hidden_dim, latent_dim, w, h, c)
# c, h, w = self.img_dim
decoder = Decoder(
hidden_dim, latent_dim, self.hparams.input_width, self.hparams.input_height, self.hparams.input_channels
)
return decoder

def forward(self, z):
Expand All @@ -78,76 +73,81 @@ def training_step(self, batch, batch_idx):
loss = self._run_step(batch)

tensorboard_logs = {
'mse_loss': loss,
"mse_loss": loss,
}

return {'loss': loss, 'log': tensorboard_logs}
return {"loss": loss, "log": tensorboard_logs}

def validation_step(self, batch, batch_idx):
loss = self._run_step(batch)

return {
'val_loss': loss,
"val_loss": loss,
}

def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()

tensorboard_logs = {'mse_loss': avg_loss}
tensorboard_logs = {"mse_loss": avg_loss}

return {
'val_loss': avg_loss,
'log': tensorboard_logs
}
return {"val_loss": avg_loss, "log": tensorboard_logs}

def test_step(self, batch, batch_idx):
loss = self._run_step(batch)

return {
'test_loss': loss,
"test_loss": loss,
}

def test_epoch_end(self, outputs):
avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()

tensorboard_logs = {'mse_loss': avg_loss}
tensorboard_logs = {"mse_loss": avg_loss}

return {
'test_loss': avg_loss,
'log': tensorboard_logs
}
return {"test_loss": avg_loss, "log": tensorboard_logs}

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--hidden_dim', type=int, default=128,
help='itermediate layers dimension before embedding for default encoder/decoder')
parser.add_argument('--latent_dim', type=int, default=32,
help='dimension of latent variables z')
parser.add_argument('--input_width', type=int, default=28,
help='input image width - 28 for MNIST (must be even)')
parser.add_argument('--input_height', type=int, default=28,
help='input image height - 28 for MNIST (must be even)')
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--num_workers', type=int, default=8, help="num dataloader workers")
parser.add_argument('--learning_rate', type=float, default=1e-3)
parser.add_argument('--data_dir', type=str, default='')
parser.add_argument(
"--hidden_dim",
type=int,
default=128,
help="itermediate layers dimension before embedding for default encoder/decoder",
)
parser.add_argument("--latent_dim", type=int, default=32, help="dimension of latent variables z")
parser.add_argument("--learning_rate", type=float, default=1e-3)
return parser


def cli_main():
def cli_main(args=None):
parser = ArgumentParser()
parser.add_argument("--dataset", default="mnist", type=str, help="mnist, cifar10, stl10, imagenet")
script_args, _ = parser.parse_known_args(args)

if script_args.dataset == "mnist":
dm_cls = MNISTDataModule
elif script_args.dataset == "cifar10":
dm_cls = CIFAR10DataModule
elif script_args.dataset == "stl10":
dm_cls = STL10DataModule
elif script_args.dataset == "imagenet":
dm_cls = ImagenetDataModule

parser = dm_cls.add_argparse_args(parser)
parser = Trainer.add_argparse_args(parser)
parser = AE.add_model_specific_args(parser)
args = parser.parse_args()
args = parser.parse_args(args)

ae = AE(**vars(args))
dm = dm_cls.from_argparse_args(args)
model = AE(*dm.size(), **vars(args))
trainer = Trainer.from_argparse_args(args)
trainer.fit(ae)
trainer.fit(model, dm)
return dm, model, trainer


if __name__ == '__main__':
cli_main()
if __name__ == "__main__":
dm, model, trainer = cli_main()
13 changes: 7 additions & 6 deletions pl_bolts/models/autoencoders/basic_ae/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,27 @@ class AEEncoder(torch.nn.Module):
get split into a mu and sigma vector
"""

def __init__(self, hidden_dim, latent_dim, input_width, input_height):
def __init__(self, hidden_dim, latent_dim, input_channels, input_height, input_width):
super().__init__()
self.hidden_dim = hidden_dim
self.latent_dim = latent_dim
self.input_width = input_width
self.input_channels = input_channels
self.input_height = input_height
self.input_width = input_width

self.c1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.c1 = nn.Conv2d(input_channels, 32, kernel_size=3, padding=1)
self.c2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
self.c3 = nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1)

conv_out_dim = self._calculate_output_dim(input_width, input_height)
conv_out_dim = self._calculate_output_dim(input_channels, input_width, input_height)

self.fc1 = DenseBlock(conv_out_dim, hidden_dim)
self.fc2 = DenseBlock(hidden_dim, hidden_dim)

self.fc_z_out = nn.Linear(hidden_dim, latent_dim)

def _calculate_output_dim(self, input_width, input_height):
x = torch.rand(1, 1, input_width, input_height)
def _calculate_output_dim(self, input_channels, input_width, input_height):
x = torch.rand(1, input_channels, input_width, input_height)
x = self.c3(self.c2(self.c1(x)))
x = x.view(-1)
return x.size(0)
Expand Down
Loading

0 comments on commit 752c433

Please sign in to comment.