From 2213bb3733267c021f8cbf7651bcb39b4804d4d9 Mon Sep 17 00:00:00 2001 From: Max Li Date: Fri, 18 Aug 2023 12:00:47 -0400 Subject: [PATCH] fix ag training divergence, fix ng radius --- projects/neuralangelo/model.py | 2 +- projects/neuralangelo/utils/modules.py | 21 ++++++++++----------- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/projects/neuralangelo/model.py b/projects/neuralangelo/model.py index 685e60b..8d769ba 100644 --- a/projects/neuralangelo/model.py +++ b/projects/neuralangelo/model.py @@ -163,7 +163,7 @@ def render_rays_object(self, center, ray_unit, near, far, outside, app, stratifi sdfs[outside[..., None].expand_as(sdfs)] = self.outside_val # Compute 1st- and 2nd-order gradients. rays_unit = ray_unit[..., None, :].expand_as(points).contiguous() # [B,R,N,3] - gradients, hessians = self.neural_sdf.compute_gradients(points, compute_hessian=self.training, sdf=sdfs) + gradients, hessians = self.neural_sdf.compute_gradients(points, training=self.training, sdf=sdfs) normals = torch_F.normalize(gradients, dim=-1) # [B,R,N,3] rgbs = self.neural_rgb.forward(points, normals, rays_unit, feats, app=app) # [B,R,N,3] # SDF volume rendering. diff --git a/projects/neuralangelo/utils/modules.py b/projects/neuralangelo/utils/modules.py index b6bc039..a360494 100644 --- a/projects/neuralangelo/utils/modules.py +++ b/projects/neuralangelo/utils/modules.py @@ -111,28 +111,25 @@ def _get_coarse2fine_mask(self, points_enc, feat_dim): mask[..., :(self.active_levels * feat_dim)] = 1 return mask - def compute_gradients(self, x, compute_hessian=False, sdf=None): + def compute_gradients(self, x, training=False, sdf=None): # Note: hessian is not fully hessian but diagonal elements if self.cfg_sdf.gradient.mode == "analytical": requires_grad = x.requires_grad with torch.enable_grad(): # 1st-order gradient x.requires_grad_(True) - y = self.sdf(x) - gradient = torch.autograd.grad(y.sum(), x, create_graph=True)[0] + sdf = self.sdf(x) + gradient = torch.autograd.grad(sdf.sum(), x, create_graph=True)[0] # 2nd-order gradient (hessian) - if compute_hessian: + if training: hessian = torch.autograd.grad(gradient.sum(), x, create_graph=True)[0] else: hessian = None + gradient = gradient.detach() x.requires_grad_(requires_grad) - if not requires_grad: - gradient = gradient.detach() - if compute_hessian: - hessian = hessian.detach() elif self.cfg_sdf.gradient.mode == "numerical": - eps = self.normal_eps if self.cfg_sdf.gradient.taps == 6: + eps = self.normal_eps # 1st-order gradient eps_x = torch.tensor([eps, 0., 0.], dtype=x.dtype, device=x.device) # [3] eps_y = torch.tensor([0., eps, 0.], dtype=x.dtype, device=x.device) # [3] @@ -148,7 +145,7 @@ def compute_gradients(self, x, compute_hessian=False, sdf=None): gradient_z = (sdf_z_pos - sdf_z_neg) / (2 * eps) gradient = torch.cat([gradient_x, gradient_y, gradient_z], dim=-1) # [...,3] # 2nd-order gradient (hessian) - if compute_hessian: + if training: assert sdf is not None # computed when feed-forwarding through the network hessian_xx = (sdf_x_pos + sdf_x_neg - 2 * sdf) / (eps ** 2) # [...,1] hessian_yy = (sdf_y_pos + sdf_y_neg - 2 * sdf) / (eps ** 2) # [...,1] @@ -157,6 +154,7 @@ def compute_gradients(self, x, compute_hessian=False, sdf=None): else: hessian = None elif self.cfg_sdf.gradient.taps == 4: + eps = self.normal_eps / np.sqrt(3) k1 = torch.tensor([1, -1, -1], dtype=x.dtype, device=x.device) # [3] k2 = torch.tensor([-1, -1, 1], dtype=x.dtype, device=x.device) # [3] k3 = torch.tensor([-1, 1, -1], dtype=x.dtype, device=x.device) # [3] @@ -166,7 +164,8 @@ def compute_gradients(self, x, compute_hessian=False, sdf=None): sdf3 = self.sdf(x + k3 * eps) # [...,1] sdf4 = self.sdf(x + k4 * eps) # [...,1] gradient = (k1*sdf1 + k2*sdf2 + k3*sdf3 + k4*sdf4) / (4.0 * eps) - if compute_hessian: + if training: + assert sdf is not None # computed when feed-forwarding through the network # the result of 4 taps is directly trace, but we assume they are individual components # so we use the same signature as 6 taps hessian_xx = ((sdf1 + sdf2 + sdf3 + sdf4) / 2.0 - 2 * sdf) / eps ** 2 # [N,1]