In [None]:
import torch
(model_params, _) = torch.load("output/colmap-geo_stage1/chkpnt10000.pth")
(active_sh_degree, xyz, _features_dc, _features_rest, 
_scaling, _rotation, opacity, max_radii2D, 
xyz_gradient_accum, denom,opt_dict, spatial_lr_scale) = model_params

In [None]:
from utils.general_utils import build_scaling_rotation
scaling = torch.exp(_scaling)
L = build_scaling_rotation(1 * scaling, _rotation)
covs3D = L @ L.transpose(1, 2)
# important: can help if the matrix is near-singular
covs3D += (1e-8 * torch.eye(3, device=scaling.device)).view(1, 3, 3)

In [None]:
import open3d as o3d

In [None]:
from scene.gaussian_geo_model_mlp_flex_colmap import GaussianGeoModel
import mcubes
device = 'cuda'
coarse_res = 160
scene_scale_radius = 2. * 4.965
scene_scale = scene_scale_radius*2
xyz_min = torch.tensor([[-0.5, -0.5, -0.5]], device=device) * scene_scale
xyz_max = torch.tensor([[0.5, 0.5, 0.5]], device=device) * scene_scale
# filter xyz
inbox_mask = (xyz_min<xyz).all(dim=1) & (xyz<xyz_max).all(dim=1)
xyz = xyz[inbox_mask]
covs3D = covs3D[inbox_mask]
opacity = opacity[inbox_mask]
scaling = scaling[inbox_mask]

density_grid = GaussianGeoModel.project_gs_grid(coarse_res, xyz, covs3D, opacity, scaling, xyz_min, xyz_max)
mc_grid = density_grid.squeeze()
vertices, triangles = mcubes.marching_cubes(mc_grid.cpu().numpy(), 0.3)
vertices = (vertices / (coarse_res - 1.0) - 0.5) * scene_scale  # [0,(res-1)] to [xyz_min, xyz_max]

In [None]:
import trimesh
mesh = trimesh.Trimesh(vertices=vertices, faces=triangles)
mesh.show()

In [None]:
import point_cloud_utils as pcu

vw, fw = pcu.make_mesh_watertight(vertices, triangles, 100_000)

import trimesh
mesh = trimesh.Trimesh(vertices=vw, faces=fw)
mesh.show()

In [None]:
# post proc
import open3d as o3d
import numpy as np

# fg_pcd = o3d.geometry.PointCloud()
# fg_pcd.points = o3d.utility.Vector3dVector(xyz.detach().double().cpu().numpy())
# cl, ind = fg_pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=20.)
# fg_pcd = fg_pcd.select_by_index(ind)
# fg_pcd.estimate_normals(
#     o3d.geometry.KDTreeSearchParamKNN(knn=30), fast_normal_computation=True)
# o3d_mesh, o3d_densities = \
#     o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(fg_pcd, depth=9)

# v_to_remove = o3d_densities < np.quantile(o3d_densities, 0.01)
# o3d_mesh.remove_vertices_by_mask(v_to_remove)

# o3d_mesh.remove_degenerate_triangles()
# o3d_mesh.remove_duplicated_triangles()
# o3d_mesh.remove_duplicated_vertices()
# o3d_mesh.remove_non_manifold_edges()

In [None]:
import numpy as np
from geo.flexicubes import FlexiCubes

# 3. init flexicube (differentiable iso-surface extrator)
marching_geo    = FlexiCubes()
grid_res = coarse_res
marching_verts, indices = marching_geo.construct_voxel_grid(grid_res)
marching_verts *= scene_scale
# sdf, fid, bc = pcu.signed_distance_to_mesh(marching_verts.cpu().numpy(), vw.astype(np.float32), fw)

sdf = torch.zeros([coarse_res+1]*3, device=device)
sdf[:-1, :-1, :-1] = mc_grid-0.1
sdf = sdf.flatten()
sdf = torch.nn.Parameter(torch.tensor(sdf, device=device), requires_grad=True)
deform = torch.nn.Parameter(torch.zeros_like(marching_verts), requires_grad=True)
per_cube_weights = torch.nn.Parameter(
    torch.ones((indices.shape[0], 21), dtype=torch.float32, device=device), requires_grad=True)

In [None]:
v_deformed = marching_verts + 0 * torch.tanh(deform)
verts, faces, reg_loss = marching_geo(v_deformed, sdf, indices, grid_res, 
                    per_cube_weights[:,:12], per_cube_weights[:,12:20], per_cube_weights[:,20],
                    training=True)
import trimesh
mesh = trimesh.Trimesh(vertices=verts.detach().cpu(), faces=faces.detach().cpu())
mesh.show()

In [None]:
marching_verts.shape, indices.shape

In [None]:
po