Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modularize AE & VAE #196

Merged
merged 14 commits into from
Sep 11, 2020
2 changes: 1 addition & 1 deletion pl_bolts/datamodules/mnist_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class MNISTDataModule(LightningDataModule):

def __init__(
self,
data_dir: str,
data_dir: str = './',
val_split: int = 5000,
num_workers: int = 16,
normalize: bool = False,
Expand Down
1 change: 1 addition & 0 deletions pl_bolts/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Collection of PyTorchLightning models
"""

from pl_bolts.models.autoencoders.basic_ae.basic_ae_module import AE
from pl_bolts.models.autoencoders.basic_vae.basic_vae_module import VAE
from pl_bolts.models.mnist_module import LitMNIST
from pl_bolts.models.regression import LinearRegression
Expand Down
20 changes: 10 additions & 10 deletions pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class AE(LightningModule):

def __init__(
self,
datamodule: LightningDataModule = None,
Borda marked this conversation as resolved.
Show resolved Hide resolved
# datamodule: LightningDataModule = None,
input_channels=1,
nateraw marked this conversation as resolved.
Show resolved Hide resolved
input_height=28,
input_width=28,
Expand Down Expand Up @@ -43,24 +43,24 @@ def __init__(
self.save_hyperparameters()

# link default data
if datamodule is None:
datamodule = MNISTDataModule(data_dir=self.hparams.data_dir, num_workers=self.hparams.num_workers)
# if datamodule is None:
# datamodule = MNISTDataModule(data_dir=self.hparams.data_dir, num_workers=self.hparams.num_workers)

self.datamodule = datamodule
# self.datamodule = datamodule

self.img_dim = self.datamodule.size()
# self.img_dim = self.datamodule.size()
nateraw marked this conversation as resolved.
Show resolved Hide resolved

self.encoder = self.init_encoder(self.hparams.hidden_dim, self.hparams.latent_dim,
self.encoder = self.init_encoder(self.hparams.hidden_dim, self.hparams.latent_dim, self.hparams.input_channels,
self.hparams.input_width, self.hparams.input_height)
self.decoder = self.init_decoder(self.hparams.hidden_dim, self.hparams.latent_dim)

def init_encoder(self, hidden_dim, latent_dim, input_width, input_height):
encoder = AEEncoder(hidden_dim, latent_dim, input_width, input_height)
def init_encoder(self, hidden_dim, latent_dim, input_channels, input_height, input_width):
nateraw marked this conversation as resolved.
Show resolved Hide resolved
encoder = AEEncoder(hidden_dim, latent_dim, input_channels, input_height, input_width)
return encoder

def init_decoder(self, hidden_dim, latent_dim):
c, h, w = self.img_dim
decoder = Decoder(hidden_dim, latent_dim, w, h, c)
# c, h, w = self.img_dim
decoder = Decoder(hidden_dim, latent_dim, self.hparams.input_width, self.hparams.input_height, self.hparams.input_channels)
return decoder

def forward(self, z):
Expand Down
13 changes: 7 additions & 6 deletions pl_bolts/models/autoencoders/basic_ae/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,27 @@ class AEEncoder(torch.nn.Module):
get split into a mu and sigma vector
"""

def __init__(self, hidden_dim, latent_dim, input_width, input_height):
def __init__(self, hidden_dim, latent_dim, input_channels, input_height, input_width):
super().__init__()
self.hidden_dim = hidden_dim
self.latent_dim = latent_dim
self.input_width = input_width
self.input_channels = input_channels
self.input_height = input_height
self.input_width = input_width

self.c1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.c1 = nn.Conv2d(input_channels, 32, kernel_size=3, padding=1)
self.c2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
self.c3 = nn.Conv2d(32, 32, kernel_size=3, stride=2, padding=1)

conv_out_dim = self._calculate_output_dim(input_width, input_height)
conv_out_dim = self._calculate_output_dim(input_channels, input_width, input_height)

self.fc1 = DenseBlock(conv_out_dim, hidden_dim)
self.fc2 = DenseBlock(hidden_dim, hidden_dim)

self.fc_z_out = nn.Linear(hidden_dim, latent_dim)

def _calculate_output_dim(self, input_width, input_height):
x = torch.rand(1, 1, input_width, input_height)
def _calculate_output_dim(self, input_channels, input_width, input_height):
x = torch.rand(1, input_channels, input_width, input_height)
x = self.c3(self.c2(self.c1(x)))
x = x.view(-1)
return x.size(0)
Expand Down
40 changes: 20 additions & 20 deletions pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ def __init__(
super().__init__()
self.save_hyperparameters()

self.datamodule = datamodule
self.__set_pretrained_dims(pretrained)
#self.datamodule = datamodule
#self.__set_pretrained_dims(pretrained)

# use mnist as the default module
self._set_default_datamodule(datamodule)
#self._set_default_datamodule(datamodule)

# init actual model
self.__init_system()
Expand All @@ -77,23 +77,23 @@ def __init_system(self):
self.encoder = self.init_encoder()
self.decoder = self.init_decoder()

def __set_pretrained_dims(self, pretrained):
if pretrained == 'imagenet2012':
self.datamodule = ImagenetDataModule(data_dir=self.hparams.data_dir)
nateraw marked this conversation as resolved.
Show resolved Hide resolved
(self.hparams.input_channels, self.hparams.input_height, self.hparams.input_width) = self.datamodule.size()

def _set_default_datamodule(self, datamodule):
nateraw marked this conversation as resolved.
Show resolved Hide resolved
# link default data
if datamodule is None:
datamodule = MNISTDataModule(
data_dir=self.hparams.data_dir,
num_workers=self.hparams.num_workers,
normalize=False
)
self.datamodule = datamodule
self.img_dim = self.datamodule.size()

(self.hparams.input_channels, self.hparams.input_height, self.hparams.input_width) = self.img_dim
# def __set_pretrained_dims(self, pretrained):
# if pretrained == 'imagenet2012':
# self.datamodule = ImagenetDataModule(data_dir=self.hparams.data_dir)
# (self.hparams.input_channels, self.hparams.input_height, self.hparams.input_width) = self.datamodule.size()

# def _set_default_datamodule(self, datamodule):
# # link default data
# if datamodule is None:
# datamodule = MNISTDataModule(
# data_dir=self.hparams.data_dir,
# num_workers=self.hparams.num_workers,
# normalize=False
# )
# self.datamodule = datamodule
# self.img_dim = self.datamodule.size()

# (self.hparams.input_channels, self.hparams.input_height, self.hparams.input_width) = self.img_dim

def load_pretrained(self, pretrained):
available_weights = {'imagenet2012'}
Expand Down