Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modularize AE & VAE #196

Merged
merged 14 commits into from Sep 11, 2020
45 changes: 18 additions & 27 deletions pl_bolts/datamodules/mnist_datamodule.py
Expand Up @@ -7,17 +7,18 @@

class MNISTDataModule(LightningDataModule):

name = 'mnist'
name = "mnist"

def __init__(
self,
data_dir: str,
val_split: int = 5000,
num_workers: int = 16,
normalize: bool = False,
seed: int = 42,
*args,
**kwargs,
self,
data_dir: str = "./",
val_split: int = 5000,
num_workers: int = 16,
normalize: bool = False,
seed: int = 42,
batch_size: int = 32,
*args,
**kwargs,
):
"""
.. figure:: https://miro.medium.com/max/744/1*AO2rIhzRYzFVQlFLx9DM9A.png
Expand Down Expand Up @@ -87,17 +88,15 @@ def train_dataloader(self, batch_size=32, transforms=None):
dataset = MNIST(self.data_dir, train=True, download=False, transform=transforms)
train_length = len(dataset)
dataset_train, _ = random_split(
dataset,
[train_length - self.val_split, self.val_split],
generator=torch.Generator().manual_seed(self.seed)
dataset, [train_length - self.val_split, self.val_split], generator=torch.Generator().manual_seed(self.seed)
)
loader = DataLoader(
dataset_train,
batch_size=batch_size,
shuffle=True,
num_workers=self.num_workers,
drop_last=True,
pin_memory=True
pin_memory=True,
)
return loader

Expand All @@ -113,17 +112,15 @@ def val_dataloader(self, batch_size=32, transforms=None):
dataset = MNIST(self.data_dir, train=True, download=True, transform=transforms)
train_length = len(dataset)
_, dataset_val = random_split(
dataset,
[train_length - self.val_split, self.val_split],
generator=torch.Generator().manual_seed(self.seed)
dataset, [train_length - self.val_split, self.val_split], generator=torch.Generator().manual_seed(self.seed)
)
loader = DataLoader(
dataset_val,
batch_size=batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=True,
pin_memory=True
pin_memory=True,
)
return loader

Expand All @@ -139,21 +136,15 @@ def test_dataloader(self, batch_size=32, transforms=None):

dataset = MNIST(self.data_dir, train=False, download=False, transform=transforms)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=True,
pin_memory=True
dataset, batch_size=batch_size, shuffle=False, num_workers=self.num_workers, drop_last=True, pin_memory=True
)
return loader

def _default_transforms(self):
if self.normalize:
mnist_transforms = transform_lib.Compose([
transform_lib.ToTensor(),
transform_lib.Normalize(mean=(0.5,), std=(0.5,)),
])
mnist_transforms = transform_lib.Compose(
[transform_lib.ToTensor(), transform_lib.Normalize(mean=(0.5,), std=(0.5,))]
)
else:
mnist_transforms = transform_lib.ToTensor()

Expand Down
4 changes: 2 additions & 2 deletions pl_bolts/models/__init__.py
Expand Up @@ -2,9 +2,9 @@
Collection of PyTorchLightning models
"""

from pl_bolts.models.autoencoders.basic_ae.basic_ae_module import AE
from pl_bolts.models.autoencoders.basic_vae.basic_vae_module import VAE
from pl_bolts.models.mnist_module import LitMNIST
from pl_bolts.models.regression import LinearRegression
from pl_bolts.models.regression import LogisticRegression
from pl_bolts.models.regression import LinearRegression, LogisticRegression
from pl_bolts.models.vision import PixelCNN
from pl_bolts.models.vision.image_gpt.igpt_module import GPT2, ImageGPT
124 changes: 62 additions & 62 deletions pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py
Expand Up @@ -4,26 +4,22 @@
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from torch.nn import functional as F

