-
Notifications
You must be signed in to change notification settings - Fork 322
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* training batch clean up * training batch clean up * training batch clean up * training batch clean up * training batch clean up * training batch clean up * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * fixes * hydra ddp support * hydra ddp support * hydra ddp support * hydra ddp support * hydra ddp support * hydra ddp support * hydra ddp support * update docs * update docs * update docs * update docs * update docs * update docs
- Loading branch information
1 parent
050611e
commit 2a8218b
Showing
28 changed files
with
537 additions
and
148 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
Autoencoder Models | ||
================== | ||
These are off-the-shelf autoencoder which can be used for resarch or as feature extractors. | ||
|
||
Pretrained models | ||
------------------ | ||
This is a basic template for implementing a Variational Autoencoder in PyTorch Lightning. | ||
|
||
A default encoder and decoder have been provided but can easily be replaced by custom models. | ||
|
||
This template uses the MNIST dataset but image data of any dimension can be fed in as long as the image width and image height are even values. | ||
For other types of data, such as sound, it will be necessary to change the Encoder and Decoder. | ||
|
||
The default encoder and decoder are both convolutional with a 128-dimensional hidden layer and | ||
a 32-dimensional latent space. The model also assumes a Gaussian prior and a Gaussian approximate posterior distribution. | ||
|
||
To use in your project or as a feature extractor: | ||
|
||
.. code-block:: python | ||
from pytorch_lightning_bolts.models.autoencoders import VAE | ||
import pytorch_lightning as pl | ||
class YourResearchModel(pl.LightningModule): | ||
def __init__(self): | ||
self.vae = VAE.load_from_checkpoint(PATH) | ||
self.vae.freeze() | ||
self.some_other_model = MyModel() | ||
def forward(self, z): | ||
# generate a sample from z ~ N(0,1) | ||
x = self.vae(z) | ||
# do stuff with sample | ||
x = self.some_other_model(x) | ||
return x | ||
To use in production or for predictions: | ||
|
||
.. code-block:: python | ||
from pytorch_lightning_bolts.models.autoencoders import VAE | ||
vae = VAE.load_from_checkpoint(PATH) | ||
vae.freeze() | ||
z = ... # z ~ N(0, 1) | ||
predictions = vae(z) | ||
Research use case | ||
----------------- | ||
You can train the VAE on its own: | ||
|
||
.. code-block:: python | ||
from pytorch_lightning_bolts.models.autoencoders import VAE | ||
import pytorch_lightning as pl | ||
vae = VAE() | ||
trainer = pl.Trainer(gpus=1) | ||
trainer.fit(vae) | ||
You can also use as template for research (example of modifying only the prior): | ||
|
||
.. code-block:: python | ||
from pytorch_lightning_bolts.models.autoencoders import VAE | ||
class MyVAEFlavor(VAE): | ||
def get_posterior(self, mu, std): | ||
# do something other than the default | ||
# P = self.get_distribution(self.prior, loc=torch.zeros_like(mu), scale=torch.ones_like(std)) | ||
return P | ||
Or pass in your own encoders and decoders: | ||
|
||
.. code-block:: python | ||
from pytorch_lightning_bolts.models.autoencoders import VAE | ||
import pytorch_lightning as pl | ||
encoder = MyEncoder() | ||
decoder = MyDecoder() | ||
vae = VAE(encoder=encoder, decoder=decoder) | ||
trainer = pl.Trainer(gpus=1) | ||
trainer.fit(vae) | ||
Train the VAE from the command line: | ||
|
||
.. code-block:: python | ||
cd pytorch_lightning_bolts/models/autoencoders/basic_vae | ||
python vae.py | ||
The vae.py script accepts the following arguments: | ||
|
||
.. code-block:: bash | ||
optional arguments: | ||
--hidden_dim if using default encoder/decoder - dimension of itermediate (dense) layers before embedding | ||
--latent_dim dimension of latent variables z | ||
--input_width input image width (must be even) - 28 for MNIST | ||
--input_height input image height (must be even) - 28 for MNIST | ||
--batch_size | ||
any arguments from pl.Trainer - e.g max_epochs, gpus | ||
.. code-block:: bash | ||
python vae.py --hidden_dim 128 --latent_dim 32 --batch_size 32 --gpus 4 --max_epochs 12 | ||
--------------- | ||
Autoencoders | ||
------------ | ||
The following is a collection of auto-encoders. | ||
Basic AE | ||
^^^^^^^^ | ||
This is the simplest autoencoder. You can use it like so | ||
.. code-block:: python | ||
from pytorch_lightning.models.autoencoders import AE | ||
model = AE() | ||
trainer = Trainer() | ||
trainer.fit(model) | ||
You can override any part of this AE to build your own variation. | ||
.. code-block:: python | ||
from pytorch_lightning_bolts.models.autoencoders import AE | ||
class MyAEFlavor(AE): | ||
def init_encoder(self, hidden_dim, latent_dim, input_width, input_height): | ||
encoder = YourSuperFancyEncoder(...) | ||
return encoder | ||
.. autoclass:: pl_bolts.models.autoencoders.AE | ||
:noindex: | ||
--------------- | ||
Variational Autoencoders | ||
------------------------ | ||
Basic VAE | ||
^^^^^^^^^ | ||
Use the VAE like so. | ||
.. code-block:: python | ||
from pytorch_lightning.models.autoencoders import VAE | ||
model = VAE() | ||
trainer = Trainer() | ||
trainer.fit(model) | ||
You can override any part of this VAE to build your own variation. | ||
.. code-block:: python | ||
from pytorch_lightning_bolts.models.autoencoders import VAE | ||
class MyVAEFlavor(VAE): | ||
def get_posterior(self, mu, std): | ||
# do something other than the default | ||
# P = self.get_distribution(self.prior, loc=torch.zeros_like(mu), scale=torch.ones_like(std)) | ||
return P | ||
.. autoclass:: pl_bolts.models.autoencoders.VAE | ||
:noindex: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
GANs | ||
==== | ||
Stuff about gans | ||
|
||
Basic GAN | ||
--------- | ||
This is a basic GAN. | ||
|
||
|
||
Example:: | ||
|
||
from pytorch_lightning.models.gans import BasicGAN | ||
|
||
gan = BasicGAN() | ||
trainer = Trainer() | ||
trainer.fit(gan) | ||
|
||
|
||
.. autoclass:: pl_bolts.models.gans.BasicGAN | ||
:noindex: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
Losses | ||
====== | ||
This package lists common losses across research domains | ||
|
||
------------- | ||
|
||
Self-supervised | ||
--------------- | ||
Here are losses for popular self-supervised approaches | ||
|
||
NCE Loss | ||
^^^^^^^^^ | ||
Used in AMDIM | ||
|
||
.. autoclass:: pl_bolts.losses.self_supervised_learning.AmdimNCELoss | ||
:noindex: |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
Self-supervised | ||
=============== | ||
|
||
CPC (V2) | ||
-------- | ||
PyTorch implementation of `Data-Efficient Image Recognition with Contrastive Predictive Coding <https://arxiv.org/abs/1905.09272>`_ | ||
by (Olivier J. Hénaff, Aravind Srinivas, Jeffrey De Fauw, Ali Razavi, Carl Doersch, S. M. Ali Eslami, Aaron van den Oord). | ||
|
||
.. code-block:: python | ||
from pl_bolts.models.self_supervised import CPCV2 | ||
.. autoclass:: pl_bolts.models.self_supervised.CPCV2 | ||
:noindex: | ||
|
||
.. autoclass:: pl_bolts.models.self_supervised.AMDIM | ||
:noindex: | ||
|
||
.. autoclass:: pl_bolts.models.self_supervised.SimCLR | ||
:noindex: | ||
|
||
.. autoclass:: pl_bolts.models.self_supervised.MocoV2 | ||
:noindex: |
Oops, something went wrong.