-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #10 from TalkToTheGAN/misc/visualization
Visualizer enabler code.
- Loading branch information
Showing
5 changed files
with
144 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# -*- coding: utf-8 -*- | ||
|
||
import torch | ||
import math | ||
irange = range | ||
|
||
|
||
def make_grid(tensor, nrow=8, padding=2, | ||
normalize=False, range=None, scale_each=False, pad_value=0): | ||
"""Make a grid of images. | ||
Args: | ||
tensor (Tensor or list): 4D mini-batch Tensor of shape (B x C x H x W) | ||
or a list of images all of the same size. | ||
nrow (int, optional): Number of images displayed in each row of the grid. | ||
The Final grid size is (B / nrow, nrow). Default is 8. | ||
padding (int, optional): amount of padding. Default is 2. | ||
normalize (bool, optional): If True, shift the image to the range (0, 1), | ||
by subtracting the minimum and dividing by the maximum pixel value. | ||
range (tuple, optional): tuple (min, max) where min and max are numbers, | ||
then these numbers are used to normalize the image. By default, min and max | ||
are computed from the tensor. | ||
scale_each (bool, optional): If True, scale each image in the batch of | ||
images separately rather than the (min, max) over all images. | ||
pad_value (float, optional): Value for the padded pixels. | ||
""" | ||
if not (torch.is_tensor(tensor) or | ||
(isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): | ||
raise TypeError('tensor or list of tensors expected, got {}'.format(type(tensor))) | ||
|
||
# if list of tensors, convert to a 4D mini-batch Tensor | ||
if isinstance(tensor, list): | ||
tensor = torch.stack(tensor, dim=0) | ||
|
||
if tensor.dim() == 2: # single image H x W | ||
tensor = tensor.view(1, tensor.size(0), tensor.size(1)) | ||
if tensor.dim() == 3: # single image | ||
if tensor.size(0) == 1: # if single-channel, convert to 3-channel | ||
tensor = torch.cat((tensor, tensor, tensor), 0) | ||
return tensor | ||
if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images | ||
tensor = torch.cat((tensor, tensor, tensor), 1) | ||
|
||
if normalize is True: | ||
tensor = tensor.clone() # avoid modifying tensor in-place | ||
if range is not None: | ||
assert isinstance(range, tuple), \ | ||
"range has to be a tuple (min, max) if specified. min and max are numbers" | ||
|
||
def norm_ip(img, min, max): | ||
img.clamp_(min=min, max=max) | ||
img.add_(-min).div_(max - min) | ||
|
||
def norm_range(t, range): | ||
if range is not None: | ||
norm_ip(t, range[0], range[1]) | ||
else: | ||
norm_ip(t, t.min(), t.max()) | ||
|
||
if scale_each is True: | ||
for t in tensor: # loop over mini-batch dimension | ||
norm_range(t, range) | ||
else: | ||
norm_range(tensor, range) | ||
|
||
# make the mini-batch of images into a grid | ||
nmaps = tensor.size(0) | ||
xmaps = min(nrow, nmaps) | ||
ymaps = int(math.ceil(float(nmaps) / xmaps)) | ||
height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding) | ||
grid = tensor.new(3, height * ymaps + padding, width * xmaps + padding).fill_(pad_value) | ||
k = 0 | ||
for y in irange(ymaps): | ||
for x in irange(xmaps): | ||
if k >= nmaps: | ||
break | ||
grid.narrow(1, y * height + padding, height - padding)\ | ||
.narrow(2, x * width + padding, width - padding)\ | ||
.copy_(tensor[k]) | ||
k = k + 1 | ||
return grid | ||
|
||
|
||
def save_image(tensor, filename, nrow=8, padding=2, | ||
normalize=False, range=None, scale_each=False, pad_value=0): | ||
"""Save a given Tensor into an image file. | ||
Args: | ||
tensor (Tensor or list): Image to be saved. If given a mini-batch tensor, | ||
saves the tensor as a grid of images by calling ``make_grid``. | ||
**kwargs: Other arguments are documented in ``make_grid``. | ||
""" | ||
from PIL import Image | ||
tensor = tensor.cpu() | ||
grid = make_grid(tensor, nrow=nrow, padding=padding, pad_value=pad_value, | ||
normalize=normalize, range=range, scale_each=scale_each) | ||
ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).numpy() | ||
im = Image.fromarray(ndarr) | ||
im.save(filename) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters