Skip to content

Commit

Permalink
fix fox dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
ashawkey committed Mar 29, 2022
1 parent 4d9806c commit 873c267
Show file tree
Hide file tree
Showing 15 changed files with 281 additions and 51 deletions.
11 changes: 6 additions & 5 deletions main_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
model = NeRFNetwork(
bound=opt.bound,
cuda_ray=opt.cuda_ray,
density_scale=1 if opt.mode == 'colmap' else 25,
)

print(model)
Expand All @@ -74,7 +75,7 @@

else:
test_dataset = NeRFDataset(opt.path, type='test', mode=opt.mode, scale=opt.scale, preload=opt.preload)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, pin_memory=not opt.preload)

if opt.mode == 'blender':
trainer.evaluate(test_loader) # blender has gt, so evaluate it.
Expand All @@ -97,23 +98,23 @@

if opt.gui:
train_dataset = NeRFDataset(opt.path, type='all', mode=opt.mode, scale=opt.scale, preload=opt.preload)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, pin_memory=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, pin_memory=not opt.preload)
trainer.train_loader = train_loader # attach dataloader to trainer

gui = NeRFGUI(opt, trainer)
gui.render()

else:
train_dataset = NeRFDataset(opt.path, type='train', mode=opt.mode, scale=opt.scale, preload=opt.preload)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, pin_memory=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, pin_memory=not opt.preload)
valid_dataset = NeRFDataset(opt.path, type='val', mode=opt.mode, downscale=2, scale=opt.scale, preload=opt.preload)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=1, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=1, pin_memory=not opt.preload)

trainer.train(train_loader, valid_loader, 300)

# also test
test_dataset = NeRFDataset(opt.path, type='test', mode=opt.mode, scale=opt.scale, preload=opt.preload)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, pin_memory=not opt.preload)

if opt.mode == 'blender':
trainer.evaluate(test_loader) # blender has gt, so evaluate it.
Expand Down
17 changes: 11 additions & 6 deletions main_tensoRF.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
parser.add_argument('--max_ray_batch', type=int, default=4096)
### network backbone options
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
parser.add_argument('--cp', action='store_true', help="use TensorCP instead of TensorVMSplit")
parser.add_argument('--resolution0', type=int, default=128)
parser.add_argument('--resolution1', type=int, default=300)
parser.add_argument("--upsample_model_steps", type=int, action="append", default=[2000, 3000, 4000, 5500, 7000])
Expand All @@ -47,12 +48,16 @@

seed_everything(opt.seed)

from tensoRF.network import NeRFNetwork
if opt.cp:
from tensoRF.network_cp import NeRFNetwork
else:
from tensoRF.network import NeRFNetwork

model = NeRFNetwork(
resolution=[opt.resolution0] * 3,
bound=opt.bound,
cuda_ray=opt.cuda_ray,
density_scale=1 if opt.mode == 'colmap' else 25,
)

print(model)
Expand All @@ -70,7 +75,7 @@

else:
test_dataset = NeRFDataset(opt.path, type='test', mode=opt.mode, scale=opt.scale, preload=opt.preload)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, pin_memory=not opt.preload)

if opt.mode == 'blender':
trainer.evaluate(test_loader) # blender has gt, so evaluate it.
Expand All @@ -93,23 +98,23 @@

if opt.gui:
train_dataset = NeRFDataset(opt.path, type='all', mode=opt.mode, scale=opt.scale, preload=opt.preload)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, pin_memory=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, pin_memory=not opt.preload)
trainer.train_loader = train_loader # attach dataloader to trainer

gui = NeRFGUI(opt, trainer)
gui.render()

else:
train_dataset = NeRFDataset(opt.path, type='train', mode=opt.mode, scale=opt.scale, preload=opt.preload)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, pin_memory=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opt.batch_size, shuffle=True, pin_memory=not opt.preload)
valid_dataset = NeRFDataset(opt.path, type='val', mode=opt.mode, downscale=2, scale=opt.scale, preload=opt.preload)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=1, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset, batch_size=1, pin_memory=not opt.preload)

trainer.train(train_loader, valid_loader, 300)

# also test
test_dataset = NeRFDataset(opt.path, type='test', mode=opt.mode, scale=opt.scale, preload=opt.preload)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, pin_memory=not opt.preload)

