-
Notifications
You must be signed in to change notification settings - Fork 8
/
inference_RDC.py
87 lines (69 loc) · 4.06 KB
/
inference_RDC.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from path import Path
import numpy as np
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
from tqdm import tqdm
import torch
from torch.nn import functional as F
from models import DepthNet
from skimage.transform import resize
from evaluation_toolkit.inference_toolkit import inferenceFramework
from inverse_warp import inverse_rotate
@torch.no_grad()
def main():
parser = ArgumentParser(description='Example usage of Inference toolkit for RDC. See https://github.com/ClementPinard/depth-dataset-builder',
formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument('--dataset_root', metavar='DIR', type=Path, required=True)
parser.add_argument('--depth_output', metavar='FILE', type=Path, required=True,
help='where to store the estimated depth maps, must be a npy file')
parser.add_argument('--evaluation_list_path', metavar='PATH', type=Path, required=True,
help='File with list of images to test for depth evaluation')
parser.add_argument('--pretrained_depthnet', metavar='FILE', type=Path, required=True)
parser.add_argument('--no-resize', action='store_true')
parser.add_argument("--img-height", default=128, type=int, help="Image height")
parser.add_argument("--img-width", default=416, type=int, help="Image width")
parser.add_argument("--nominal-displacement", "-D", type=float, default=0.3)
args = parser.parse_args()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
with open(args.evaluation_list_path) as f:
evaluation_list = [line[:-1] for line in f.readlines()]
def preprocessing(frame):
h, w, _ = frame.shape
if (not args.no_resize) and (h != args.img_height or w != args.img_width):
frame = resize(frame, (args.img_height, args.img_width))
frame_np = (frame.transpose(2, 0, 1).astype(np.float32)[None]/255 - 0.5)/0.5
return torch.from_numpy(frame_np).to(device)
engine = inferenceFramework(args.dataset_root, evaluation_list, frame_transform=preprocessing, max_shift=10)
depth_net = DepthNet().to(device)
weights = torch.load(args.pretrained_depthnet)
depth_net.load_state_dict(weights['state_dict'])
depth_net.eval()
for sample in tqdm(engine):
tgt_img, latest_intrinsics, poses = sample.get_frame()
# TODO Here, we don't apply the Algotihm from
# "Multi range Real-time depth inference from a monocular stabilized footage
# using a Fully Convolutional Neural Network"
#
# Instead, we just want to take the frame with a displacement close to nominal displacement (usually 0.3m)
ref_img, _, previous_pose = sample.get_previous_frame(displacement=args.nominal_displacement)
inv_rot = torch.from_numpy(previous_pose).to(ref_img)[:, :3].T
latest_intrinsics = torch.from_numpy(latest_intrinsics).to(ref_img)
stab_img = inverse_rotate(ref_img, inv_rot[None], latest_intrinsics[None])
pair = torch.cat([stab_img, tgt_img], dim=1) # [1, 6, H, W]
pred_depth = depth_net(pair)
scale_factor = np.linalg.norm(previous_pose[:, 3]) / args.nominal_displacement
pred_depth *= scale_factor
if (not args.no_resize) and (pred_depth.shape[0] != args.img_height or pred_depth.shape[1] != args.img_width):
out_shape = (args.img_height, args.img_width)
else:
out_shape = tgt_img.shape[:2]
pred_depth_zoomed = F.interpolate(pred_depth.view(1, 1, *pred_depth.shape),
out_shape,
mode='bilinear',
align_corners=False)
pred_depth_zoomed = pred_depth_zoomed.cpu().numpy()[0, 0]
engine.finish_frame(pred_depth_zoomed)
mean_inference_time, output_depth_maps = engine.finalize(output_path=args.depth_output)
print("Mean time per sample : {:.2f}us".format(1e6 * mean_inference_time))
np.savez(args.depth_output, **output_depth_maps)
if __name__ == '__main__':
main()