Skip to content

Commit

Permalink
Move SSL transforms to pl_bolts/transforms (#905)
Browse files Browse the repository at this point in the history
* Move SSL transforms to pl_bolts/transforms
* Update self supervised docs
* Update self-supervised transforms docs
* Add simclr transforms doc strings and typing
* Fix swav_transform import error. Change assert to assertion error
* call train_transform
* Fix gaussian_blur super init arg
* Update moco transforms docs
* Transfer MoCo transforms. MoCo tests passing.
* Fix interpolation mode deprecation warning
* CPC_transforms typing and interpolation deprecation fix.
* MoCo transforms typing hints
* Keep under_review tags for unreviewed files
* Fix MoCo docs error

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: otaj <6065855+otaj@users.noreply.github.com>
Co-authored-by: Jirka B <j.borovec+github@gmail.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
7 people committed May 20, 2023
1 parent fbdaef6 commit 5669578
Show file tree
Hide file tree
Showing 23 changed files with 304 additions and 180 deletions.
19 changes: 13 additions & 6 deletions docs/source/models/self_supervised.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ These models are perfect for training from scratch when you have a huge set of u
.. code-block:: python
from pl_bolts.models.self_supervised import SimCLR
from pl_bolts.models.self_supervised.simclr import SimCLREvalDataTransform, SimCLRTrainDataTransform
from pl_bolts.transforms.self_supervised.simclr_transforms import (
SimCLREvalDataTransform,
SimCLRTrainDataTransform
)
train_dataset = MyDataset(transforms=SimCLRTrainDataTransform())
val_dataset = MyDataset(transforms=SimCLREvalDataTransform())
Expand Down Expand Up @@ -120,8 +122,10 @@ To Train::
import pytorch_lightning as pl
from pl_bolts.models.self_supervised import CPC_v2
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised.cpc import (
CPCTrainTransformsCIFAR10, CPCEvalTransformsCIFAR10)
from pl_bolts.transforms.self_supervised.cpc_transforms import (
CPCTrainTransformsCIFAR10,
CPCEvalTransformsCIFAR10
)

# data
dm = CIFAR10DataModule(num_workers=0)
Expand Down Expand Up @@ -277,7 +281,9 @@ To Train::
from pl_bolts.models.self_supervised import SimCLR
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised.simclr.transforms import (
SimCLREvalDataTransform, SimCLRTrainDataTransform)
SimCLREvalDataTransform,
SimCLRTrainDataTransform
)

# data
dm = CIFAR10DataModule(num_workers=0)
Expand Down Expand Up @@ -466,7 +472,8 @@ To Train::
from pl_bolts.models.self_supervised import SwAV
from pl_bolts.datamodules import STL10DataModule
from pl_bolts.models.self_supervised.swav.transforms import (
SwAVTrainDataTransform, SwAVEvalDataTransform
SwAVTrainDataTransform,
SwAVEvalDataTransform
)
from pl_bolts.transforms.dataset_normalizations import stl10_normalization

Expand Down
40 changes: 20 additions & 20 deletions docs/source/transforms/self_supervised.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,32 +20,32 @@ Transforms used for CPC
CIFAR-10 Train (c)
^^^^^^^^^^^^^^^^^^^^

.. autoclass:: pl_bolts.models.self_supervised.cpc.transforms.CPCTrainTransformsCIFAR10
.. autoclass:: pl_bolts.transforms.self_supervised.cpc_transforms.CPCTrainTransformsCIFAR10
:noindex:

CIFAR-10 Eval (c)
^^^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.cpc.transforms.CPCEvalTransformsCIFAR10
.. autoclass:: pl_bolts.transforms.self_supervised.cpc_transforms.CPCEvalTransformsCIFAR10
:noindex:

Imagenet Train (c)
^^^^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.cpc.transforms.CPCTrainTransformsImageNet128
.. autoclass:: pl_bolts.transforms.self_supervised.cpc_transforms.CPCTrainTransformsImageNet128
:noindex:

