-
Notifications
You must be signed in to change notification settings - Fork 395
/
Copy pathvisualization.py
41 lines (32 loc) · 1.48 KB
/
visualization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
'''
-----------------------------------------------------------------------------
Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
NVIDIA CORPORATION and its licensors retain all intellectual property
and proprietary rights in and to this software, related documentation
and any modifications thereto. Any use, reproduction, disclosure or
distribution of this software and related documentation without an express
license agreement from NVIDIA CORPORATION is strictly prohibited.
-----------------------------------------------------------------------------
'''
import wandb
import torch
import torchvision
from matplotlib import pyplot as plt
from torchvision.transforms import functional as torchvision_F
def wandb_image(images, from_range=(0, 1)):
images = preprocess_image(images, from_range=from_range)
image_grid = torchvision.utils.make_grid(images, nrow=1, pad_value=1)
image_grid = torchvision_F.to_pil_image(image_grid)
wandb_image = wandb.Image(image_grid)
return wandb_image
def preprocess_image(images, from_range=(0, 1), cmap="gray"):
min, max = from_range
images = (images - min) / (max - min)
images = images.detach().cpu().float().clamp_(min=0, max=1)
if images.shape[1] == 1:
images = get_heatmap(images[:, 0], cmap=cmap)
return images
def get_heatmap(gray, cmap): # [N,H,W]
color = plt.get_cmap(cmap)(gray.numpy())
color = torch.from_numpy(color[..., :3]).permute(0, 3, 1, 2).float() # [N,3,H,W]
return color