In [None]:
import sys
from pathlib import Path
import torch
import numpy as np
import pycolmap
import cv2
import trimesh

In [None]:
from scene.colmap_loader import rotmat2qvec
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
from dust3r.image_pairs import make_pairs
from dust3r.utils.device import to_numpy
from dust3r.utils.image import load_images
from dust3r.inference import inference as run_depth_inference, load_model
sys.path.append('dust3r')
sys.path.append('gaussian-splatting')

In [None]:
def invert_4x4(matrix):
    return torch.inverse(matrix) if isinstance(matrix, torch.Tensor) else np.linalg.inv(matrix)

In [None]:
def tensor_to_numpy(tensor: torch.Tensor):
    return tensor.cpu().detach().numpy()

In [None]:
def locate_image_files(directory: Path):
    imgs = sorted([f for f in directory.iterdir() if f.suffix.lower() in ['.png', '.jpg']],
                  key=lambda x: int(x.stem))
    if not imgs:
        raise FileNotFoundError("No image files found in the given directory.")
    return imgs

In [None]:
def run_inference_on_depth(img_files, model_ckpt, dev, img_resolution, pair_strateg, batch_sz):
    mdl = load_model(model_ckpt, dev)
    imgs_tensor = load_images([str(f) for f in img_files], size=img_resolution)
    pair_config = make_pairs(
        imgs_tensor, scene_graph=pair_strateg, prefilter=None, symmetrize=True)
    depth_res = run_depth_inference(pair_config, mdl, dev, batch_size=batch_sz)
    return imgs_tensor, depth_res

In [None]:
def execute_global_alignment(inferred_data, dev, iterations, sched, lr_val, conf_threshold):
    align_handle = global_aligner(
        inferred_data, device=dev, mode=GlobalAlignerMode.PointCloudOptimizer)
    align_handle.min_conf_thr = float(
        align_handle.conf_trf(torch.tensor(conf_threshold)))
    align_handle.compute_global_alignment(
        init="mst", niter=iterations, schedule=sched, lr=lr_val)
    return align_handle

In [None]:
def extract_scene_info(align_obj):
    intrinsic_params = tensor_to_numpy(align_obj.get_intrinsics())
    cam2world_mats = tensor_to_numpy(align_obj.get_im_poses())
    world2cam_mats = invert_4x4(cam2world_mats)
    principal_pts = tensor_to_numpy(align_obj.get_principal_points())
    focal_lengths = tensor_to_numpy(align_obj.get_focals())
    raw_imgs = np.array(align_obj.imgs)
    pts3d_data = [p.detach() for p in align_obj.get_pts3d()]
    bin_masks = to_numpy(align_obj.get_masks())
    return intrinsic_params, cam2world_mats, world2cam_mats, principal_pts, focal_lengths, raw_imgs, pts3d_data, bin_masks

In [None]:
def normalize_geometry(pts3d_arrays, masks_arrays, c2w_matrices):
    all_valid_pts = []
    for pts, msk in zip(pts3d_arrays, masks_arrays):
        chosen = pts[torch.from_numpy(msk).bool()].view(-1, 3)
        all_valid_pts.append(chosen)

    combined_pts = torch.cat(all_valid_pts, dim=0)
    midpoint = combined_pts.mean(dim=0)
    scale_factor = torch.norm(combined_pts - midpoint, dim=1).max()

    normalized_pts3d = []
    mod_c2w = []
    midpoint_np = midpoint.numpy()
    scale_np = scale_factor.item()

    for pts, mat in zip(pts3d_arrays, c2w_matrices):
        adjusted_pts = (pts - midpoint) / scale_factor
        normalized_pts3d.append(adjusted_pts)
        mat_copy = mat.copy()
        mat_copy[:3, 3] = (mat_copy[:3, 3] - midpoint_np) / scale_np
        mod_c2w.append(mat_copy)

    return normalized_pts3d, mod_c2w

In [None]:
def build_scene_output_environment(dest_dir: Path):
    if not dest_dir.exists():
        dest_dir.mkdir(parents=True, exist_ok=True)
    out_img_dir = dest_dir / 'images'
    recon_dir = dest_dir / 'sparse' / '0'
    out_img_dir.mkdir(parents=True, exist_ok=True)
    recon_dir.mkdir(parents=True, exist_ok=True)
    return out_img_dir, recon_dir

In [None]:
def export_image_data(scene_imgs, img_dir):
    for i, img_arr in enumerate(scene_imgs):
        img_save_path = img_dir / f"{i}.png"
        img_8bit = (img_arr * 255).astype(np.uint8)
        final_img = cv2.cvtColor(img_8bit, cv2.COLOR_BGR2RGB)
        cv2.imwrite(str(img_save_path), final_img)

In [None]:
def create_pointcloud(scene_imgs, pts3d_arr, masks_arr):
    img_np = to_numpy(scene_imgs)
    pts_np = [to_numpy(p) for p in pts3d_arr]
    msk_np = to_numpy(masks_arr)

    combined_xyz = np.concatenate([p[m].reshape(-1, 3)
                                  for p, m in zip(pts_np, msk_np)])
    combined_rgb = np.concatenate(
        [im[m].reshape(-1, 3) for im, m in zip(img_np, msk_np)])
    xyz_sub = combined_xyz[::3]
    rgb_sub = combined_rgb[::3]

    normals_arr = np.tile([1, 0, 0], (xyz_sub.shape[0], 1))
    pc_obj = trimesh.PointCloud(xyz_sub, colors=(rgb_sub*255).astype(np.uint8))
    pc_obj.vertices_normal = normals_arr
    return pc_obj

In [None]:
def project_initial_view(xyz_coords, focal_list, princ_points, w2c_mats, imgs_arr):
    height, width = imgs_arr.shape[1], imgs_arr.shape[2]
    fx = focal_list[0][0]
    fy = fx
    cx, cy = princ_points[0]
    rot_mat = w2c_mats[0, :3, :3]
    qw, qx, qy, qz = rotmat2qvec(rot_mat)
    tx, ty, tz = w2c_mats[0, :3, 3]
    R_mat = pycolmap.Rotation3d(np.array([qx, qy, qz, qw])).matrix()
    t_vec = np.array([tx, ty, tz]).reshape(3, 1)

    cam_space = (R_mat @ xyz_coords.T + t_vec).T
    x_proj = (cam_space[:, 0]*fx / cam_space[:, 2]) + cx
    y_proj = (cam_space[:, 1]*fy / cam_space[:, 2]) + cy

    forward = cam_space[:, 2] > 0
    in_frame = (x_proj >= 0) & (x_proj < width) & (
        y_proj >= 0) & (y_proj < height)
    valid_pts = forward & in_frame

    return xyz_coords[valid_pts], np.stack([x_proj[valid_pts], y_proj[valid_pts]], axis=1)

In [None]:
def construct_reconstruction(w2c_mats, focal_vals, ppoints, imgs_arr, valid_xyz, init_keypoints, clr_data, recon_dir):
    reconstruction = pycolmap.Reconstruction()
    h, w = imgs_arr.shape[1], imgs_arr.shape[2]

    for i, (focal, pp) in enumerate(zip(focal_vals, ppoints), start=1):
        fx = fy = focal[0]
        cx, cy = pp
        cam = pycolmap.Camera(
            model='PINHOLE',
            width=w,
            height=h,
            params=[fx, fy, cx, cy]
        )
        cam.camera_id = i
        reconstruction.add_camera(cam)

    first_rot = w2c_mats[0, :3, :3]
    qw, qx, qy, qz = rotmat2qvec(first_rot)
    r_1 = pycolmap.Rotation3d(np.array([qx, qy, qz, qw]))
    t_1 = w2c_mats[0, :3, 3]
    first_tf = pycolmap.Rigid3d(r_1, t_1)
    first_img = pycolmap.Image(
        name="0.png",
        cam_from_world=first_tf,
        camera_id=1,
        keypoints=init_keypoints
    )
    first_img.image_id = 1
    reconstruction.add_image(first_img)

    for i in range(1, w2c_mats.shape[0]):
        img_id = i + 1
        rot_next = w2c_mats[i, :3, :3]
        qw, qx, qy, qz = rotmat2qvec(rot_next)
        r_next = pycolmap.Rotation3d(np.array([qx, qy, qz, qw]))
        t_next = w2c_mats[i, :3, 3]
        next_tf = pycolmap.Rigid3d(r_next, t_next)
        next_img = pycolmap.Image(
            name=f"{i}.png",
            cam_from_world=next_tf,
            camera_id=img_id
        )
        next_img.image_id = img_id
        reconstruction.add_image(next_img)

    for idx, (pt_3d, col_v) in enumerate(zip(valid_xyz, clr_data)):
        xyz_col = pt_3d.astype(np.float64).reshape((3, 1))
        rgb_val = col_v[:3].astype(np.uint8).reshape((3, 1))
        track_el = pycolmap.Track(
            elements=[pycolmap.TrackElement(image_id=1, point2D_idx=idx)])
        p3D_id = reconstruction.add_point3D(
            xyz=xyz_col, track=track_el, color=rgb_val)
        reconstruction.images[1].set_point3D_for_point2D(idx, p3D_id)

    reconstruction.write(recon_dir)