if opt.mode == 'blender':
trainer.evaluate(test_loader) # blender has gt, so evaluate it.
Expand Down
4 changes: 2 additions & 2 deletions nerf/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ def __init__(self,
num_layers_color=3,
hidden_dim_color=64,
bound=1,
cuda_ray=False,
**kwargs,
):
super().__init__(bound, cuda_ray)
super().__init__(bound, **kwargs)

# sigma network
self.num_layers = num_layers
Expand Down
4 changes: 2 additions & 2 deletions nerf/network_ff.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ def __init__(self,
num_layers_color=3,
hidden_dim_color=64,
bound=1,
cuda_ray=False,
**kwargs
):
super().__init__(bound, cuda_ray)
super().__init__(bound, **kwargs)

# sigma network
self.num_layers = num_layers
Expand Down
4 changes: 2 additions & 2 deletions nerf/network_tcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ def __init__(self,
num_layers_color=3,
hidden_dim_color=64,
bound=1,
cuda_ray=False,
**kwargs
):
super().__init__(bound, cuda_ray)
super().__init__(bound, **kwargs)

# sigma network
self.num_layers = num_layers
Expand Down
15 changes: 7 additions & 8 deletions nerf/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class NeRFRenderer(nn.Module):
def __init__(self,
bound=1,
cuda_ray=False,
density_scale=25, # scale up deltas (or sigmas), to make the density grid more sharp. larger value than 1 usually improves performance.
density_scale=1, # scale up deltas (or sigmas), to make the density grid more sharp. larger value than 1 usually improves performance.
):
super().__init__()

Expand Down Expand Up @@ -261,7 +261,7 @@ def run_cuda(self, rays_o, rays_d, num_steps, upsample_steps, bg_color, perturb)
counter.zero_() # set to 0
self.local_step += 1

xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_grid, self.mean_density, self.iter_density, counter, self.mean_count, perturb, 128, True)
xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_grid, self.mean_density, self.iter_density, counter, self.mean_count, perturb, 128, False)

density_outputs = self.density(xyzs) # [M,], use a dict since it may include extra things, like geo_feat for rgb.
sigmas = density_outputs['sigma']
Expand Down Expand Up @@ -333,7 +333,7 @@ def run_cuda(self, rays_o, rays_d, num_steps, upsample_steps, bg_color, perturb)

raymarching.composite_rays(n_alive, n_step, rays_alive[i % 2], rays_t[i % 2], sigmas, rgbs, deltas, weights_sum, depth, image)

#print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}')
#print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}')

step += n_step
i += 1
Expand All @@ -348,7 +348,7 @@ def run_cuda(self, rays_o, rays_d, num_steps, upsample_steps, bg_color, perturb)
return depth, image


def update_extra_state(self, decay=0.95):
def update_extra_state(self, decay=0.9):
# call before each epoch to update extra states.

if not self.cuda_ray:
Expand Down Expand Up @@ -376,9 +376,8 @@ def update_extra_state(self, decay=0.95):
xyzs += (torch.rand_like(xyzs) * 2 - 1) * half_grid_size
# query density
sigmas = self.density(xyzs.to(tmp_grid.device))['sigma'].reshape(lx, ly, lz).detach()
# change density to alpha in [0, 1]
alphas = 1 - torch.exp(-self.density_scale * sigmas) # [B, N, T], fake deltas to 1 (it doesn't really matter)
tmp_grid[xi * 128: xi * 128 + lx, yi * 128: yi * 128 + ly, zi * 128: zi * 128 + lz] = alphas
# the magic scale number is from `scalbnf(MIN_CONE_STEPSIZE(), level)`, don't ask me why...
tmp_grid[xi * 128: xi * 128 + lx, yi * 128: yi * 128 + ly, zi * 128: zi * 128 + lz] = sigmas * self.density_scale * 0.001691

# maxpool to smooth
#tmp_grid = F.pad(tmp_grid, (0, 1, 0, 1, 0, 1))
Expand All @@ -395,7 +394,7 @@ def update_extra_state(self, decay=0.95):
self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step)
self.local_step = 0

#print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f} | [step counter] mean={self.mean_count}')
#print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f}, occ_rate={(self.density_grid > 0.01).sum() / (128**3):.3f} | [step counter] mean={self.mean_count}')


