Skip to content

Commit

Permalink
Camera pose optimization for Splatfacto (nerfstudio-project#2891)
Browse files Browse the repository at this point in the history
* Add pose optimization to Splatfacto

* Disable Splatfacto pose optimization by default

* Improve apply_to_camera for Gaussian Splatting pose optimization

Do not chain modifications to camera_to_worlds to improve
numerical stability and enable L2 rot/trans penalties.

* Add separate mean and max rot/trans metrics to camera-opt

* Tweak pose optimization hyperparameters

Parameters used in the Gaussian Splatting on the Move paper v1

* Unit test fix for new cameara_optimizer training metrics

* Adjust splatfacto-big camera pose optimization parameters

Same parameters as in normal Splatfacto

---------

Co-authored-by: jh-surh <jh.surh@bucketplace.net>
  • Loading branch information
oseiskar and jh-surh authored Apr 11, 2024
1 parent 2d9bbe5 commit eba72db
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 19 deletions.
32 changes: 21 additions & 11 deletions nerfstudio/cameras/camera_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from dataclasses import dataclass, field
from typing import Literal, Optional, Type, Union

import numpy
import torch
import tyro
from jaxtyping import Float, Int
Expand Down Expand Up @@ -151,15 +152,20 @@ def apply_to_raybundle(self, raybundle: RayBundle) -> None:
raybundle.origins = raybundle.origins + correction_matrices[:, :3, 3]
raybundle.directions = torch.bmm(correction_matrices[:, :3, :3], raybundle.directions[..., None]).squeeze()

def apply_to_camera(self, camera: Cameras) -> None:
"""Apply the pose correction to the raybundle"""
if self.config.mode != "off":
assert camera.metadata is not None, "Must provide id of camera in its metadata"
assert "cam_idx" in camera.metadata, "Must provide id of camera in its metadata"
camera_idx = camera.metadata["cam_idx"]
adj = self(torch.tensor([camera_idx], dtype=torch.long, device=camera.device)) # type: ignore
adj = torch.cat([adj, torch.Tensor([0, 0, 0, 1])[None, None].to(adj)], dim=1)
camera.camera_to_worlds = torch.bmm(camera.camera_to_worlds, adj)
def apply_to_camera(self, camera: Cameras) -> torch.Tensor:
"""Apply the pose correction to the world-to-camera matrix in a Camera object"""
if self.config.mode == "off":
return camera.camera_to_worlds

assert camera.metadata is not None, "Must provide id of camera in its metadata"
if "cam_idx" not in camera.metadata:
# Evalutaion cams?
return camera.camera_to_worlds

camera_idx = camera.metadata["cam_idx"]
adj = self(torch.tensor([camera_idx], dtype=torch.long, device=camera.device)) # type: ignore
adj = torch.cat([adj, torch.Tensor([0, 0, 0, 1])[None, None].to(adj)], dim=1)
return torch.bmm(camera.camera_to_worlds, adj)

def get_loss_dict(self, loss_dict: dict) -> None:
"""Add regularization"""
Expand All @@ -176,8 +182,12 @@ def get_correction_matrices(self):
def get_metrics_dict(self, metrics_dict: dict) -> None:
"""Get camera optimizer metrics"""
if self.config.mode != "off":
metrics_dict["camera_opt_translation"] = self.pose_adjustment[:, :3].norm()
metrics_dict["camera_opt_rotation"] = self.pose_adjustment[:, 3:].norm()
trans = self.pose_adjustment[:, :3].detach().norm(dim=-1)
rot = self.pose_adjustment[:, 3:].detach().norm(dim=-1)
metrics_dict["camera_opt_translation_max"] = trans.max()
metrics_dict["camera_opt_translation_mean"] = trans.mean()
metrics_dict["camera_opt_rotation_mean"] = numpy.rad2deg(rot.mean().cpu())
metrics_dict["camera_opt_rotation_max"] = numpy.rad2deg(rot.max().cpu())

def get_param_groups(self, param_groups: dict) -> None:
"""Get camera optimizer parameters"""
Expand Down
12 changes: 8 additions & 4 deletions nerfstudio/configs/method_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,8 +632,10 @@
},
"quats": {"optimizer": AdamOptimizerConfig(lr=0.001, eps=1e-15), "scheduler": None},
"camera_opt": {
"optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15),
"scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-5, max_steps=30000),
"optimizer": AdamOptimizerConfig(lr=1e-4, eps=1e-15),
"scheduler": ExponentialDecaySchedulerConfig(
lr_final=5e-7, max_steps=30000, warmup_steps=1000, lr_pre_warmup=0
),
},
},
viewer=ViewerConfig(num_rays_per_chunk=1 << 15),
Expand Down Expand Up @@ -684,8 +686,10 @@
},
"quats": {"optimizer": AdamOptimizerConfig(lr=0.001, eps=1e-15), "scheduler": None},
"camera_opt": {
"optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15),
"scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-5, max_steps=30000),
"optimizer": AdamOptimizerConfig(lr=1e-4, eps=1e-15),
"scheduler": ExponentialDecaySchedulerConfig(
lr_final=5e-7, max_steps=30000, warmup_steps=1000, lr_pre_warmup=0
),
},
},
viewer=ViewerConfig(num_rays_per_chunk=1 << 15),
Expand Down
29 changes: 25 additions & 4 deletions nerfstudio/models/splatfacto.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from torch.nn import Parameter
from typing_extensions import Literal