In [None]:
program_params = {
    "target_object": "husky",
    "base_image_dir": Path('images'),
    "model_checkpoint": "dust3r/checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth",
    "inference_device": "cuda:0",
    "input_resolution": 512,
    "pair_strat": "complete",
    "inference_batch": 32,
    "global_alignment_iters": 250,
    "init_learning_rate": 0.01,
    "min_confidence_thres": 25,
    "do_normalization": False,
    "learning_rate_policy": "cosine",
    "scene_output_dir": Path('results-dust3r')
}

input_directory = Path.joinpath(
    program_params['base_image_dir'], program_params['target_object'])
image_files = locate_image_files(input_directory)

scene_images, depth_inference = run_inference_on_depth(
    image_files,
    program_params["model_checkpoint"],
    program_params["inference_device"],
    program_params["input_resolution"],
    program_params["pair_strat"],
    program_params["inference_batch"]
)

alignment_op = execute_global_alignment(
    depth_inference,
    program_params["inference_device"],
    program_params["global_alignment_iters"],
    program_params["learning_rate_policy"],
    program_params["init_learning_rate"],
    program_params["min_confidence_thres"]
)

intrinsics, cam2world, world2cam, p_pts, focs, scene_images, scene_pts3d, scene_msks = extract_scene_info(
    alignment_op)

if program_params["do_normalization"]:
    scene_pts3d, cam2world = normalize_geometry(
        scene_pts3d, scene_msks, cam2world)
    world2cam = np.linalg.inv(cam2world)

In [None]:
out_dir = Path.joinpath(
    program_params['scene_output_dir'], program_params['target_object'])
images_dir, reconstruction_dir = build_scene_output_environment(out_dir)
export_image_data(scene_images, images_dir)

pc_obj = create_pointcloud(scene_images, scene_pts3d, scene_msks)
xyz_points = pc_obj.vertices
col_points = pc_obj.colors

valid_xyzs, kpts_first = project_initial_view(
    xyz_points, focs, p_pts, world2cam, scene_images)

In [None]:
construct_reconstruction(world2cam, focs, p_pts, scene_images,
                         valid_xyzs, kpts_first, col_points, reconstruction_dir)

In [None]:
gs_path = "gaussian-splatting"
obj_nm = program_params['target_object']
d3r_output_dir = str(out_dir
gs_output_dir = str(Path.joinpath(Path("results-gaussian-splatting"), obj_nm))

In [None]:
# train gaussian splatting on dust3r results
!python "{gs_path}/train.py" -s "{d3r_output_dir}" -m "{gs_output_dir}"

In [None]:
# render model
!python "{gs_path}/render.py" -m "{gs_output_dir}"

In [None]:
# interactive viewer
!"./{gs_path}/SIBR_viewers/install/bin/SIBR_gaussianViewer_app" -m "{gs_output_dir}" 