Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move SSL transforms to pl_bolts/transforms #905

Merged
merged 37 commits into from
May 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
0cab7a0
Move SSL transforms to pl_bolts/transforms
matsumotosan Oct 11, 2022
2aeeb32
Update self supervised docs
matsumotosan Oct 11, 2022
2e9d2fd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 27, 2022
56f2db0
Merge branch 'master' into ssl_transforms
Borda Oct 27, 2022
e1e0176
Merge branch 'master' into ssl_transforms
matsumotosan Oct 28, 2022
026325a
Update self-supervised transforms docs.
matsumotosan Oct 28, 2022
ec0868f
Add simclr transforms doc strings and typing
matsumotosan Oct 28, 2022
e015d8f
Fix swav_transform import error. Change assert to assertion error.
matsumotosan Oct 28, 2022
a720809
call train_transform
matsumotosan Oct 28, 2022
50d459b
Merge branch 'master' into ssl_transforms
matsumotosan Oct 31, 2022
a85c81c
Merge branch 'master' into ssl_transforms
matsumotosan Nov 1, 2022
7b68830
Merge branch 'master' into ssl_transforms
matsumotosan Nov 1, 2022
f6244db
Merge branch 'master' into ssl_transforms
otaj Nov 2, 2022
be6da17
Merge branch 'master' into ssl_transforms
matsumotosan Nov 8, 2022
c455501
Merge branch 'master' into ssl_transforms
matsumotosan Dec 25, 2022
937581a
Fix gaussian_blur super init arg
matsumotosan Nov 1, 2022
421cd1a
Update moco transforms docs
matsumotosan Dec 25, 2022
a9719c1
Transfer MoCo transforms. MoCo tests passing.
matsumotosan Jan 5, 2023
3f26454
Fix interpolation mode deprecation warning
matsumotosan Jan 6, 2023
7062702
CPC_transforms typing and interpolation deprecation fix.
matsumotosan Jan 6, 2023
cb045e0
MoCo transforms typing hints
matsumotosan Jan 6, 2023
b0fbdf8
Keep under_review tags for unreviewed files
matsumotosan Jan 6, 2023
651f51b
Merge branch 'master' into ssl_transforms
matsumotosan Jan 6, 2023
7b60348
Fix MoCo docs error
matsumotosan Jan 6, 2023
d0a6857
Merge branch 'master' into ssl_transforms
matsumotosan Feb 23, 2023
6a83014
Merge branch 'master' into ssl_transforms
matsumotosan Feb 25, 2023
bd23c27
update mergify team
Borda May 19, 2023
f81a2d7
Merge branch 'master' into ssl_transforms
Borda May 19, 2023
e3ab905
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 19, 2023
afdc09a
update mergify team
Borda May 19, 2023
139e3c0
Merge branch 'master' into ssl_transforms
Borda May 19, 2023
76ad929
Merge branch 'master' into ssl_transforms
Borda May 19, 2023
472a271
Merge branch 'master' into ssl_transforms
mergify[bot] May 19, 2023
6c519cd
Merge branch 'master' into ssl_transforms
mergify[bot] May 20, 2023
b33ec0b
Merge branch 'master' into ssl_transforms
mergify[bot] May 20, 2023
ca87957
Merge branch 'master' into ssl_transforms
Borda May 20, 2023
383ecfc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 13 additions & 6 deletions docs/source/models/self_supervised.rst
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ These models are perfect for training from scratch when you have a huge set of u
.. code-block:: python

from pl_bolts.models.self_supervised import SimCLR
from pl_bolts.models.self_supervised.simclr import SimCLREvalDataTransform, SimCLRTrainDataTransform

from pl_bolts.transforms.self_supervised.simclr_transforms import (
SimCLREvalDataTransform,
SimCLRTrainDataTransform
)

train_dataset = MyDataset(transforms=SimCLRTrainDataTransform())
val_dataset = MyDataset(transforms=SimCLREvalDataTransform())
Expand Down Expand Up @@ -120,8 +122,10 @@ To Train::
import pytorch_lightning as pl
from pl_bolts.models.self_supervised import CPC_v2
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised.cpc import (
CPCTrainTransformsCIFAR10, CPCEvalTransformsCIFAR10)
from pl_bolts.transforms.self_supervised.cpc_transforms import (
CPCTrainTransformsCIFAR10,
CPCEvalTransformsCIFAR10
)

# data
dm = CIFAR10DataModule(num_workers=0)
Expand Down Expand Up @@ -277,7 +281,9 @@ To Train::
from pl_bolts.models.self_supervised import SimCLR
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.models.self_supervised.simclr.transforms import (
SimCLREvalDataTransform, SimCLRTrainDataTransform)
SimCLREvalDataTransform,
SimCLRTrainDataTransform
)

