This section houses autoencoders and variational autoencoders.
Note
We rely on the community to keep these updated and working. If something doesn't work, we'd really appreciate a contribution to fix!
This is the simplest autoencoder. You can use it like so
from pl_bolts.models.autoencoders import AE
model = AE()
trainer = Trainer()
trainer.fit(model)
You can override any part of this AE to build your own variation.
from pl_bolts.models.autoencoders import AE
class MyAEFlavor(AE):
- def init_encoder(self, hidden_dim, latent_dim, input_width, input_height):
encoder = YourSuperFancyEncoder(...) return encoder
You can use the pretrained models present in bolts.
CIFAR-10 pretrained model:
from pl_bolts.models.autoencoders import AE
ae = AE(input_height=32)
print(AE.pretrained_weights_available())
ae = ae.from_pretrained('cifar10-resnet18')
ae.freeze()
Training:
Reconstructions:
Both input and generated images are normalized versions as the training was done with such images.
pl_bolts.models.autoencoders.AE
Use the VAE like so.
from pl_bolts.models.autoencoders import VAE
model = VAE()
trainer = Trainer()
trainer.fit(model)
You can override any part of this VAE to build your own variation.
from pl_bolts.models.autoencoders import VAE
class MyVAEFlavor(VAE):
def get_posterior(self, mu, std):
# do something other than the default
# P = self.get_distribution(self.prior, loc=torch.zeros_like(mu), scale=torch.ones_like(std))
return P
You can use the pretrained models present in bolts.
CIFAR-10 pretrained model:
from pl_bolts.models.autoencoders import VAE
vae = VAE(input_height=32)
print(VAE.pretrained_weights_available())
vae = vae.from_pretrained('cifar10-resnet18')
vae.freeze()
Training:
Reconstructions:
Both input and generated images are normalized versions as the training was done with such images.
STL-10 pretrained model:
from pl_bolts.models.autoencoders import VAE
vae = VAE(input_height=96, first_conv=True)
print(VAE.pretrained_weights_available())
vae = vae.from_pretrained('cifar10-resnet18')
vae.freeze()
Training:
pl_bolts.models.autoencoders.VAE