from pl_bolts.datamodules import MNISTDataModule
from pl_bolts.datamodules import (CIFAR10DataModule, ImagenetDataModule,
MNISTDataModule, STL10DataModule)
from pl_bolts.models.autoencoders.basic_ae.components import AEEncoder
from pl_bolts.models.autoencoders.basic_vae.components import Decoder


class AE(LightningModule):

def __init__(
self,
datamodule: LightningDataModule = None,
Borda marked this conversation as resolved.
Show resolved Hide resolved
input_channels=1,
input_height=28,
input_width=28,
latent_dim=32,
batch_size=32,
hidden_dim=128,
learning_rate=0.001,
num_workers=8,
data_dir='.',
**kwargs
self,
input_channels: int,
input_height: int,
input_width: int,
latent_dim=32,
hidden_dim=128,
learning_rate=0.001,
**kwargs
):
"""
Args:
Expand All @@ -42,25 +38,24 @@ def __init__(
super().__init__()
self.save_hyperparameters()

# link default data
if datamodule is None:
datamodule = MNISTDataModule(data_dir=self.hparams.data_dir, num_workers=self.hparams.num_workers)

self.datamodule = datamodule

self.img_dim = self.datamodule.size()

self.encoder = self.init_encoder(self.hparams.hidden_dim, self.hparams.latent_dim,
self.hparams.input_width, self.hparams.input_height)
self.encoder = self.init_encoder(
self.hparams.hidden_dim,
self.hparams.latent_dim,
self.hparams.input_channels,
self.hparams.input_width,
self.hparams.input_height,
)
self.decoder = self.init_decoder(self.hparams.hidden_dim, self.hparams.latent_dim)

def init_encoder(self, hidden_dim, latent_dim, input_width, input_height):
encoder = AEEncoder(hidden_dim, latent_dim, input_width, input_height)
def init_encoder(self, hidden_dim, latent_dim, input_channels, input_height, input_width):
nateraw marked this conversation as resolved.
Show resolved Hide resolved
encoder = AEEncoder(hidden_dim, latent_dim, input_channels, input_height, input_width)
return encoder

def init_decoder(self, hidden_dim, latent_dim):
c, h, w = self.img_dim
decoder = Decoder(hidden_dim, latent_dim, w, h, c)
# c, h, w = self.img_dim
decoder = Decoder(
hidden_dim, latent_dim, self.hparams.input_width, self.hparams.input_height, self.hparams.input_channels
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's keep this order if it is everywhere else

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should be c, h, w right? I started switching over to that from 'c, w, h', which is wrong, if I'm not mistaken? (that's what I kept seeing in this file, but I believe its mixed up)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just see that the order is different in several places...

)
return decoder

def forward(self, z):
Expand All @@ -78,76 +73,81 @@ def training_step(self, batch, batch_idx):
loss = self._run_step(batch)

tensorboard_logs = {
'mse_loss': loss,
"mse_loss": loss,
}

return {'loss': loss, 'log': tensorboard_logs}
return {"loss": loss, "log": tensorboard_logs}

def validation_step(self, batch, batch_idx):
loss = self._run_step(batch)

return {
'val_loss': loss,
"val_loss": loss,
}

def validation_epoch_end(self, outputs):
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
avg_loss = torch.stack([x["val_loss"] for x in outputs]).mean()

tensorboard_logs = {'mse_loss': avg_loss}
tensorboard_logs = {"mse_loss": avg_loss}

return {
'val_loss': avg_loss,
'log': tensorboard_logs
}
return {"val_loss": avg_loss, "log": tensorboard_logs}

def test_step(self, batch, batch_idx):
loss = self._run_step(batch)

return {
'test_loss': loss,
"test_loss": loss,
}

def test_epoch_end(self, outputs):
avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
avg_loss = torch.stack([x["test_loss"] for x in outputs]).mean()

tensorboard_logs = {'mse_loss': avg_loss}
tensorboard_logs = {"mse_loss": avg_loss}

return {
'test_loss': avg_loss,
'log': tensorboard_logs
}
return {"test_loss": avg_loss, "log": tensorboard_logs}

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)

@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument('--hidden_dim', type=int, default=128,
help='itermediate layers dimension before embedding for default encoder/decoder')
parser.add_argument('--latent_dim', type=int, default=32,
help='dimension of latent variables z')
parser.add_argument('--input_width', type=int, default=28,
help='input image width - 28 for MNIST (must be even)')
parser.add_argument('--input_height', type=int, default=28,
help='input image height - 28 for MNIST (must be even)')
parser.add_argument('--batch_size', type=int, default=32)
parser.add_argument('--num_workers', type=int, default=8, help="num dataloader workers")
parser.add_argument('--learning_rate', type=float, default=1e-3)
parser.add_argument('--data_dir', type=str, default='')
parser.add_argument(
"--hidden_dim",
type=int,
default=128,
help="itermediate layers dimension before embedding for default encoder/decoder",
)
parser.add_argument("--latent_dim", type=int, default=32, help="dimension of latent variables z")
parser.add_argument("--learning_rate", type=float, default=1e-3)
return parser


def cli_main():
def cli_main(args=None):
parser = ArgumentParser()
parser.add_argument("--dataset", default="mnist", type=str, help="mnist, cifar10, stl10, imagenet")
script_args, _ = parser.parse_known_args(args)
nateraw marked this conversation as resolved.
Show resolved Hide resolved

if script_args.dataset == "mnist":
dm_cls = MNISTDataModule
elif script_args.dataset == "cifar10":
dm_cls = CIFAR10DataModule
elif script_args.dataset == "stl10":
dm_cls = STL10DataModule
elif script_args.dataset == "imagenet":
dm_cls = ImagenetDataModule

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add else and raise an error?

parser = dm_cls.add_argparse_args(parser)
parser = Trainer.add_argparse_args(parser)
parser = AE.add_model_specific_args(parser)
args = parser.parse_args()
args = parser.parse_args(args)
Borda marked this conversation as resolved.
Show resolved Hide resolved

ae = AE(**vars(args))
dm = dm_cls.from_argparse_args(args)
model = AE(*dm.size(), **vars(args))
trainer = Trainer.from_argparse_args(args)
trainer.fit(ae)
trainer.fit(model, dm)
return dm, model, trainer


if __name__ == '__main__':
cli_main()
if __name__ == "__main__":
dm, model, trainer = cli_main()
Borda marked this conversation as resolved.
Show resolved Hide resolved
13 changes: 7 additions & 6 deletions pl_bolts/models/autoencoders/basic_ae/components.py
Expand Up @@ -9,26 +9,27 @@ class AEEncoder(torch.nn.Module):
get split into a mu and sigma vector
"""

def __init__(self, hidden_dim, latent_dim, input_width, input_height):
def __init__(self, hidden_dim, latent_dim, input_channels, input_height, input_width):
super().__init__()
self.hidden_dim = hidden_dim
self.latent_dim = latent_dim
self.input_width = input_width
self.input_channels = input_channels
self.input_height = input_height
self.input_width = input_width

self.c1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.c1 = nn.Conv2d(input_channels, 32, kernel_size=3, padding=1)
self.c2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
self.c3 = nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1)

conv_out_dim = self._calculate_output_dim(input_width, input_height)
conv_out_dim = self._calculate_output_dim(input_channels, input_width, input_height)

self.fc1 = DenseBlock(conv_out_dim, hidden_dim)
self.fc2 = DenseBlock(hidden_dim, hidden_dim)

self.fc_z_out = nn.Linear(hidden_dim, latent_dim)

def _calculate_output_dim(self, input_width, input_height):
x = torch.rand(1, 1, input_width, input_height)
def _calculate_output_dim(self, input_channels, input_width, input_height):
x = torch.rand(1, input_channels, input_width, input_height)
x = self.c3(self.c2(self.c1(x)))
x = x.view(-1)
return x.size(0)
Expand Down