Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the rest of pl_bolts.models.self_supervised #481

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions pl_bolts/models/self_supervised/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,15 @@
from pl_bolts.models.self_supervised.simclr.simclr_module import SimCLR # noqa: F401
from pl_bolts.models.self_supervised.ssl_finetuner import SSLFineTuner # noqa: F401
from pl_bolts.models.self_supervised.swav.swav_module import SwAV # noqa: F401


__all__ = [
"AMDIM",
"BYOL",
"CPCV2",
"SSLEvaluator"
"MocoV2",
"SimCLR",
"SSLFineTuner",
"SwAV",
]
13 changes: 13 additions & 0 deletions pl_bolts/models/self_supervised/cpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,16 @@
CPCTrainTransformsImageNet128,
CPCTrainTransformsSTL10,
)


__all__ = [
"CPCV2",
"cpc_resnet50",
"cpc_resnet101",
"CPCEvalTransformsCIFAR10",
"CPCEvalTransformsImageNet128",
"CPCEvalTransformsSTL10",
"CPCTrainTransformsCIFAR10",
"CPCTrainTransformsImageNet128",
"CPCTrainTransformsSTL10",
]
28 changes: 14 additions & 14 deletions pl_bolts/models/self_supervised/cpc/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

if _TORCHVISION_AVAILABLE:
from torchvision import transforms
else:
warn_missing_pkg('torchvision') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('torchvision')


class CPCTrainTransformsCIFAR10:
Expand Down Expand Up @@ -39,8 +39,8 @@ def __init__(self, patch_size=8, overlap=4):
patch_size: size of patches when cutting up the image into overlapping patches
overlap: how much to overlap patches
"""
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use `transforms` from `torchvision` which is not installed yet.'
)

Expand Down Expand Up @@ -97,8 +97,8 @@ def __init__(self, patch_size: int = 8, overlap: int = 4):
patch_size: size of patches when cutting up the image into overlapping patches
overlap: how much to overlap patches
"""
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use `transforms` from `torchvision` which is not installed yet.'
)

Expand Down Expand Up @@ -151,8 +151,8 @@ def __init__(self, patch_size: int = 16, overlap: int = 8):
patch_size: size of patches when cutting up the image into overlapping patches
overlap: how much to overlap patches
"""
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use `transforms` from `torchvision` which is not installed yet.'
)

Expand Down Expand Up @@ -210,8 +210,8 @@ def __init__(self, patch_size: int = 16, overlap: int = 8):
patch_size: size of patches when cutting up the image into overlapping patches
overlap: how much to overlap patches
"""
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use `transforms` from `torchvision` which is not installed yet.'
)

Expand Down Expand Up @@ -260,8 +260,8 @@ def __init__(self, patch_size: int = 32, overlap: int = 16):
patch_size: size of patches when cutting up the image into overlapping patches
overlap: how much to overlap patches
"""
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use `transforms` from `torchvision` which is not installed yet.'
)

Expand Down Expand Up @@ -320,8 +320,8 @@ def __init__(self, patch_size: int = 32, overlap: int = 16):
patch_size: size of patches when cutting up the image into overlapping patches
overlap: how much to overlap patches
"""
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use `transforms` from `torchvision` which is not installed yet.'
)

Expand Down
36 changes: 18 additions & 18 deletions pl_bolts/models/self_supervised/moco/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@

if _TORCHVISION_AVAILABLE:
from torchvision import transforms
else:
warn_missing_pkg('torchvision') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('torchvision')

if _PIL_AVAILABLE:
from PIL import ImageFilter
else:
warn_missing_pkg('PIL', pypi_name='Pillow') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('PIL', pypi_name='Pillow')


class Moco2TrainCIFAR10Transforms:
Expand All @@ -25,8 +25,8 @@ class Moco2TrainCIFAR10Transforms:
https://arxiv.org/pdf/2003.04297.pdf
"""
def __init__(self, height: int = 32):
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use `transforms` from `torchvision` which is not installed yet.'
)

Expand Down Expand Up @@ -55,8 +55,8 @@ class Moco2EvalCIFAR10Transforms:
https://arxiv.org/pdf/2003.04297.pdf
"""
def __init__(self, height: int = 32):
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use `transforms` from `torchvision` which is not installed yet.'
)

