Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor refactors - cleaning models #524

Merged
merged 14 commits into from Jan 19, 2021
13 changes: 13 additions & 0 deletions pl_bolts/models/__init__.py
Expand Up @@ -8,3 +8,16 @@
from pl_bolts.models.regression import LinearRegression, LogisticRegression # noqa: F401
from pl_bolts.models.vision import PixelCNN, SemSegment, UNet # noqa: F401
from pl_bolts.models.vision.image_gpt.igpt_module import GPT2, ImageGPT # noqa: F401

__all__ = [
"AE",
"VAE",
"LitMNIST",
"LinearRegression",
"LogisticRegression",
"PixelCNN",
"SemSegment",
"UNet",
"GPT2",
"ImageGPT",
]
9 changes: 9 additions & 0 deletions pl_bolts/models/autoencoders/__init__.py
Expand Up @@ -11,3 +11,12 @@
resnet50_decoder,
resnet50_encoder,
)

__all__ = [
"AE",
"VAE",
"resnet18_decoder",
"resnet18_encoder",
"resnet50_decoder",
"resnet50_encoder",
]
11 changes: 4 additions & 7 deletions pl_bolts/models/detection/__init__.py
@@ -1,8 +1,5 @@
__all__ = []
from pl_bolts.models.detection.faster_rcnn import FasterRCNN # noqa: F401

try:
from pl_bolts.models.detection.faster_rcnn import FasterRCNN # noqa: F401
except ModuleNotFoundError: # pragma: no-cover
pass # pragma: no-cover
else:
__all__.append('FasterRCNN')
__all__ = [
"FasterRCNN",
]
13 changes: 10 additions & 3 deletions pl_bolts/models/detection/faster_rcnn.py
Expand Up @@ -3,20 +3,24 @@
import pytorch_lightning as pl
import torch

from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

try:
if _TORCHVISION_AVAILABLE:
from torchvision.models.detection import faster_rcnn, fasterrcnn_resnet50_fpn
from torchvision.ops import box_iou
except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('torchvision')


def _evaluate_iou(target, pred):
"""
Evaluate intersection over union (IOU) for target from dataset and output prediction
from model
"""
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('You want to use `torchvision` which is not installed yet.')

