From ee2035b9c0550c7fd35061e827c0870a6e335a62 Mon Sep 17 00:00:00 2001 From: Otto Seiskari Date: Thu, 11 Apr 2024 21:23:17 +0300 Subject: [PATCH] Camera pose optimization for Splatfacto (#2891) * Add pose optimization to Splatfacto * Disable Splatfacto pose optimization by default * Improve apply_to_camera for Gaussian Splatting pose optimization Do not chain modifications to camera_to_worlds to improve numerical stability and enable L2 rot/trans penalties. * Add separate mean and max rot/trans metrics to camera-opt * Tweak pose optimization hyperparameters Parameters used in the Gaussian Splatting on the Move paper v1 * Unit test fix for new cameara_optimizer training metrics * Adjust splatfacto-big camera pose optimization parameters Same parameters as in normal Splatfacto --------- Co-authored-by: jh-surh --- nerfstudio/cameras/camera_optimizers.py | 32 ++++++++++++++++--------- nerfstudio/configs/method_configs.py | 12 ++++++---- nerfstudio/models/splatfacto.py | 29 ++++++++++++++++++---- 3 files changed, 54 insertions(+), 19 deletions(-) diff --git a/nerfstudio/cameras/camera_optimizers.py b/nerfstudio/cameras/camera_optimizers.py index 6ab8cc2a05..9b47b7e53b 100644 --- a/nerfstudio/cameras/camera_optimizers.py +++ b/nerfstudio/cameras/camera_optimizers.py @@ -22,6 +22,7 @@ from dataclasses import dataclass, field from typing import Literal, Optional, Type, Union +import numpy import torch import tyro from jaxtyping import Float, Int @@ -151,15 +152,20 @@ def apply_to_raybundle(self, raybundle: RayBundle) -> None: raybundle.origins = raybundle.origins + correction_matrices[:, :3, 3] raybundle.directions = torch.bmm(correction_matrices[:, :3, :3], raybundle.directions[..., None]).squeeze() - def apply_to_camera(self, camera: Cameras) -> None: - """Apply the pose correction to the raybundle""" - if self.config.mode != "off": - assert camera.metadata is not None, "Must provide id of camera in its metadata" - assert "cam_idx" in camera.metadata, "Must provide id of camera in its metadata" - camera_idx = camera.metadata["cam_idx"] - adj = self(torch.tensor([camera_idx], dtype=torch.long, device=camera.device)) # type: ignore - adj = torch.cat([adj, torch.Tensor([0, 0, 0, 1])[None, None].to(adj)], dim=1) - camera.camera_to_worlds = torch.bmm(camera.camera_to_worlds, adj) + def apply_to_camera(self, camera: Cameras) -> torch.Tensor: + """Apply the pose correction to the world-to-camera matrix in a Camera object""" + if self.config.mode == "off": + return camera.camera_to_worlds + + assert camera.metadata is not None, "Must provide id of camera in its metadata" + if "cam_idx" not in camera.metadata: + # Evalutaion cams? + return camera.camera_to_worlds + + camera_idx = camera.metadata["cam_idx"] + adj = self(torch.tensor([camera_idx], dtype=torch.long, device=camera.device)) # type: ignore + adj = torch.cat([adj, torch.Tensor([0, 0, 0, 1])[None, None].to(adj)], dim=1) + return torch.bmm(camera.camera_to_worlds, adj) def get_loss_dict(self, loss_dict: dict) -> None: """Add regularization""" @@ -176,8 +182,12 @@ def get_correction_matrices(self): def get_metrics_dict(self, metrics_dict: dict) -> None: """Get camera optimizer metrics""" if self.config.mode != "off": - metrics_dict["camera_opt_translation"] = self.pose_adjustment[:, :3].norm() - metrics_dict["camera_opt_rotation"] = self.pose_adjustment[:, 3:].norm() + trans = self.pose_adjustment[:, :3].detach().norm(dim=-1) + rot = self.pose_adjustment[:, 3:].detach().norm(dim=-1) + metrics_dict["camera_opt_translation_max"] = trans.max() + metrics_dict["camera_opt_translation_mean"] = trans.mean() + metrics_dict["camera_opt_rotation_mean"] = numpy.rad2deg(rot.mean().cpu()) + metrics_dict["camera_opt_rotation_max"] = numpy.rad2deg(rot.max().cpu()) def get_param_groups(self, param_groups: dict) -> None: """Get camera optimizer parameters""" diff --git a/nerfstudio/configs/method_configs.py b/nerfstudio/configs/method_configs.py index a78823995d..cbbec0d04f 100644 --- a/nerfstudio/configs/method_configs.py +++ b/nerfstudio/configs/method_configs.py @@ -632,8 +632,10 @@ }, "quats": {"optimizer": AdamOptimizerConfig(lr=0.001, eps=1e-15), "scheduler": None}, "camera_opt": { - "optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15), - "scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-5, max_steps=30000), + "optimizer": AdamOptimizerConfig(lr=1e-4, eps=1e-15), + "scheduler": ExponentialDecaySchedulerConfig( + lr_final=5e-7, max_steps=30000, warmup_steps=1000, lr_pre_warmup=0 + ), }, }, viewer=ViewerConfig(num_rays_per_chunk=1 << 15), @@ -684,8 +686,10 @@ }, "quats": {"optimizer": AdamOptimizerConfig(lr=0.001, eps=1e-15), "scheduler": None}, "camera_opt": { - "optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15), - "scheduler": ExponentialDecaySchedulerConfig(lr_final=5e-5, max_steps=30000), + "optimizer": AdamOptimizerConfig(lr=1e-4, eps=1e-15), + "scheduler": ExponentialDecaySchedulerConfig( + lr_final=5e-7, max_steps=30000, warmup_steps=1000, lr_pre_warmup=0 + ), }, }, viewer=ViewerConfig(num_rays_per_chunk=1 << 15), diff --git a/nerfstudio/models/splatfacto.py b/nerfstudio/models/splatfacto.py index 7d7a78d2d2..0aebc63b44 100644 --- a/nerfstudio/models/splatfacto.py +++ b/nerfstudio/models/splatfacto.py @@ -33,6 +33,7 @@ from torch.nn import Parameter from typing_extensions import Literal +from nerfstudio.cameras.camera_optimizers import CameraOptimizer, CameraOptimizerConfig from nerfstudio.cameras.cameras import Cameras from nerfstudio.data.scene_box import OrientedBox from nerfstudio.engine.callbacks import TrainingCallback, TrainingCallbackAttributes, TrainingCallbackLocation @@ -146,6 +147,8 @@ class SplatfactoModelConfig(ModelConfig): However, PLY exported with antialiased rasterize mode is not compatible with classic mode. Thus many web viewers that were implemented for classic mode can not render antialiased mode PLY properly without modifications. """ + camera_optimizer: CameraOptimizerConfig = field(default_factory=lambda: CameraOptimizerConfig(mode="off")) + """Config of the camera optimizer to use""" class SplatfactoModel(Model): @@ -213,6 +216,10 @@ def populate_modules(self): } ) + self.camera_optimizer: CameraOptimizer = self.config.camera_optimizer.setup( + num_cameras=self.num_train_data, device="cpu" + ) + # metrics from torchmetrics.image import PeakSignalNoiseRatio from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity @@ -609,6 +616,7 @@ def get_param_groups(self) -> Dict[str, List[Parameter]]: Mapping of different parameter groups """ gps = self.get_gaussian_param_groups() + self.camera_optimizer.get_param_groups(param_groups=gps) return gps def _get_downscale_factor(self): @@ -648,6 +656,8 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: # get the background color if self.training: + optimized_camera_to_world = self.camera_optimizer.apply_to_camera(camera)[0, ...] + if self.config.background_color == "random": background = torch.rand(3, device=self.device) elif self.config.background_color == "white": @@ -657,6 +667,8 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: else: background = self.background_color.to(self.device) else: + optimized_camera_to_world = camera.camera_to_worlds[0, ...] + if renderers.BACKGROUND_COLOR_OVERRIDE is not None: background = renderers.BACKGROUND_COLOR_OVERRIDE.to(self.device) else: @@ -674,8 +686,9 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: camera_downscale = self._get_downscale_factor() camera.rescale_output_resolution(1 / camera_downscale) # shift the camera to center of scene looking at center - R = camera.camera_to_worlds[0, :3, :3] # 3 x 3 - T = camera.camera_to_worlds[0, :3, 3:4] # 3 x 1 + R = optimized_camera_to_world[:3, :3] # 3 x 3 + T = optimized_camera_to_world[:3, 3:4] # 3 x 1 + # flip the z and y axes to align with gsplat conventions R_edit = torch.diag(torch.tensor([1, -1, -1], device=self.device, dtype=R.dtype)) R = R @ R_edit @@ -738,7 +751,7 @@ def get_outputs(self, camera: Cameras) -> Dict[str, Union[torch.Tensor, List]]: self.xys.retain_grad() if self.config.sh_degree > 0: - viewdirs = means_crop.detach() - camera.camera_to_worlds.detach()[..., :3, 3] # (N, 3) + viewdirs = means_crop.detach() - optimized_camera_to_world.detach()[:3, 3] # (N, 3) viewdirs = viewdirs / viewdirs.norm(dim=-1, keepdim=True) n = min(self.step // self.config.sh_degree_interval, self.config.sh_degree) rgbs = spherical_harmonics(n, viewdirs, colors_crop) @@ -829,6 +842,8 @@ def get_metrics_dict(self, outputs, batch) -> Dict[str, torch.Tensor]: metrics_dict["psnr"] = self.psnr(predicted_rgb, gt_rgb) metrics_dict["gaussian_count"] = self.num_points + + self.camera_optimizer.get_metrics_dict(metrics_dict) return metrics_dict def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Tensor]: @@ -867,11 +882,17 @@ def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict[str, torch.Te else: scale_reg = torch.tensor(0.0).to(self.device) - return { + loss_dict = { "main_loss": (1 - self.config.ssim_lambda) * Ll1 + self.config.ssim_lambda * simloss, "scale_reg": scale_reg, } + if self.training: + # Add loss from camera optimizer + self.camera_optimizer.get_loss_dict(loss_dict) + + return loss_dict + @torch.no_grad() def get_outputs_for_camera(self, camera: Cameras, obb_box: Optional[OrientedBox] = None) -> Dict[str, torch.Tensor]: """Takes in a camera, generates the raybundle, and computes the output of the model.