Expand All @@ -79,8 +79,8 @@ class Moco2TrainSTL10Transforms:
https://arxiv.org/pdf/2003.04297.pdf
"""
def __init__(self, height: int = 64):
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use `transforms` from `torchvision` which is not installed yet.'
)

Expand Down Expand Up @@ -109,8 +109,8 @@ class Moco2EvalSTL10Transforms:
https://arxiv.org/pdf/2003.04297.pdf
"""
def __init__(self, height: int = 64):
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use `transforms` from `torchvision` which is not installed yet.'
)

Expand All @@ -135,8 +135,8 @@ class Moco2TrainImagenetTransforms:
"""

def __init__(self, height: int = 128):
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use `transforms` from `torchvision` which is not installed yet.'
)

Expand Down Expand Up @@ -165,8 +165,8 @@ class Moco2EvalImagenetTransforms:
https://arxiv.org/pdf/2003.04297.pdf
"""
def __init__(self, height: int = 128):
if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use `transforms` from `torchvision` which is not installed yet.'
)

Expand All @@ -187,8 +187,8 @@ class GaussianBlur(object):
"""Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""

def __init__(self, sigma=(0.1, 2.0)):
if not _PIL_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _PIL_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use `Pillow` which is not installed yet, install it with `pip install Pillow`.'
)
self.sigma = sigma
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)


def cli_main(): # pragma: no-cover
def cli_main(): # pragma: no cover
from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule, STL10DataModule

pl.seed_everything(1234)
Expand Down
17 changes: 11 additions & 6 deletions pl_bolts/models/self_supervised/simclr/transforms.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import numpy as np

from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils import _TORCHVISION_AVAILABLE, _OPENCV_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _TORCHVISION_AVAILABLE:
import torchvision.transforms as transforms
else:
warn_missing_pkg('torchvision') # pragma: no-cover

try:
if _OPENCV_AVAILABLE:
import cv2
except ModuleNotFoundError:
warn_missing_pkg('cv2', pypi_name='opencv-python') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('cv2', pypi_name='opencv-python')


class SimCLRTrainDataTransform(object):
Expand Down Expand Up @@ -43,8 +43,8 @@ def __init__(
normalize=None
) -> None:

if not _TORCHVISION_AVAILABLE:
raise ModuleNotFoundError( # pragma: no-cover
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use `transforms` from `torchvision` which is not installed yet.'
)

Expand Down Expand Up @@ -187,6 +187,11 @@ def __call__(self, sample):
class GaussianBlur(object):
# Implements Gaussian blur as described in the SimCLR paper
def __init__(self, kernel_size, p=0.5, min=0.1, max=2.0):
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
'You want to use `GaussianBlur` from `cv2` which is not installed yet.'
)

self.min = min
self.max = max

Expand Down
10 changes: 10 additions & 0 deletions pl_bolts/models/self_supervised/swav/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,13 @@
SwAVFinetuneTransform,
SwAVTrainDataTransform,
)


__all__ = [
"SwAV",
"resnet18",
"resnet50",
"SwAVEvalDataTransform",
"SwAVFinetuneTransform",
"SwAVTrainDataTransform",
]
2 changes: 1 addition & 1 deletion pl_bolts/models/self_supervised/swav/swav_finetuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pl_bolts.transforms.dataset_normalizations import imagenet_normalization, stl10_normalization


def cli_main(): # pragma: no-cover
def cli_main(): # pragma: no cover
from pl_bolts.datamodules import ImagenetDataModule, STL10DataModule

pl.seed_everything(1234)
Expand Down
8 changes: 4 additions & 4 deletions pl_bolts/models/self_supervised/swav/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@

if _TORCHVISION_AVAILABLE:
import torchvision.transforms as transforms
else:
warn_missing_pkg('torchvision') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('torchvision')

if _OPENCV_AVAILABLE:
import cv2
else:
warn_missing_pkg('cv2', pypi_name='opencv-python') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('cv2', pypi_name='opencv-python')


class SwAVTrainDataTransform(object):
Expand Down