Skip to content

Commit

Permalink
Update TensorboardGenerativeModelImageSampler args (#494)
Browse files Browse the repository at this point in the history
* Update args

* Add docs

* Fix codefactor and update changelog

* Update docs

* Apply yapf

* chlog

Co-authored-by: Christoph Clement <christoph.clement@artorg.unibe.ch>
Co-authored-by: Jirka Borovec <jirka.borovec@seznam.cz>
  • Loading branch information
3 people committed Jan 18, 2021
1 parent 42cfa8f commit 6e15643
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -42,6 +42,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Refactored `pl_bolts.callbacks` ([#477](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/477))
- Refactored the rest of `pl_bolts.models.self_supervised` ([#481](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/481),
[#479](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/479)
- Update [`torchvision.utils.make_grid`(https://pytorch.org/docs/stable/torchvision/utils.html#torchvision.utils.make_grid)] kwargs to `TensorboardGenerativeModelImageSampler` ([#494](https://github.com/PyTorchLightning/pytorch-lightning-bolts/pull/494))

### Fixed

Expand Down
50 changes: 45 additions & 5 deletions pl_bolts/callbacks/vision/image_generation.py
@@ -1,3 +1,5 @@
from typing import Optional, Tuple

import torch
from pytorch_lightning import Callback, LightningModule, Trainer

Expand All @@ -6,7 +8,7 @@
try:
import torchvision
except ModuleNotFoundError:
warn_missing_pkg('torchvision') # pragma: no-cover
warn_missing_pkg("torchvision") # pragma: no-cover


class TensorboardGenerativeModelImageSampler(Callback):
Expand All @@ -30,9 +32,39 @@ class TensorboardGenerativeModelImageSampler(Callback):
trainer = Trainer(callbacks=[TensorboardGenerativeModelImageSampler()])
"""

def __init__(self, num_samples: int = 3) -> None:
def __init__(
self,
num_samples: int = 3,
nrow: int = 8,
padding: int = 2,
normalize: bool = False,
norm_range: Optional[Tuple[int, int]] = None,
scale_each: bool = False,
pad_value: int = 0,
) -> None:
"""
Args:
num_samples: Number of images displayed in the grid. Default: ``3``.
nrow: Number of images displayed in each row of the grid.
The final grid size is ``(B / nrow, nrow)``. Default: ``8``.
padding: Amount of padding. Default: ``2``.
normalize: If ``True``, shift the image to the range (0, 1),
by the min and max values specified by :attr:`range`. Default: ``False``.
norm_range: Tuple (min, max) where min and max are numbers,
then these numbers are used to normalize the image. By default, min and max
are computed from the tensor.
scale_each: If ``True``, scale each image in the batch of
images separately rather than the (min, max) over all images. Default: ``False``.
pad_value: Value for the padded pixels. Default: ``0``.
"""
super().__init__()
self.num_samples: int = num_samples
self.num_samples = num_samples
self.nrow = nrow
self.padding = padding
self.normalize = normalize
self.norm_range = norm_range
self.scale_each = scale_each
self.pad_value = pad_value

def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
dim = (self.num_samples, pl_module.hparams.latent_dim) # type: ignore[union-attr]
Expand All @@ -48,6 +80,14 @@ def on_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
img_dim = pl_module.img_dim
images = images.view(self.num_samples, *img_dim)

grid = torchvision.utils.make_grid(images)
str_title = f'{pl_module.__class__.__name__}_images'
grid = torchvision.utils.make_grid(
tensor=images,
nrow=self.nrow,
padding=self.padding,
normalize=self.normalize,
range=self.norm_range,
scale_each=self.scale_each,
pad_value=self.pad_value,
)
str_title = f"{pl_module.__class__.__name__}_images"
trainer.logger.experiment.add_image(str_title, grid, global_step=trainer.global_step)

0 comments on commit 6e15643

Please sign in to comment.