# data
dm = CIFAR10DataModule(num_workers=0)
Expand Down Expand Up @@ -466,7 +472,8 @@ To Train::
from pl_bolts.models.self_supervised import SwAV
from pl_bolts.datamodules import STL10DataModule
from pl_bolts.models.self_supervised.swav.transforms import (
SwAVTrainDataTransform, SwAVEvalDataTransform
SwAVTrainDataTransform,
SwAVEvalDataTransform
)
from pl_bolts.transforms.dataset_normalizations import stl10_normalization

Expand Down
40 changes: 20 additions & 20 deletions docs/source/transforms/self_supervised.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,32 +20,32 @@ Transforms used for CPC
CIFAR-10 Train (c)
^^^^^^^^^^^^^^^^^^^^

.. autoclass:: pl_bolts.models.self_supervised.cpc.transforms.CPCTrainTransformsCIFAR10
.. autoclass:: pl_bolts.transforms.self_supervised.cpc_transforms.CPCTrainTransformsCIFAR10
:noindex:

CIFAR-10 Eval (c)
^^^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.cpc.transforms.CPCEvalTransformsCIFAR10
.. autoclass:: pl_bolts.transforms.self_supervised.cpc_transforms.CPCEvalTransformsCIFAR10
:noindex:

Imagenet Train (c)
^^^^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.cpc.transforms.CPCTrainTransformsImageNet128
.. autoclass:: pl_bolts.transforms.self_supervised.cpc_transforms.CPCTrainTransformsImageNet128
:noindex:

Imagenet Eval (c)
^^^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.cpc.transforms.CPCEvalTransformsImageNet128
.. autoclass:: pl_bolts.transforms.self_supervised.cpc_transforms.CPCEvalTransformsImageNet128
:noindex:

STL-10 Train (c)
^^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.cpc.transforms.CPCTrainTransformsSTL10
.. autoclass:: pl_bolts.transforms.self_supervised.cpc_transforms.CPCTrainTransformsSTL10
:noindex:

STL-10 Eval (c)
^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.cpc.transforms.CPCEvalTransformsSTL10
.. autoclass:: pl_bolts.transforms.self_supervised.cpc_transforms.CPCEvalTransformsSTL10
:noindex:

AMDIM transforms
Expand All @@ -56,32 +56,32 @@ Transforms used for AMDIM
CIFAR-10 Train (a)
^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: pl_bolts.models.self_supervised.amdim.transforms.AMDIMTrainTransformsCIFAR10
.. autoclass:: pl_bolts.transforms.self_supervised.amdim_transforms.AMDIMTrainTransformsCIFAR10
:noindex:

CIFAR-10 Eval (a)
^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.amdim.transforms.AMDIMEvalTransformsCIFAR10
.. autoclass:: pl_bolts.transforms.self_supervised.amdim_transforms.AMDIMEvalTransformsCIFAR10
:noindex:

Imagenet Train (a)
^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.amdim.transforms.AMDIMTrainTransformsImageNet128
.. autoclass:: pl_bolts.transforms.self_supervised.amdim_transforms.AMDIMTrainTransformsImageNet128
:noindex:

Imagenet Eval (a)
^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.amdim.transforms.AMDIMEvalTransformsImageNet128
.. autoclass:: pl_bolts.transforms.self_supervised.amdim_transforms.AMDIMEvalTransformsImageNet128
:noindex:

STL-10 Train (a)
^^^^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.amdim.transforms.AMDIMTrainTransformsSTL10
.. autoclass:: pl_bolts.transforms.self_supervised.amdim_transforms.AMDIMTrainTransformsSTL10
:noindex:

STL-10 Eval (a)
^^^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.amdim.transforms.AMDIMEvalTransformsSTL10
.. autoclass:: pl_bolts.transforms.self_supervised.amdim_transforms.AMDIMEvalTransformsSTL10
:noindex:

MOCO V2 transforms
Expand All @@ -92,32 +92,32 @@ Transforms used for MOCO V2
CIFAR-10 Train (m2)
^^^^^^^^^^^^^^^^^^^^^^

.. autoclass:: pl_bolts.models.self_supervised.moco.transforms.Moco2TrainCIFAR10Transforms
.. autoclass:: pl_bolts.transforms.self_supervised.moco_transforms.MoCo2TrainCIFAR10Transforms
:noindex:

CIFAR-10 Eval (m2)
^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.moco.transforms.Moco2EvalCIFAR10Transforms
.. autoclass:: pl_bolts.transforms.self_supervised.moco_transforms.MoCo2EvalCIFAR10Transforms
:noindex:

Imagenet Train (m2)
^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.moco.transforms.Moco2TrainSTL10Transforms
.. autoclass:: pl_bolts.transforms.self_supervised.moco_transforms.MoCo2TrainSTL10Transforms
:noindex:

Imagenet Eval (m2)
^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.moco.transforms.Moco2EvalSTL10Transforms
.. autoclass:: pl_bolts.transforms.self_supervised.moco_transforms.MoCo2EvalSTL10Transforms
:noindex:

STL-10 Train (m2)
^^^^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.moco.transforms.Moco2TrainImagenetTransforms
.. autoclass:: pl_bolts.transforms.self_supervised.moco_transforms.MoCo2TrainImagenetTransforms
:noindex:

STL-10 Eval (m2)
^^^^^^^^^^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.moco.transforms.Moco2EvalImagenetTransforms
.. autoclass:: pl_bolts.transforms.self_supervised.moco_transforms.MoCo2EvalImagenetTransforms
:noindex:

SimCLR transforms
Expand All @@ -126,12 +126,12 @@ Transforms used for SimCLR

Train (sc)
^^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.simclr.transforms.SimCLRTrainDataTransform
.. autoclass:: pl_bolts.transforms.self_supervised.simclr_transforms.SimCLRTrainDataTransform
:noindex:

Eval (sc)
^^^^^^^^^
.. autoclass:: pl_bolts.models.self_supervised.simclr.transforms.SimCLREvalDataTransform
.. autoclass:: pl_bolts.transforms.self_supervised.simclr_transforms.SimCLREvalDataTransform
:noindex:


Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/models/self_supervised/amdim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pl_bolts.models.self_supervised.amdim.amdim_module import AMDIM
from pl_bolts.models.self_supervised.amdim.networks import AMDIMEncoder
from pl_bolts.models.self_supervised.amdim.transforms import (
from pl_bolts.transforms.self_supervised.amdim_transforms import (
AMDIMEvalTransformsCIFAR10,
AMDIMEvalTransformsImageNet128,
AMDIMEvalTransformsSTL10,
Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/models/self_supervised/amdim/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from torch.utils.data import random_split

from pl_bolts.datasets import CIFAR10Mixed, UnlabeledImagenet
from pl_bolts.models.self_supervised.amdim import transforms as amdim_transforms
from pl_bolts.transforms.self_supervised import amdim_transforms
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review
from pl_bolts.utils.warnings import warn_missing_pkg
Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/models/self_supervised/cpc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pl_bolts.models.self_supervised.cpc.cpc_module import CPC_v2
from pl_bolts.models.self_supervised.cpc.networks import cpc_resnet50, cpc_resnet101
from pl_bolts.models.self_supervised.cpc.transforms import (
from pl_bolts.transforms.self_supervised.cpc_transforms import (
CPCEvalTransformsCIFAR10,
CPCEvalTransformsImageNet128,
CPCEvalTransformsSTL10,
Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/models/self_supervised/cpc/cpc_finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pytorch_lightning import Trainer, seed_everything

from pl_bolts.models.self_supervised import CPC_v2, SSLFineTuner
from pl_bolts.models.self_supervised.cpc.transforms import (
from pl_bolts.transforms.self_supervised.cpc_transforms import (
CPCEvalTransformsCIFAR10,
CPCEvalTransformsSTL10,
CPCTrainTransformsCIFAR10,
Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/models/self_supervised/cpc/cpc_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pl_bolts.datamodules.stl10_datamodule import STL10DataModule
from pl_bolts.losses.self_supervised_learning import CPCTask
from pl_bolts.models.self_supervised.cpc.networks import cpc_resnet101
from pl_bolts.models.self_supervised.cpc.transforms import (
from pl_bolts.transforms.self_supervised.cpc_transforms import (
CPCEvalTransformsCIFAR10,
CPCEvalTransformsImageNet128,
CPCEvalTransformsSTL10,
Expand Down
14 changes: 7 additions & 7 deletions src/pl_bolts/models/self_supervised/moco/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from pl_bolts.models.self_supervised.moco.transforms import ( # noqa: F401
Moco2EvalCIFAR10Transforms,
Moco2EvalImagenetTransforms,
Moco2EvalSTL10Transforms,
Moco2TrainCIFAR10Transforms,
Moco2TrainImagenetTransforms,
Moco2TrainSTL10Transforms,
from pl_bolts.transforms.self_supervised.moco_transforms import ( # noqa: F401
MoCo2EvalCIFAR10Transforms,
MoCo2EvalImagenetTransforms,
MoCo2EvalSTL10Transforms,
MoCo2TrainCIFAR10Transforms,
MoCo2TrainImagenetTransforms,
MoCo2TrainSTL10Transforms,
)
26 changes: 13 additions & 13 deletions src/pl_bolts/models/self_supervised/moco/moco2_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
from torch.nn import functional as F

from pl_bolts.metrics import mean, precision_at_k
from pl_bolts.models.self_supervised.moco.transforms import (
Moco2EvalCIFAR10Transforms,
Moco2EvalImagenetTransforms,
Moco2EvalSTL10Transforms,
Moco2TrainCIFAR10Transforms,
Moco2TrainImagenetTransforms,
Moco2TrainSTL10Transforms,
from pl_bolts.transforms.self_supervised.moco_transforms import (
MoCo2EvalCIFAR10Transforms,
MoCo2EvalImagenetTransforms,
MoCo2EvalSTL10Transforms,
MoCo2TrainCIFAR10Transforms,
MoCo2TrainImagenetTransforms,
MoCo2TrainSTL10Transforms,
)
from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.stability import under_review
Expand Down Expand Up @@ -372,20 +372,20 @@ def cli_main():

if args.dataset == "cifar10":
datamodule = CIFAR10DataModule.from_argparse_args(args)
datamodule.train_transforms = Moco2TrainCIFAR10Transforms()
datamodule.val_transforms = Moco2EvalCIFAR10Transforms()
datamodule.train_transforms = MoCo2TrainCIFAR10Transforms()
datamodule.val_transforms = MoCo2EvalCIFAR10Transforms()

elif args.dataset == "stl10":
datamodule = STL10DataModule.from_argparse_args(args)
datamodule.train_dataloader = datamodule.train_dataloader_mixed
datamodule.val_dataloader = datamodule.val_dataloader_mixed
datamodule.train_transforms = Moco2TrainSTL10Transforms()
datamodule.val_transforms = Moco2EvalSTL10Transforms()
datamodule.train_transforms = MoCo2TrainSTL10Transforms()
datamodule.val_transforms = MoCo2EvalSTL10Transforms()

elif args.dataset == "imagenet2012":
datamodule = SSLImagenetDataModule.from_argparse_args(args)
datamodule.train_transforms = Moco2TrainImagenetTransforms()
datamodule.val_transforms = Moco2EvalImagenetTransforms()
datamodule.train_transforms = MoCo2TrainImagenetTransforms()
datamodule.val_transforms = MoCo2EvalImagenetTransforms()

else:
# replace with your own dataset, otherwise CIFAR-10 will be used by default if `None` passed in
Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/models/self_supervised/simclr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pl_bolts.models.self_supervised.simclr.transforms import ( # noqa: F401
from pl_bolts.transforms.self_supervised.simclr_transforms import ( # noqa: F401
SimCLREvalDataTransform,
SimCLRTrainDataTransform,
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from pytorch_lightning import Trainer, seed_everything

from pl_bolts.models.self_supervised.simclr.simclr_module import SimCLR
from pl_bolts.models.self_supervised.simclr.transforms import SimCLRFinetuneTransform
from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner
from pl_bolts.transforms.dataset_normalizations import (
cifar10_normalization,
imagenet_normalization,
stl10_normalization,
)
from pl_bolts.transforms.self_supervised.simclr_transforms import SimCLRFinetuneTransform
from pl_bolts.utils.stability import under_review


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def add_model_specific_args(parent_parser):
def cli_main():
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule
from pl_bolts.models.self_supervised.simclr.transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform
from pl_bolts.transforms.self_supervised.simclr_transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform

parser = ArgumentParser()

Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/models/self_supervised/swav/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pl_bolts.models.self_supervised.swav.loss import SWAVLoss
from pl_bolts.models.self_supervised.swav.swav_module import SwAV
from pl_bolts.models.self_supervised.swav.swav_resnet import resnet18, resnet50
from pl_bolts.models.self_supervised.swav.transforms import (
from pl_bolts.transforms.self_supervised.swav_transforms import (
SwAVEvalDataTransform,
SwAVFinetuneTransform,
SwAVTrainDataTransform,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner
from pl_bolts.models.self_supervised.swav.swav_module import SwAV
from pl_bolts.models.self_supervised.swav.transforms import SwAVFinetuneTransform
from pl_bolts.transforms.dataset_normalizations import imagenet_normalization, stl10_normalization
from pl_bolts.transforms.self_supervised.swav_transforms import SwAVFinetuneTransform


def cli_main(): # pragma: no cover
Expand Down
2 changes: 1 addition & 1 deletion src/pl_bolts/models/self_supervised/swav/swav_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def add_model_specific_args(parent_parser):
def cli_main():
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator
from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule
from pl_bolts.models.self_supervised.swav.transforms import SwAVEvalDataTransform, SwAVTrainDataTransform
from pl_bolts.transforms.self_supervised.swav_transforms import SwAVEvalDataTransform, SwAVTrainDataTransform

parser = ArgumentParser()

Expand Down