## Load modules

In [1]:
%load_ext autoreload
%autoreload 2

# Verify conda environment
import sys
print(f"Python executable: {sys.executable}")
print(f"Python version: {sys.version}")

# Check if we're in the nerfstudio environment
if 'nerfstudio' not in sys.executable:
    print("\n⚠️  WARNING: Not running in nerfstudio conda environment!")
    print("Please activate with: conda activate nerfstudio")
else:
    print("\n✓ Running in nerfstudio environment")

from pathlib import Path
import cv2

# Import internal nerfstudio utilities
from nerfstudio.process_data import vggt_utils
from nerfstudio.process_data import colmap_utils
from nerfstudio.process_data.process_data_utils import (
    convert_video_to_images,
    CameraModel,
)

print("✓ Imports complete")

Python executable: /opt/conda/envs/nerfstudio/bin/python
Python version: 3.10.18 | packaged by conda-forge | (main, Jun  4 2025, 14:45:41) [GCC 13.3.0]

✓ Running in nerfstudio environment
✓ Imports complete


Process ForkProcess-1:
Process ForkProcess-2:
Process ForkProcess-3:
Process ForkProcess-24:
Process ForkProcess-26:
Process ForkProcess-15:
Process ForkProcess-28:
Process ForkProcess-6:
Process ForkProcess-23:
Process ForkProcess-20:
Process ForkProcess-29:
Process ForkProcess-17:
Process ForkProcess-8:
Process ForkProcess-9:
Process ForkProcess-5:
Process ForkProcess-19:
Process ForkProcess-14:
Process ForkProcess-7:
Process ForkProcess-30:
Process ForkProcess-4:
Process ForkProcess-21:
Process ForkProcess-25:
Process ForkProcess-27:
Process ForkProcess-31:
Process ForkProcess-16:
Process ForkProcess-13:
Process ForkProcess-22:
Process ForkProcess-18:
Process ForkProcess-32:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last)

## Run preproc

In [2]:
# Option 1: Use Splatter wrapper for training
from collab_splats.wrapper import Splatter, SplatterConfig

# Configuration
config_dir = Path("/workspace/collab-splats/docs/splats/configs/")
# dataset_name = "bicycle"
dataset_name = "birds_date-02062024_video-C0043"

# Create splatter from config
splatter = Splatter.from_config_file(
    dataset=dataset_name,
    config_dir=config_dir,
    # overrides={
    #     "frame_proportion": 0.1,
    # }
)

splatter.preprocess()

# splatter.preprocess(
#     sfm_tool='vggt',
#     overwrite=False, 
#     kwargs={
#         "refine-vggt": "",
#         "camera-type": "pinhole",
#         "verbose": "",
#         "num_downscales": 0,
#         "vggt_conf_threshold": 35.0,
#         "save_vggt_checkpoint": "",
#         # "skip_image_processing": "",
#     }  # Enable bundle adjustment
# )

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
✓ Valid video file with 2388 frames
transforms.json already exists at /workspace/fieldwork-data/birds/2024-02-06/environment/C0043/preproc/transforms.json
To rerun preprocessing, set overwrite=True


In [None]:
# Load the aligned poses in case we want them for visualization --> by default they are aligned to the splat
aligned_cameras = splatter.load_aligned_cameras(align_mesh=False)

In [None]:
## Step 4: Visualize the Sparse Point Cloud
import pyvista as pv
from collab_splats.utils.visualization import (
    CAMERA_KWARGS,
    MESH_KWARGS,
    VIZ_KWARGS,
    visualize_splat,
)

# Load the sparse point cloud
pcd_fn = splatter.config["preproc_data_path"] / "sparse_pc.ply"
splat = pv.PolyData(str(pcd_fn))

pcd_kwargs = MESH_KWARGS.copy()
pcd_kwargs.update(
    {
        "point_size": 2,
        "render_points_as_spheres": True,
        "ambient": 0.3,
        "diffuse": 0.8,
        "specular": 0.1,
    }
)

camera_kwargs = CAMERA_KWARGS.copy()
camera_kwargs.update(
    {
        "n_poses": 1,
    }
)

plotter = visualize_splat(
    mesh=splat,
    mesh_kwargs=pcd_kwargs,
    camera_kwargs=camera_kwargs,
    viz_kwargs=VIZ_KWARGS,
    aligned_cameras=aligned_cameras,
)

plotter.show()

## Run VGGT

In [3]:
import torch
from nerfstudio.process_data import vggt_utils

image_dir = splatter.config["preproc_data_path"] / "images"

vggt_output_dir = splatter.config["preproc_data_path"] / "vggt"

colmap_dir = splatter.config["preproc_data_path"] / "colmap"
vggt_ckpt_path = colmap_dir / "vggt_checkpoint.pt"

vggt_data_loaded =  torch.load(vggt_ckpt_path)

# Extract data from inference results
images = vggt_data_loaded["images"]
extrinsic = vggt_data_loaded["extrinsic"]
intrinsic = vggt_data_loaded["instrinsics"] # Upsampled to native resolution
intrinsic_downsampled = vggt_data_loaded["instrinsics_downsampled"] # Downsampled to VGGT resolution
depth_map = vggt_data_loaded["depth"]
depth_conf = vggt_data_loaded["depth_conf"]
image_paths = vggt_data_loaded["image_paths"]

image_basenames = [image_path.name for image_path in image_paths]

matches_fn = splatter.config["preproc_data_path"] / "colmap" / "matches.pt"
matches = torch.load(matches_fn)


In [12]:
torch.corrcoef(depth_map[0], depth_map[1])

TypeError: corrcoef() takes 1 positional argument but 2 were given

In [4]:
extrinsic_refined, intrinsic_refined = vggt_utils.pose_optimization(
    match_outputs=matches,
    extrinsic=extrinsic,
    intrinsic=intrinsic_downsampled,
    images=images,
    depth_conf=depth_conf,
    depth_maps=depth_map,
    base_image_path_list=image_basenames,
    target_scene_dir=colmap_dir,
    shared_intrinsics=True,
    lambda_depth=0.0,
)


Pose Optimization...: 100%|██████████| 300/300 [06:58<00:00,  1.39s/it]


In [9]:
# Filter and prepare 3D points in Facebook's format (points3d, points_xyf, points_rgb)
# Note: Facebook uses conf_threshold as a value (e.g., 5.0), not percentile
# For compatibility, we convert percentile to value if > 1
# Interpret as value directly
conf_threshold_value = 1.0

# Unproject depth map to point map
points3d = vggt_utils.unproject_depth_map_to_point_map(
    depth_map, 
    extrinsic_refined, 
    intrinsic_refined
)

In [10]:
# Filter points for pycolmap reconstruction using VGGTX logic
points3d, points_xyf, points_rgb = vggt_utils._filter_and_prepare_points_for_pycolmap(
    points3d=points3d,
    depth_map=depth_map,
    depth_conf=depth_conf,
    images=images,
    image_paths=image_paths,
    conf_thres_value=conf_threshold_value,
    use_global_alignment=True,
    match_outputs=matches,
)

In [13]:
import numpy as np

# Grab image size from depth map (N, H, W) --> make as width and height
image_size = np.array([depth_map.shape[2], depth_map.shape[1]])

reconstruction = vggt_utils._build_pycolmap_reconstruction_without_tracks(
    points3d=points3d,
    points_xyf=points_xyf,
    points_rgb=points_rgb,
    extrinsic=extrinsic_refined,
    intrinsic=intrinsic_refined,
    image_size=image_size,
    image_paths=image_paths,
    shared_camera=True,
    camera_type="PINHOLE",
    verbose=True,
)

In [16]:
reconstruction_resolution

(294, 518)

In [21]:
# # Step 2: Rescale reconstruction to original dimensions
reconstruction_resolution = (image_size[0], image_size[1]) # Reverse as it expects width and height

original_image_sizes = np.repeat([[0, 0, 294, 518, 1080, 1920]], len(image_paths), axis=0)

reconstruction = vggt_utils._rescale_reconstruction_to_original_dimensions(
    reconstruction=reconstruction,
    image_paths=image_paths,
    original_image_sizes=original_image_sizes,
    image_size=reconstruction_resolution,
    shift_point2d_to_original_res=True,
    shared_camera=True,
    verbose=True,
)

In [34]:
import copy

filtered_reconstruction = copy.deepcopy(reconstruction)
removed_image_ids = []
removal_reasons = {
    'few_matches': [],
    'few_points': [],
    'low_depth_conf': [],
    'high_reproj_error': [],
    'pose_discontinuity': [],
    'rotation_discontinuity': [],
}

num_images = len(reconstruction.images)

min_points_3d: int = 100
min_matches: int = 200
min_avg_depth_conf: float = 0.3
reprojection_error_percentile: float = 95  # Filter top 5% worst
pose_distance_std_factor: float = 3.0 # Outlier if >3 std from mean
rotation_angle_threshold: float = 45.0  # Degrees change between neighbors

# ========================================================================
# METRIC 5: Pose Discontinuity Detection (NEW - MOST IMPORTANT)
# Detects cameras with positions far from their temporal neighbors
# ========================================================================
if extrinsic is not None:
    camera_positions = extrinsic[:, :3, 3]  # Extract translation vectors
    
    # Compute distances to temporal neighbors
    position_distances = []
    for i in range(1, len(camera_positions) - 1):
        img_id = i + 1  # COLMAP uses 1-indexed
        
        if img_id in removed_image_ids:
            continue
        
        # Distance to previous and next frame
        dist_prev = np.linalg.norm(camera_positions[i] - camera_positions[i-1])
        dist_next = np.linalg.norm(camera_positions[i] - camera_positions[i+1])
        avg_dist = (dist_prev + dist_next) / 2
        
        position_distances.append(avg_dist)
    
    if position_distances:
        mean_dist = np.mean(position_distances)
        std_dist = np.std(position_distances)
        threshold_dist = mean_dist + pose_distance_std_factor * std_dist
        
        for i in range(1, len(camera_positions) - 1):
            img_id = i + 1
            
            if img_id in removed_image_ids:
                continue
            
            dist_prev = np.linalg.norm(camera_positions[i] - camera_positions[i-1])
            dist_next = np.linalg.norm(camera_positions[i] - camera_positions[i+1])
            avg_dist = (dist_prev + dist_next) / 2
            
            if avg_dist > threshold_dist:
                removed_image_ids.append(img_id)
                removal_reasons['pose_discontinuity'].append(img_id)
                # if verbose:
                print(f"[yellow]Image {img_id}: pose jump {avg_dist:.3f} > {threshold_dist:.3f}")


[yellow]Image 233: pose jump 0.188 > 0.170
[yellow]Image 370: pose jump 0.195 > 0.170
[yellow]Image 371: pose jump 0.201 > 0.170
[yellow]Image 374: pose jump 0.244 > 0.170
[yellow]Image 375: pose jump 0.296 > 0.170
[yellow]Image 376: pose jump 0.334 > 0.170
[yellow]Image 377: pose jump 0.259 > 0.170
[yellow]Image 417: pose jump 0.186 > 0.170
[yellow]Image 418: pose jump 0.188 > 0.170


In [60]:
reconstruction_filtered, removed_image_ids, removal_reasons = _filter_outlier_cameras_enhanced(
    reconstruction=reconstruction,
    match_outputs=matches,
    depth_conf=depth_conf,
    extrinsic=extrinsic_refined,
    images=images,
    min_points_3d=200,
    min_matches=2000,
    verbose=True,
)

In [66]:
depth_map.shape

(597, 518, 294, 1)

In [59]:
from typing import Any, Dict, List, Optional, Tuple
import copy
import numpy as np
import torch
from nerfstudio.utils.rich_utils import CONSOLE

def _filter_outlier_cameras_enhanced(
    reconstruction: Any,
    match_outputs: Optional[Dict[str, Any]] = None,
    depth_conf: Optional[np.ndarray] = None,
    extrinsic: Optional[np.ndarray] = None,
    images: Optional[torch.Tensor] = None,
    min_points_3d: int = 100,
    min_matches: int = 200,
    min_avg_depth_conf: float = 0.3,
    reprojection_error_percentile: float = 95,
    pose_distance_std_factor: float = 2.0,
    rotation_angle_threshold: float = 30.0,
    verbose: bool = False,
) -> Tuple[Any, List[int], Dict[str, List[int]]]:
    """
    Multi-metric outlier detection for camera poses.
    Returns:
        filtered reconstruction, removed image IDs, removal reasons
    """
    filtered_reconstruction = copy.deepcopy(reconstruction)
    removed_image_ids = []
    removal_reasons = {
        'few_matches': [],
        'few_points': [],
        'low_depth_conf': [],
        'high_reproj_error': [],
        'pose_discontinuity': [],
        'rotation_discontinuity': [],
    }
    
    num_images = len(reconstruction.images)

    # # ----------------------------
    # # METRIC 1: Feature Match Count
    # # ----------------------------
    # if match_outputs is not None:
    #     indexes_i = match_outputs["indexes_i_expanded"]
    #     indexes_j = match_outputs["indexes_j_expanded"]

    #     for img_id in reconstruction.images:
    #         mask = (indexes_i == img_id - 1) | (indexes_j == img_id - 1)
    #         match_count = mask.sum()
    #         print (match_count)
    #         if match_count < min_matches:
    #             removed_image_ids.append(img_id)
    #             removal_reasons['few_matches'].append(img_id)
    #             if verbose:
    #                 CONSOLE.print(f"[yellow]Image {img_id}: {match_count} matches < {min_matches}")
    # sys.exit(0)
    # ----------------------------
    # METRIC 2: Number of 3D Points
    # ----------------------------
    for img_id in reconstruction.images:
        if img_id in removed_image_ids:
            continue
        point_count = sum(1 for p in reconstruction.images[img_id].points2D if p.point3D_id != -1)
        if point_count < min_points_3d:
            removed_image_ids.append(img_id)
            removal_reasons['few_points'].append(img_id)
            if verbose:
                CONSOLE.print(f"[yellow]Image {img_id}: {point_count} 3D points < {min_points_3d}")

    # ----------------------------
    # METRIC 3: Average Depth Confidence
    # ----------------------------
    if depth_conf is not None:
        for img_id in reconstruction.images:
            if img_id in removed_image_ids:
                continue
            avg_conf = depth_conf[img_id - 1].mean()
            if avg_conf < min_avg_depth_conf:
                removed_image_ids.append(img_id)
                removal_reasons['low_depth_conf'].append(img_id)
                if verbose:
                    CONSOLE.print(f"[yellow]Image {img_id}: depth conf {avg_conf:.3f} < {min_avg_depth_conf}")

    # # ----------------------------
    # # METRIC 4: Reprojection Error
    # # ----------------------------
    # reprojection_errors = {}
    # for img_id, pyimage in reconstruction.images.items():
    #     if img_id in removed_image_ids:
    #         continue
    #     pycamera = reconstruction.cameras[pyimage.camera_id]

    #     errors = []
    #     # Obtain world->camera transformation
    #     if extrinsic is not None:
    #         W2C = extrinsic[img_id - 1]
    #     else:
    #         # fallback: pyimage.pose() returns world->camera matrix
    #         W2C = pyimage.pose()

    #     for point2D in pyimage.points2D:
    #         if (point3D_id := point2D.point3D_id) != -1:
    #             point3D = reconstruction.points3D[point3D_id]
    #             point_world_h = np.append(point3D.xyz, 1)
    #             point_cam = W2C @ point_world_h
    #             projected = pycamera.project(point_cam[:3])
    #             error = np.linalg.norm(projected - point2D.xy)
    #             errors.append(error)
    #     if errors:
    #         reprojection_errors[img_id] = np.mean(errors)

    # if reprojection_errors:
    #     threshold = np.percentile(list(reprojection_errors.values()), reprojection_error_percentile)
    #     for img_id, error in reprojection_errors.items():
    #         if error > threshold and img_id not in removed_image_ids:
    #             removed_image_ids.append(img_id)
    #             removal_reasons['high_reproj_error'].append(img_id)
    #             if verbose:
    #                 CONSOLE.print(f"[yellow]Image {img_id}: reproj error {error:.2f} > {threshold:.2f}")

    # ----------------------------
    # METRIC 5: Pose Discontinuity (MAD from neighbors)
    # ----------------------------
    if extrinsic is not None:
        camera_positions = extrinsic[:, :3, 3]
        n_neighbors = 3  # number of neighbors on each side
        num_cameras = len(camera_positions)

        for i in range(num_cameras):
            img_id = i + 1
            if img_id in removed_image_ids:
                continue

            # Select neighbor indices
            neighbor_indices = [j for j in range(max(0, i - n_neighbors), min(num_cameras, i + n_neighbors + 1))
                                if j != i]

            # Compute mean absolute deviation from neighbors
            neighbor_positions = camera_positions[neighbor_indices]
            mad = np.mean(np.linalg.norm(neighbor_positions - camera_positions[i], axis=1))

            # Compute global MAD threshold
            all_mads = []
            for k in range(num_cameras):
                neighbor_idx_k = [j for j in range(max(0, k - n_neighbors), min(num_cameras, k + n_neighbors + 1))
                                if j != k]
                neighbor_pos_k = camera_positions[neighbor_idx_k]
                all_mads.append(np.mean(np.linalg.norm(neighbor_pos_k - camera_positions[k], axis=1)))
            threshold_mad = np.mean(all_mads) + pose_distance_std_factor * np.std(all_mads)

            if mad > threshold_mad:
                removed_image_ids.append(img_id)
                removal_reasons['pose_discontinuity'].append(img_id)
                if verbose:
                    CONSOLE.print(f"[yellow]Image {img_id}: pose MAD {mad:.3f} > {threshold_mad:.3f}")

    # ----------------------------
    # METRIC 6: Rotation Discontinuity
    # ----------------------------
    if extrinsic is not None:
        for i in range(1, len(extrinsic) - 1):
            img_id = i + 1
            if img_id in removed_image_ids:
                continue
            R_curr = extrinsic[i, :3, :3]
            R_prev = extrinsic[i-1, :3, :3]
            R_next = extrinsic[i+1, :3, :3]

            # Relative rotations
            R_rel_prev = R_prev.T @ R_curr
            R_rel_next = R_curr.T @ R_next

            # Rotation angle (degrees)
            angle_prev = np.degrees(np.arccos(np.clip((np.trace(R_rel_prev) - 1)/2, -1, 1)))
            angle_next = np.degrees(np.arccos(np.clip((np.trace(R_rel_next) - 1)/2, -1, 1)))

            if angle_prev > rotation_angle_threshold or angle_next > rotation_angle_threshold:
                removed_image_ids.append(img_id)
                removal_reasons['rotation_discontinuity'].append(img_id)
                if verbose:
                    max_angle = max(angle_prev, angle_next)
                    CONSOLE.print(f"[yellow]Image {img_id}: rotation jump {max_angle:.1f}° > {rotation_angle_threshold}°")

    # ----------------------------
    # Final cleanup: remove images and orphaned points
    # ----------------------------
    removed_image_ids = list(set(removed_image_ids))
    for img_id in removed_image_ids:
        filtered_reconstruction.deregister_image(img_id)
        if img_id in filtered_reconstruction.images:
            del filtered_reconstruction.images[img_id]

    points_to_remove = []
    for point3D_id, point3D in filtered_reconstruction.points3D.items():
        track = point3D.track
        track.elements = [e for e in track.elements if e.image_id not in removed_image_ids]
        if not track.elements:
            points_to_remove.append(point3D_id)

    for point3D_id in points_to_remove:
        del filtered_reconstruction.points3D[point3D_id]

    # ----------------------------
    # Verbose summary
    # ----------------------------
    if verbose:
        CONSOLE.print(f"\n[bold cyan]Outlier Detection Summary:")
        CONSOLE.print(f"  Total images: {num_images}")
        CONSOLE.print(f"  Removed: {len(removed_image_ids)}")
        CONSOLE.print(f"  Remaining: {len(filtered_reconstruction.images)}")
        for reason, ids in removal_reasons.items():
            if ids:
                CONSOLE.print(f"  - {reason}: {len(ids)} images")
        CONSOLE.print(f"  Removed 3D points: {len(points_to_remove)}")

    return filtered_reconstruction, removed_image_ids, removal_reasons

In [30]:
mean_dist + pose_distance_std_factor * std_dist

TypeError: can't multiply sequence by non-int of type 'numpy.float64'

In [None]:
## Step 4: Visualize the Sparse Point Cloud
import pyvista as pv
from collab_splats.utils.visualization import (
    CAMERA_KWARGS,
    MESH_KWARGS,
    VIZ_KWARGS,
    visualize_splat,
)

# Load the sparse point cloud
# pcd_fn = splatter.config["preproc_data_path"] / "sparse_pc.ply"
splat = pv.PolyData(points3d)
splat.point_data["RGB"] = points_rgb

pcd_kwargs = MESH_KWARGS.copy()
pcd_kwargs.update(
    {
        "point_size": 2,
        "render_points_as_spheres": True,
        "ambient": 0.3,
        "diffuse": 0.8,
        "specular": 0.1,
    }
)

camera_kwargs = CAMERA_KWARGS.copy()
camera_kwargs.update(
    {
        "n_poses": 1,
    }
)

plotter = visualize_splat(
    mesh=splat,
    mesh_kwargs=pcd_kwargs,
    camera_kwargs=camera_kwargs,
    viz_kwargs=VIZ_KWARGS,
    aligned_cameras=extrinsic_refined,
)

plotter.show()