Imagenet Eval (c)
^^^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.cpc.transforms.CPCEvalTransformsImageNet128
.. autoclass:: pl_bolts.transforms.self_supervised.cpc_transforms.CPCEvalTransformsImageNet128
:noindex:

STL-10 Train (c)
^^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.cpc.transforms.CPCTrainTransformsSTL10
.. autoclass:: pl_bolts.transforms.self_supervised.cpc_transforms.CPCTrainTransformsSTL10
:noindex:

STL-10 Eval (c)
^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.cpc.transforms.CPCEvalTransformsSTL10
.. autoclass:: pl_bolts.transforms.self_supervised.cpc_transforms.CPCEvalTransformsSTL10
:noindex:

AMDIM transforms
Expand All @@ -56,32 +56,32 @@ Transforms used for AMDIM
CIFAR-10 Train (a)
^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: pl_bolts.models.self_supervised.amdim.transforms.AMDIMTrainTransformsCIFAR10
.. autoclass:: pl_bolts.transforms.self_supervised.amdim_transforms.AMDIMTrainTransformsCIFAR10
:noindex:

CIFAR-10 Eval (a)
^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.amdim.transforms.AMDIMEvalTransformsCIFAR10
.. autoclass:: pl_bolts.transforms.self_supervised.amdim_transforms.AMDIMEvalTransformsCIFAR10
:noindex:

Imagenet Train (a)
^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.amdim.transforms.AMDIMTrainTransformsImageNet128
.. autoclass:: pl_bolts.transforms.self_supervised.amdim_transforms.AMDIMTrainTransformsImageNet128
:noindex:

Imagenet Eval (a)
^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.amdim.transforms.AMDIMEvalTransformsImageNet128
.. autoclass:: pl_bolts.transforms.self_supervised.amdim_transforms.AMDIMEvalTransformsImageNet128
:noindex:

STL-10 Train (a)
^^^^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.amdim.transforms.AMDIMTrainTransformsSTL10
.. autoclass:: pl_bolts.transforms.self_supervised.amdim_transforms.AMDIMTrainTransformsSTL10
:noindex:

STL-10 Eval (a)
^^^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.amdim.transforms.AMDIMEvalTransformsSTL10
.. autoclass:: pl_bolts.transforms.self_supervised.amdim_transforms.AMDIMEvalTransformsSTL10
:noindex:

MOCO V2 transforms
Expand All @@ -92,32 +92,32 @@ Transforms used for MOCO V2
CIFAR-10 Train (m2)
^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: pl_bolts.models.self_supervised.moco.transforms.Moco2TrainCIFAR10Transforms
.. autoclass:: pl_bolts.transforms.self_supervised.moco_transforms.MoCo2TrainCIFAR10Transforms
:noindex:

CIFAR-10 Eval (m2)
^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.moco.transforms.Moco2EvalCIFAR10Transforms
.. autoclass:: pl_bolts.transforms.self_supervised.moco_transforms.MoCo2EvalCIFAR10Transforms
:noindex:

Imagenet Train (m2)
^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.moco.transforms.Moco2TrainSTL10Transforms
.. autoclass:: pl_bolts.transforms.self_supervised.moco_transforms.MoCo2TrainSTL10Transforms
:noindex:

Imagenet Eval (m2)
^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.moco.transforms.Moco2EvalSTL10Transforms
.. autoclass:: pl_bolts.transforms.self_supervised.moco_transforms.MoCo2EvalSTL10Transforms
:noindex:

STL-10 Train (m2)
^^^^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.moco.transforms.Moco2TrainImagenetTransforms
.. autoclass:: pl_bolts.transforms.self_supervised.moco_transforms.MoCo2TrainImagenetTransforms
:noindex:

STL-10 Eval (m2)
^^^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.moco.transforms.Moco2EvalImagenetTransforms
.. autoclass:: pl_bolts.transforms.self_supervised.moco_transforms.MoCo2EvalImagenetTransforms
:noindex:

SimCLR transforms
Expand All @@ -126,12 +126,12 @@ Transforms used for SimCLR

Train (sc)
^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.simclr.transforms.SimCLRTrainDataTransform
.. autoclass:: pl_bolts.transforms.self_supervised.simclr_transforms.SimCLRTrainDataTransform
:noindex:

Eval (sc)
^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.simclr.transforms.SimCLREvalDataTransform
.. autoclass:: pl_bolts.transforms.self_supervised.simclr_transforms.SimCLREvalDataTransform
:noindex:


Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/models/self_supervised/amdim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pl_bolts.models.self_supervised.amdim.amdim_module import AMDIM
from pl_bolts.models.self_supervised.amdim.networks import AMDIMEncoder
from pl_bolts.models.self_supervised.amdim.transforms import (
from pl_bolts.transforms.self_supervised.amdim_transforms import (
AMDIMEvalTransformsCIFAR10,
AMDIMEvalTransformsImageNet128,
AMDIMEvalTransformsSTL10,
Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/models/self_supervised/amdim/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch.utils.data import random_split

from pl_bolts.datasets import CIFAR10Mixed, UnlabeledImagenet
from pl_bolts.models.self_supervised.amdim import transforms as amdim_transforms
from pl_bolts.transforms.self_supervised import amdim_transforms
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review
from pl_bolts.utils.warnings import warn_missing_pkg
Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/models/self_supervised/cpc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pl_bolts.models.self_supervised.cpc.cpc_module import CPC_v2
from pl_bolts.models.self_supervised.cpc.networks import cpc_resnet50, cpc_resnet101
from pl_bolts.models.self_supervised.cpc.transforms import (
from pl_bolts.transforms.self_supervised.cpc_transforms import (
CPCEvalTransformsCIFAR10,
CPCEvalTransformsImageNet128,
CPCEvalTransformsSTL10,
Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pytorch_lightning import Trainer, seed_everything

from pl_bolts.models.self_supervised import CPC_v2, SSLFineTuner
from pl_bolts.models.self_supervised.cpc.transforms import (
from pl_bolts.transforms.self_supervised.cpc_transforms import (
CPCEvalTransformsCIFAR10,
CPCEvalTransformsSTL10,
CPCTrainTransformsCIFAR10,
Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/models/self_supervised/cpc/cpc_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pl_bolts.datamodules.stl10_datamodule import STL10DataModule
from pl_bolts.losses.self_supervised_learning import CPCTask
from pl_bolts.models.self_supervised.cpc.networks import cpc_resnet101
from pl_bolts.models.self_supervised.cpc.transforms import (
from pl_bolts.transforms.self_supervised.cpc_transforms import (
CPCEvalTransformsCIFAR10,
CPCEvalTransformsImageNet128,
CPCEvalTransformsSTL10,
Expand Down
14 changes: 7 additions & 7 deletions src/pl_bolts/models/self_supervised/moco/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from pl_bolts.models.self_supervised.moco.transforms import ( # noqa: F401
Moco2EvalCIFAR10Transforms,
Moco2EvalImagenetTransforms,
Moco2EvalSTL10Transforms,
Moco2TrainCIFAR10Transforms,
Moco2TrainImagenetTransforms,
Moco2TrainSTL10Transforms,
from pl_bolts.transforms.self_supervised.moco_transforms import ( # noqa: F401
MoCo2EvalCIFAR10Transforms,
MoCo2EvalImagenetTransforms,
MoCo2EvalSTL10Transforms,
MoCo2TrainCIFAR10Transforms,
MoCo2TrainImagenetTransforms,
MoCo2TrainSTL10Transforms,
)
26 changes: 13 additions & 13 deletions src/pl_bolts/models/self_supervised/moco/moco2_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
from torch.nn import functional as F

from pl_bolts.metrics import mean, precision_at_k
from pl_bolts.models.self_supervised.moco.transforms import (
Moco2EvalCIFAR10Transforms,
Moco2EvalImagenetTransforms,
Moco2EvalSTL10Transforms,
Moco2TrainCIFAR10Transforms,
Moco2TrainImagenetTransforms,
Moco2TrainSTL10Transforms,
from pl_bolts.transforms.self_supervised.moco_transforms import (
MoCo2EvalCIFAR10Transforms,
MoCo2EvalImagenetTransforms,
MoCo2EvalSTL10Transforms,
MoCo2TrainCIFAR10Transforms,
MoCo2TrainImagenetTransforms,
MoCo2TrainSTL10Transforms,
)
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review
Expand Down Expand Up @@ -372,20 +372,20 @@ def cli_main():

if args.dataset == "cifar10":
datamodule = CIFAR10DataModule.from_argparse_args(args)
datamodule.train_transforms = Moco2TrainCIFAR10Transforms()
datamodule.val_transforms = Moco2EvalCIFAR10Transforms()
datamodule.train_transforms = MoCo2TrainCIFAR10Transforms()
datamodule.val_transforms = MoCo2EvalCIFAR10Transforms()

elif args.dataset == "stl10":
datamodule = STL10DataModule.from_argparse_args(args)
datamodule.train_dataloader = datamodule.train_dataloader_mixed
datamodule.val_dataloader = datamodule.val_dataloader_mixed
datamodule.train_transforms = Moco2TrainSTL10Transforms()
datamodule.val_transforms = Moco2EvalSTL10Transforms()
datamodule.train_transforms = MoCo2TrainSTL10Transforms()
datamodule.val_transforms = MoCo2EvalSTL10Transforms()

elif args.dataset == "imagenet2012":
datamodule = SSLImagenetDataModule.from_argparse_args(args)
datamodule.train_transforms = Moco2TrainImagenetTransforms()
datamodule.val_transforms = Moco2EvalImagenetTransforms()
datamodule.train_transforms = MoCo2TrainImagenetTransforms()
datamodule.val_transforms = MoCo2EvalImagenetTransforms()

else:
# replace with your own dataset, otherwise CIFAR-10 will be used by default if `None` passed in
Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/models/self_supervised/simclr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pl_bolts.models.self_supervised.simclr.transforms import ( # noqa: F401
from pl_bolts.transforms.self_supervised.simclr_transforms import ( # noqa: F401
SimCLREvalDataTransform,
SimCLRTrainDataTransform,
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from pytorch_lightning import Trainer, seed_everything

from pl_bolts.models.self_supervised.simclr.simclr_module import SimCLR
from pl_bolts.models.self_supervised.simclr.transforms import SimCLRFinetuneTransform
from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner
from pl_bolts.transforms.dataset_normalizations import (
cifar10_normalization,
imagenet_normalization,
stl10_normalization,
)
from pl_bolts.transforms.self_supervised.simclr_transforms import SimCLRFinetuneTransform
from pl_bolts.utils.stability import under_review


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def add_model_specific_args(parent_parser):
def cli_main():
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule
from pl_bolts.models.self_supervised.simclr.transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform
from pl_bolts.transforms.self_supervised.simclr_transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform

parser = ArgumentParser()

Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/models/self_supervised/swav/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pl_bolts.models.self_supervised.swav.loss import SWAVLoss
from pl_bolts.models.self_supervised.swav.swav_module import SwAV
from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50
from pl_bolts.models.self_supervised.swav.transforms import (
from pl_bolts.transforms.self_supervised.swav_transforms import (
SwAVEvalDataTransform,
SwAVFinetuneTransform,
SwAVTrainDataTransform,
Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/models/self_supervised/swav/swav_finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner
from pl_bolts.models.self_supervised.swav.swav_module import SwAV
from pl_bolts.models.self_supervised.swav.transforms import SwAVFinetuneTransform
from pl_bolts.transforms.dataset_normalizations import imagenet_normalization, stl10_normalization
from pl_bolts.transforms.self_supervised.swav_transforms import SwAVFinetuneTransform


def cli_main(): # pragma: no cover
Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/models/self_supervised/swav/swav_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def add_model_specific_args(parent_parser):
def cli_main():
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule
from pl_bolts.models.self_supervised.swav.transforms import SwAVEvalDataTransform, SwAVTrainDataTransform
from pl_bolts.transforms.self_supervised.swav_transforms import SwAVEvalDataTransform, SwAVTrainDataTransform

parser = ArgumentParser()

Expand Down

0 comments on commit 5669578

Please sign in to comment.