if pred["boxes"].shape[0] == 0:
# no box detected, 0 IOU
return torch.tensor(0.0, device=pred["boxes"].device)
Expand Down Expand Up @@ -61,6 +65,9 @@ def __init__(
pretrained_backbone: if true, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers: number of trainable resnet layers starting from final block
"""
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('You want to use `torchvision` which is not installed yet.')

super().__init__()

model = fasterrcnn_resnet50_fpn(
Expand Down
4 changes: 4 additions & 0 deletions pl_bolts/models/gans/__init__.py
@@ -1 +1,5 @@
from pl_bolts.models.gans.basic.basic_gan_module import GAN # noqa: F401

__all__ = [
"GAN",
]
10 changes: 7 additions & 3 deletions pl_bolts/models/mnist_module.py
Expand Up @@ -5,18 +5,22 @@
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split

from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

try:
if _TORCHVISION_AVAILABLE:
from torchvision import transforms
from torchvision.datasets import MNIST
except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('torchvision')


class LitMNIST(LightningModule):

def __init__(self, hidden_dim=128, learning_rate=1e-3, batch_size=32, num_workers=4, data_dir='', **kwargs):
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('You want to use `torchvision` which is not installed yet.')

super().__init__()
self.save_hyperparameters()

Expand Down
5 changes: 5 additions & 0 deletions pl_bolts/models/regression/__init__.py
@@ -1,2 +1,7 @@
from pl_bolts.models.regression.linear_regression import LinearRegression # noqa: F401
from pl_bolts.models.regression.logistic_regression import LogisticRegression # noqa: F401

__all__ = [
"LinearRegression",
"LogisticRegression",
]
27 changes: 17 additions & 10 deletions pl_bolts/models/rl/__init__.py
@@ -1,10 +1,17 @@
try:
from pl_bolts.models.rl.double_dqn_model import DoubleDQN # noqa: F401
from pl_bolts.models.rl.dqn_model import DQN # noqa: F401
from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN # noqa: F401
from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN # noqa: F401
from pl_bolts.models.rl.per_dqn_model import PERDQN # noqa: F401
from pl_bolts.models.rl.reinforce_model import Reinforce # noqa: F401
from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient # noqa: F401
except ModuleNotFoundError:
pass
from pl_bolts.models.rl.double_dqn_model import DoubleDQN # noqa: F401
from pl_bolts.models.rl.dqn_model import DQN # noqa: F401
from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN # noqa: F401
from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN # noqa: F401
from pl_bolts.models.rl.per_dqn_model import PERDQN # noqa: F401
from pl_bolts.models.rl.reinforce_model import Reinforce # noqa: F401
from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient # noqa: F401

__all__ = [
"DoubleDQN",
"DQN",
"DuelingDQN",
"NoisyDQN",
"PERDQN",
"Reinforce",
"VanillaPolicyGradient",
]
24 changes: 17 additions & 7 deletions pl_bolts/models/rl/common/gym_wrappers.py
Expand Up @@ -14,21 +14,24 @@
import gym.spaces
from gym import make as gym_make
from gym import ObservationWrapper, Wrapper
else: # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('gym')
Wrapper = object
ObservationWrapper = object

if _OPENCV_AVAILABLE:
import cv2
else:
warn_missing_pkg('cv2', pypi_name='opencv-python') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('cv2', pypi_name='opencv-python')


class ToTensor(Wrapper):
"""For environments where the user need to press FIRE for the game to start."""

def __init__(self, env=None):
if not _GYM_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('You want to use `gym` which is not installed yet.')

super(ToTensor, self).__init__(env)

def step(self, action):
Expand All @@ -45,6 +48,9 @@ class FireResetEnv(Wrapper):
"""For environments where the user need to press FIRE for the game to start."""

def __init__(self, env=None):
if not _GYM_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('You want to use `gym` which is not installed yet.')

super(FireResetEnv, self).__init__(env)
assert env.unwrapped.get_action_meanings()[1] == "FIRE"
assert len(env.unwrapped.get_action_meanings()) >= 3
Expand All @@ -69,6 +75,9 @@ class MaxAndSkipEnv(Wrapper):
"""Return only every `skip`-th frame"""

def __init__(self, env=None, skip=4):
if not _GYM_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('You want to use `gym` which is not installed yet.')

super(MaxAndSkipEnv, self).__init__(env)
# most recent raw observations (for max pooling across time steps)
self._obs_buffer = collections.deque(maxlen=2)
Expand Down Expand Up @@ -99,8 +108,7 @@ class ProcessFrame84(ObservationWrapper):
"""preprocessing images from env"""

def __init__(self, env=None):

if not _OPENCV_AVAILABLE:
if not _OPENCV_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('This class uses OpenCV which it is not installed yet.')

super(ProcessFrame84, self).__init__(env)
Expand Down Expand Up @@ -130,8 +138,7 @@ class ImageToPyTorch(ObservationWrapper):
"""converts image to pytorch format"""

def __init__(self, env):

if not _OPENCV_AVAILABLE:
if not _OPENCV_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('This class uses OpenCV which it is not installed yet.')

super(ImageToPyTorch, self).__init__(env)
Expand Down Expand Up @@ -188,6 +195,9 @@ class DataAugmentation(ObservationWrapper):
"""

def __init__(self, env=None):
if not _GYM_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('You want to use `gym` which is not installed yet.')

super().__init__(env)
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)

Expand Down
4 changes: 2 additions & 2 deletions pl_bolts/models/rl/dqn_model.py
Expand Up @@ -25,8 +25,8 @@

if _GYM_AVAILABLE:
from gym import Env
else:
warn_missing_pkg('gym') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('gym')
Env = object


Expand Down
6 changes: 3 additions & 3 deletions pl_bolts/models/rl/reinforce_model.py
Expand Up @@ -21,8 +21,8 @@

if _GYM_AVAILABLE:
import gym
else:
warn_missing_pkg('gym') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('gym')


class Reinforce(pl.LightningModule):
Expand Down Expand Up @@ -80,7 +80,7 @@ def __init__(
"""
super().__init__()

if not _GYM_AVAILABLE:
if not _GYM_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('This Module requires gym environment which is not installed yet.')

# Hyperparameters
Expand Down
6 changes: 3 additions & 3 deletions pl_bolts/models/rl/vanilla_policy_gradient_model.py
Expand Up @@ -20,8 +20,8 @@

if _GYM_AVAILABLE:
import gym
else:
warn_missing_pkg('gym') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg('gym')


class VanillaPolicyGradient(pl.LightningModule):
Expand Down Expand Up @@ -79,7 +79,7 @@ def __init__(
"""
super().__init__()

if not _GYM_AVAILABLE:
if not _GYM_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError('This Module requires gym environment which is not installed yet.')

# Hyperparameters
Expand Down
6 changes: 6 additions & 0 deletions pl_bolts/models/vision/__init__.py
@@ -1,3 +1,9 @@
from pl_bolts.models.vision.pixel_cnn import PixelCNN # noqa: F401
from pl_bolts.models.vision.segmentation import SemSegment # noqa: F401
from pl_bolts.models.vision.unet import UNet # noqa: F401

__all__ = [
"PixelCNN",
"SemSegment",
"UNet",
]