Skip to content

Latest commit

 

History

History
181 lines (111 loc) · 4.07 KB

autoencoders.rst

File metadata and controls

181 lines (111 loc) · 4.07 KB

Autoencoders

This section houses autoencoders and variational autoencoders.

Note

We rely on the community to keep these updated and working. If something doesn't work, we'd really appreciate a contribution to fix!


Basic AE

This is the simplest autoencoder. You can use it like so

from pl_bolts.models.autoencoders import AE

model = AE()
trainer = Trainer()
trainer.fit(model)

You can override any part of this AE to build your own variation.

from pl_bolts.models.autoencoders import AE

class MyAEFlavor(AE):

def init_encoder(self, hidden_dim, latent_dim, input_width, input_height):

encoder = YourSuperFancyEncoder(...) return encoder

You can use the pretrained models present in bolts.

CIFAR-10 pretrained model:

from pl_bolts.models.autoencoders import AE

ae = AE(input_height=32)
print(AE.pretrained_weights_available())
ae = ae.from_pretrained('cifar10-resnet18')

ae.freeze()

Training:

Reconstructions:

Both input and generated images are normalized versions as the training was done with such images.

pl_bolts.models.autoencoders.AE


Variational Autoencoders

Basic VAE

Use the VAE like so.

from pl_bolts.models.autoencoders import VAE

model = VAE()
trainer = Trainer()
trainer.fit(model)

You can override any part of this VAE to build your own variation.

from pl_bolts.models.autoencoders import VAE

class MyVAEFlavor(VAE):

    def get_posterior(self, mu, std):
        # do something other than the default
        # P = self.get_distribution(self.prior, loc=torch.zeros_like(mu), scale=torch.ones_like(std))

        return P

You can use the pretrained models present in bolts.

CIFAR-10 pretrained model:

from pl_bolts.models.autoencoders import VAE

vae = VAE(input_height=32)
print(VAE.pretrained_weights_available())
vae = vae.from_pretrained('cifar10-resnet18')

vae.freeze()

Training:

Reconstructions:

Both input and generated images are normalized versions as the training was done with such images.

STL-10 pretrained model:

from pl_bolts.models.autoencoders import VAE

vae = VAE(input_height=96, first_conv=True)
print(VAE.pretrained_weights_available())
vae = vae.from_pretrained('cifar10-resnet18')

vae.freeze()

Training:

pl_bolts.models.autoencoders.VAE