from nerfstudio.cameras.camera_optimizers import CameraOptimizer, CameraOptimizerConfig
from nerfstudio.cameras.cameras import Cameras
from nerfstudio.data.scene_box import OrientedBox
from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes, TrainingCallbackLocation
Expand Down Expand Up @@ -146,6 +147,8 @@ class SplatfactoModelConfig(ModelConfig):
However, PLY exported with antialiased rasterize mode is not compatible with classic mode. Thus many web viewers that
were implemented for classic mode can not render antialiased mode PLY properly without modifications.
"""
camera_optimizer: CameraOptimizerConfig = field(default_factory=lambda: CameraOptimizerConfig(mode="off"))
"""Config of the camera optimizer to use"""


class SplatfactoModel(Model):
Expand Down Expand Up @@ -213,6 +216,10 @@ def populate_modules(self):
}
)

self.camera_optimizer: CameraOptimizer = self.config.camera_optimizer.setup(
num_cameras=self.num_train_data, device="cpu"
)

# metrics
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
Expand Down Expand Up @@ -609,6 +616,7 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]:
Mapping of different parameter groups
"""
gps = self.get_gaussian_param_groups()
self.camera_optimizer.get_param_groups(param_groups=gps)
return gps

def _get_downscale_factor(self):
Expand Down Expand Up @@ -648,6 +656,8 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:

# get the background color
if self.training:
optimized_camera_to_world = self.camera_optimizer.apply_to_camera(camera)[0, ...]

if self.config.background_color == "random":
background = torch.rand(3, device=self.device)
elif self.config.background_color == "white":
Expand All @@ -657,6 +667,8 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
else:
background = self.background_color.to(self.device)
else:
optimized_camera_to_world = camera.camera_to_worlds[0, ...]

if renderers.BACKGROUND_COLOR_OVERRIDE is not None:
background = renderers.BACKGROUND_COLOR_OVERRIDE.to(self.device)
else:
Expand All @@ -674,8 +686,9 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
camera_downscale = self._get_downscale_factor()
camera.rescale_output_resolution(1 / camera_downscale)
# shift the camera to center of scene looking at center
R = camera.camera_to_worlds[0, :3, :3] # 3 x 3
T = camera.camera_to_worlds[0, :3, 3:4] # 3 x 1
R = optimized_camera_to_world[:3, :3] # 3 x 3
T = optimized_camera_to_world[:3, 3:4] # 3 x 1

# flip the z and y axes to align with gsplat conventions
R_edit = torch.diag(torch.tensor([1, -1, -1], device=self.device, dtype=R.dtype))
R = R @ R_edit
Expand Down Expand Up @@ -738,7 +751,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]:
self.xys.retain_grad()

if self.config.sh_degree > 0:
viewdirs = means_crop.detach() - camera.camera_to_worlds.detach()[..., :3, 3] # (N, 3)
viewdirs = means_crop.detach() - optimized_camera_to_world.detach()[:3, 3] # (N, 3)
viewdirs = viewdirs / viewdirs.norm(dim=-1, keepdim=True)
n = min(self.step // self.config.sh_degree_interval, self.config.sh_degree)
rgbs = spherical_harmonics(n, viewdirs, colors_crop)
Expand Down Expand Up @@ -829,6 +842,8 @@ def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]:
metrics_dict["psnr"] = self.psnr(predicted_rgb, gt_rgb)

metrics_dict["gaussian_count"] = self.num_points

self.camera_optimizer.get_metrics_dict(metrics_dict)
return metrics_dict

def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Tensor]:
Expand Down Expand Up @@ -867,11 +882,17 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te
else:
scale_reg = torch.tensor(0.0).to(self.device)

return {
loss_dict = {
"main_loss": (1 - self.config.ssim_lambda) * Ll1 + self.config.ssim_lambda * simloss,
"scale_reg": scale_reg,
}

if self.training:
# Add loss from camera optimizer
self.camera_optimizer.get_loss_dict(loss_dict)

return loss_dict

@torch.no_grad()
def get_outputs_for_camera(self, camera: Cameras, obb_box: Optional[OrientedBox] = None) -> Dict[str, torch.Tensor]:
"""Takes in a camera, generates the raybundle, and computes the output of the model.
Expand Down

0 comments on commit eba72db

Please sign in to comment.