def render(self, rays_o, rays_d, num_steps=128, upsample_steps=128, staged=False, max_ray_batch=4096, bg_color=None, perturb=False, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion nerf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def __init__(self,
workspace='workspace', # workspace to save logs & ckpts
best_mode='min', # the smaller/larger result, the better
use_loss_as_metric=True, # use loss as the first metric
report_metric_at_train=True, # also report metrics at training
report_metric_at_train=False, # also report metrics at training
use_checkpoint="latest", # which ckpt to use at init time
use_tensorboardX=True, # whether to use tensorboard for logging
scheduler_update_every_step=False, # whether to call scheduler.step() after every train step
Expand Down
2 changes: 2 additions & 0 deletions raymarching/raymarching.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def forward(ctx, rays_o, rays_d, bound, density_grid, mean_density, iter_density
dirs = dirs[:m]
deltas = deltas[:m]

torch.cuda.empty_cache()

return xyzs, dirs, deltas, rays

march_rays_train = _march_rays_train.apply
Expand Down
15 changes: 6 additions & 9 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -115,22 +115,19 @@ check the `scripts` directory for more provided examples.
Tested with the default settings on the Lego test dataset. Here the speed refers to the `iterations per second` on a TITAN RTX.
| Model | PSNR | Train Speed | Test Speed |
| - | - | - | - |
| HashNeRF (`fp16`) | 32.22 | 24 | 0.56 |
| HashNeRF (`fp16 + ff`) | 32.81 | 24 | 0.79 |
| HashNeRF (`fp16 + tcnn`) | 32.72 | 20 | 0.37 |
| HashNeRF (`fp16 + cuda_ray`) | 32.54 | 65 | 6.4 |
| HashNeRF (`fp16 + cuda_ray + ff`) | 33.24 | 72 | 6.9 |
| HashNeRF (`fp16 + cuda_ray + tcnn`) | 33.11 | 60 | 5.8 |
| TensoRF (`fp16`) | 33.79 | 18 | 0.53 |
| TensoRF (`fp16 + cuda_ray`) | 34.05 | 13 | 0.43 |

| HashNeRF (`fp16 + ff`) | 32.84 | 22 | 0.54 |
| HashNeRF (`fp16 + cuda_ray + ff`) | 32.81 | 80 | 7.0 |
| TensoRF (`fp16`) | 33.81 | 18 | 0.53 |
| TensoRF (`fp16 + cuda_ray`) | 33.83 | 46 | 3.4 |

# Difference from the original implementation
* Instead of assuming the scene is bounded in the unit box `[0, 1]` and centered at `(0.5, 0.5, 0.5)`, this repo assumes **the scene is bounded in box `[-bound, bound]`, and centered at `(0, 0, 0)`**. Therefore, the functionality of `aabb_scale` is replaced by `bound` here.
* For the hashgrid encoder, this repo only implement the linear interpolation mode.
* For the voxel pruning in ray marching kernels, this repo doesn't implement the multi-scale density grid (check the `mip` keyword), and only use one `128x128x128` grid for simplicity. Instead of updating the grid every 16 steps, we update it every epoch, which may lead to slower first few epochs if using `--cuda_ray`.
* For the blender dataest, the default mode in instant-ngp is to load all data (train/val/test) for training. Instead, we only use the specified split to train in CMD mode for easy evaluation. However, for GUI mode, we follow instant-ngp and use all data to train (check `type='all'` for `NeRFDataset`).

# Update Logs
* 3.29: fix training speed for the fox dataset (balanced speed with performance...).
* 3.27: major update. basically improve performance, and support tensoRF model.
* 3.22: reverted from pre-generating rays as it takes too much CPU memory, still the PSNR for Lego can reach ~33 now.
* 3.14: fixed the precision related issue for `fp16` mode, and it renders much better quality. Added PSNR metric for NeRF.
Expand Down
4 changes: 2 additions & 2 deletions scripts/run_gui_nerf.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/fox --workspace trial_nerf --cuda_ray --gui
#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego --cuda_ray --bound 1.5 --scale 1.0 --mode blender --gui

#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/fox --workspace trial_nerf_ff2 --fp16 --ff --cuda_ray --gui
OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego_ff2_gui --fp16 --ff --cuda_ray --bound 1.5 --scale 1.0 --mode blender --gui
OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/fox --workspace trial_nerf_ff2 --fp16 --ff --cuda_ray --gui
#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego_ff2 --fp16 --ff --cuda_ray --bound 1.5 --scale 1.0 --mode blender --gui
14 changes: 7 additions & 7 deletions scripts/run_nerf.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
# OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/fox --workspace trial_nerf_tcnn --fp16 --tcnn

# OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/fox --workspace trial_nerf2 --fp16 --cuda_ray
# OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/fox --workspace trial_nerf_ff2 --fp16 --ff --cuda_ray
OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/fox --workspace trial_nerf_ff2 --fp16 --ff --cuda_ray
# OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/fox --workspace trial_nerf_tcnn2 --fp16 --tcnn --cuda_ray

#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego --fp16 --bound 1.5 --scale 1.0 --mode blender
#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego_ff --fp16 --ff --bound 1.5 --scale 1.0 --mode blender
#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego_tcnn --fp16 --tcnn --bound 1.5 --scale 1.0 --mode blender
# OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego --fp16 --bound 1.5 --scale 1.0 --mode blender
# OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego_ff --fp16 --ff --bound 1.5 --scale 1.0 --mode blender
# OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego_tcnn --fp16 --tcnn --bound 1.5 --scale 1.0 --mode blender

#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego2 --fp16 --cuda_ray --bound 1.5 --scale 1.0 --mode blender --test
OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego_ff2 --fp16 --ff --cuda_ray --bound 1.5 --scale 1.0 --mode blender
#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego_tcnn2 --fp16 --tcnn --cuda_ray --bound 1.5 --scale 1.0 --mode blender
# OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego2 --fp16 --cuda_ray --bound 1.5 --scale 1.0 --mode blender
# OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego_ff2 --fp16 --ff --cuda_ray --bound 1.5 --scale 1.0 --mode blender
# OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego_tcnn2 --fp16 --tcnn --cuda_ray --bound 1.5 --scale 1.0 --mode blender
14 changes: 10 additions & 4 deletions scripts/run_tensoRF.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
#! /bin/bash

#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main_tensoRF.py data/fox --workspace trial_tensoRF --fp16
#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main_tensoRF.py data/fox --workspace trial_tensoRF2 --fp16 --cuda_ray
#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_tensoRF.py data/fox --workspace trial_tensoRF --fp16
#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_tensoRF.py data/fox --workspace trial_tensoRF2 --fp16 --cuda_ray

#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main_tensoRF.py data/nerf_synthetic/lego --workspace trial_tensoRF_lego --fp16 --bound 1.5 --scale 1.0 --mode blender
OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main_tensoRF.py data/nerf_synthetic/lego --workspace trial_tensoRF_lego2 --fp16 --cuda_ray --bound 1.5 --scale 1.0 --mode blender
#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_tensoRF.py data/fox --workspace trial_tensoRF_CP --fp16 --cp
#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_tensoRF.py data/fox --workspace trial_tensoRF_CP2 --fp16 --cuda_ray --cp

OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_tensoRF.py data/nerf_synthetic/lego --workspace trial_tensoRF_lego --fp16 --bound 1.5 --scale 1.0 --mode blender
OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_tensoRF.py data/nerf_synthetic/lego --workspace trial_tensoRF_lego2 --fp16 --cuda_ray --bound 1.5 --scale 1.0 --mode blender

#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_tensoRF.py data/nerf_synthetic/lego --workspace trial_tensoRF_CP_lego --cp --fp16 --bound 1.5 --scale 1.0 --mode blender
#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_tensoRF.py data/nerf_synthetic/lego --workspace trial_tensoRF_CP_lego2 --fp16 --cuda_ray --cp --bound 1.5 --scale 1.0 --mode blender
4 changes: 2 additions & 2 deletions tensoRF/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ def __init__(self,
num_layers=3,
hidden_dim=128,
bound=1,
cuda_ray=False,
**kwargs
):
super().__init__(bound, cuda_ray)
super().__init__(bound, **kwargs)

self.resolution = resolution

Expand Down
Loading

0 comments on commit 873c267

Please sign in to comment.