-
Notifications
You must be signed in to change notification settings - Fork 62
/
depth_utils.py
39 lines (36 loc) · 1.55 KB
/
depth_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import math
def bin_depths(depth_map, mode, depth_min, depth_max, num_bins, target=False):
"""
Converts depth map into bin indices
Args:
depth_map [torch.Tensor(H, W)]: Depth Map
mode [string]: Discretiziation mode (See https://arxiv.org/pdf/2005.13423.pdf for more details)
UD: Uniform discretiziation
LID: Linear increasing discretiziation
SID: Spacing increasing discretiziation
depth_min [float]: Minimum depth value
depth_max [float]: Maximum depth value
num_bins [int]: Number of depth bins
target [bool]: Whether the depth bins indices will be used for a target tensor in loss comparison
Returns:
indices [torch.Tensor(H, W)]: Depth bin indices
"""
if mode == "UD":
bin_size = (depth_max - depth_min) / num_bins
indices = ((depth_map - depth_min) / bin_size)
elif mode == "LID":
bin_size = 2 * (depth_max - depth_min) / (num_bins * (1 + num_bins))
indices = -0.5 + 0.5 * torch.sqrt(1 + 8 * (depth_map - depth_min) / bin_size)
elif mode == "SID":
indices = num_bins * (torch.log(1 + depth_map) - math.log(1 + depth_min)) / \
(math.log(1 + depth_max) - math.log(1 + depth_min))
else:
raise NotImplementedError
if target:
# Remove indicies outside of bounds
mask = (indices < 0) | (indices > num_bins) | (~torch.isfinite(indices))
indices[mask] = num_bins
# Convert to integer
indices = indices.type(torch.int64)
return indices