-
Notifications
You must be signed in to change notification settings - Fork 30
/
dataset_llff.py
112 lines (88 loc) · 4.72 KB
/
dataset_llff.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
# Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
import os
import glob
import torch
import numpy as np
from render import util
from dataset import Dataset
def _load_mask(fn):
img = torch.tensor(util.load_image(fn), dtype=torch.float32)
if len(img.shape) == 2:
img = img[..., None].repeat(1, 1, 3)
return img
def _load_img(fn):
img = util.load_image_raw(fn)
if img.dtype != np.float32: # LDR image
img = torch.tensor(img / 255, dtype=torch.float32)
img[..., 0:3] = util.srgb_to_rgb(img[..., 0:3])
else:
img = torch.tensor(img, dtype=torch.float32)
return img
###############################################################################
# LLFF datasets (real world camera lightfields)
###############################################################################
class DatasetLLFF(Dataset):
def __init__(self, base_dir, FLAGS, examples=None):
self.FLAGS = FLAGS
self.base_dir = base_dir
self.examples = examples
# Enumerate all image files and get resolution
all_img = [f for f in sorted(glob.glob(os.path.join(self.base_dir, "images", "*")))
if f.lower().endswith('png') or f.lower().endswith('jpg') or f.lower().endswith('jpeg')]
self.resolution = _load_img(all_img[0]).shape[0:2]
print("DatasetLLFF: %d images with shape [%d, %d]" % (len(all_img), self.resolution[0], self.resolution[1]))
# Load camera poses
poses_bounds = np.load(os.path.join(self.base_dir, 'poses_bounds.npy'))
poses = poses_bounds[:, :-2].reshape([-1, 3, 5]).transpose([1,2,0])
poses = np.concatenate([poses[:, 1:2, :], -poses[:, 0:1, :], poses[:, 2:, :]], 1) # Taken from nerf, swizzles from LLFF to expected coordinate system
poses = np.moveaxis(poses, -1, 0).astype(np.float32)
lcol = np.array([0,0,0,1], dtype=np.float32)[None, None, :].repeat(poses.shape[0], 0)
self.imvs = torch.tensor(np.concatenate((poses[:, :, 0:4], lcol), axis=1), dtype=torch.float32)
self.aspect = self.resolution[1] / self.resolution[0] # width / height
self.fovy = util.focal_length_to_fovy(poses[:, 2, 4], poses[:, 0, 4])
# Recenter scene so lookat position is origin
center = util.lines_focal(self.imvs[..., :3, 3], -self.imvs[..., :3, 2])
self.imvs[..., :3, 3] = self.imvs[..., :3, 3] - center[None, ...]
print("DatasetLLFF: auto-centering at %s" % (center.cpu().numpy()))
# Pre-load from disc to avoid slow png parsing
if self.FLAGS.pre_load:
self.preloaded_data = []
for i in range(self.imvs.shape[0]):
self.preloaded_data += [self._parse_frame(i)]
def _parse_frame(self, idx):
all_img = [f for f in sorted(glob.glob(os.path.join(self.base_dir, "images", "*"))) if f.lower().endswith('png') or f.lower().endswith('jpg') or f.lower().endswith('jpeg')]
all_mask = [f for f in sorted(glob.glob(os.path.join(self.base_dir, "masks", "*"))) if f.lower().endswith('png') or f.lower().endswith('jpg') or f.lower().endswith('jpeg')]
assert len(all_img) == self.imvs.shape[0] and len(all_mask) == self.imvs.shape[0]
# Load image & mask data
img = _load_img(all_img[idx])
mask = _load_mask(all_mask[idx])
img = torch.cat((img, mask[..., 0:1]), dim=-1)
# Setup transforms
proj = util.perspective(self.fovy[idx, ...], self.aspect, self.FLAGS.cam_near_far[0], self.FLAGS.cam_near_far[1])
mv = torch.linalg.inv(self.imvs[idx, ...])
campos = torch.linalg.inv(mv)[:3, 3]
mvp = proj @ mv
return img[None, ...], mv[None, ...], mvp[None, ...], campos[None, ...] # Add batch dimension
def getMesh(self):
return None # There is no mesh
def __len__(self):
return self.imvs.shape[0] if self.examples is None else self.examples
def __getitem__(self, itr):
if self.FLAGS.pre_load:
img, mv, mvp, campos = self.preloaded_data[itr % self.imvs.shape[0]]
else:
img, mv, mvp, campos = self._parse_frame(itr % self.imvs.shape[0])
return {
'mv' : mv,
'mvp' : mvp,
'campos' : campos,
'resolution' : self.resolution,
'spp' : self.FLAGS.spp,
'img' : img
}