Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor pl_bolts.callbacks #477

Merged
merged 6 commits into from Jan 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 15 additions & 1 deletion pl_bolts/callbacks/__init__.py
@@ -1,7 +1,21 @@
"""
Collection of PyTorchLightning callbacks
"""
from pl_bolts.callbacks.byol_updates import BYOLMAWeightUpdate # noqa: F401
from pl_bolts.callbacks.data_monitor import ModuleDataMonitor, TrainingDataMonitor # noqa: F401
from pl_bolts.callbacks.printing import PrintTableMetricsCallback # noqa: F401
from pl_bolts.callbacks.ssl_online import SSLOnlineEvaluator # noqa: F401
from pl_bolts.callbacks.variational import LatentDimInterpolator # noqa: F401
from pl_bolts.callbacks.vision import TensorboardGenerativeModelImageSampler # noqa: F401
from pl_bolts.callbacks.vision.confused_logit import ConfusedLogitCallback # noqa: F401
from pl_bolts.callbacks.vision.image_generation import TensorboardGenerativeModelImageSampler # noqa: F401

__all__ = [
"BYOLMAWeightUpdate",
"ModuleDataMonitor",
"TrainingDataMonitor",
"PrintTableMetricsCallback",
"SSLOnlineEvaluator",
"LatentDimInterpolator",
"ConfusedLogitCallback",
"TensorboardGenerativeModelImageSampler",
]
13 changes: 11 additions & 2 deletions pl_bolts/callbacks/data_monitor.py
Expand Up @@ -10,9 +10,13 @@
from torch import Tensor
from torch.utils.hooks import RemovableHandle

try:
from pl_bolts.utils import _WANDB_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

if _WANDB_AVAILABLE:
import wandb
except ModuleNotFoundError:
else: # pragma: no cover
warn_missing_pkg("wandb")
wandb = None


Expand Down Expand Up @@ -87,6 +91,11 @@ def log_histogram(self, tensor: Tensor, name: str) -> None:
)

if isinstance(logger, WandbLogger):
if not _WANDB_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
"You want to use `wandb` which is not installed yet."
)

logger.experiment.log(
data={name: wandb.Histogram(tensor)}, commit=False,
)
Expand Down
2 changes: 1 addition & 1 deletion pl_bolts/callbacks/ssl_online.py
Expand Up @@ -6,7 +6,7 @@
from torch.nn import functional as F


class SSLOnlineEvaluator(Callback): # pragma: no-cover
class SSLOnlineEvaluator(Callback): # pragma: no cover
"""
Attaches a MLP for fine-tuning using the standard self-supervised protocol.

Expand Down
12 changes: 9 additions & 3 deletions pl_bolts/callbacks/variational.py
Expand Up @@ -4,12 +4,13 @@
import torch
from pytorch_lightning.callbacks import Callback

from pl_bolts.utils import _TORCHVISION_AVAILABLE
from pl_bolts.utils.warnings import warn_missing_pkg

try:
if _TORCHVISION_AVAILABLE:
import torchvision
except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover
else: # pragma: no cover
warn_missing_pkg("torchvision")


class LatentDimInterpolator(Callback):
Expand Down Expand Up @@ -44,6 +45,11 @@ def __init__(
num_samples: default 2
normalize: default True (change image to (0, 1) range)
"""
if not _TORCHVISION_AVAILABLE: # pragma: no cover
raise ModuleNotFoundError(
"You want to use `torchvision` which is not installed yet."
)

super().__init__()
self.interpolate_epoch_interval = interpolate_epoch_interval
self.range_start = range_start
Expand Down
2 changes: 0 additions & 2 deletions pl_bolts/callbacks/vision/__init__.py
@@ -1,2 +0,0 @@
from pl_bolts.callbacks.vision.confused_logit import ConfusedLogitCallback # noqa: F401
from pl_bolts.callbacks.vision.image_generation import TensorboardGenerativeModelImageSampler # noqa: F401
1 change: 1 addition & 0 deletions pl_bolts/utils/__init__.py
Expand Up @@ -8,3 +8,4 @@
_SKLEARN_AVAILABLE: bool = _module_available("sklearn")
_PIL_AVAILABLE: bool = _module_available("PIL")
_OPENCV_AVAILABLE: bool = _module_available("cv2")
_WANDB_AVAILABLE: bool = _module_available("wandb")