In [1]:
import numpy as np


In [3]:
import os.path as osp
import cv2
import numpy as np

import sys
sys.path.append('../../../') 

from dust3r.datasets.base.base_stereo_view_dataset import BaseStereoViewDataset
from dust3r.utils.image import imread_cv2


In [18]:
from itertools import combinations
from collections import defaultdict
from tqdm.auto import tqdm
from scipy.spatial import cKDTree
from sklearn.neighbors import NearestNeighbors

def compute_pointcloud_overlaps_scikit(pointcloud1, pointcloud2, distance_threshold=0.05, compute_symmetric=False):
    """
    Compute 'overlapping' metrics based on a distance threshold between two point clouds.
    """
    nbrs = NearestNeighbors(n_neighbors=1, algorithm = 'kd_tree').fit(pointcloud2)
    distances, indices = nbrs.kneighbors(pointcloud1)
    intersection = np.count_nonzero(distances.flatten() < distance_threshold)

    return intersection

def simulate_occlusion(point_cloud, radius):
    """
    Simulates occlusion by attaching a ball of fixed radius to each 3D point.
    Points occluded by another point within the radius are removed.

    Parameters:
    - point_cloud: numpy.ndarray of shape (N, 3) or (H, W, 3), the input 3D point cloud.
    - radius: float, the radius of the occlusion ball.

    Returns:
    - visible_points: numpy.ndarray of shape (M, 3), the filtered point cloud
      with occluded points removed.
    """
    # Ensure point cloud is 2D
    if len(point_cloud.shape) == 3:
        point_cloud = point_cloud.reshape(-1, 3)

    # Use KDTree for efficient neighborhood queries
    tree = cKDTree(point_cloud)
    
    # Find all neighbors within the radius
    indices = tree.query_ball_tree(tree, radius)
    
    # Mark visible points (only the first point in each neighborhood is visible)
    visible_mask = np.full(len(point_cloud), True, dtype=bool)
    for i, neighbors in enumerate(indices):
        if visible_mask[i]:
            # Mark all neighbors as occluded except the first point
            for neighbor in neighbors[1:]:
                visible_mask[neighbor] = False
    
    # Filter point cloud
    visible_points = point_cloud[visible_mask]
    return visible_points

def compute_iou_with_occlusion(pc1, pc2, radius=0.0005):
    """
    Computes the IoU between two 3D point clouds with artificial occlusion.

    Parameters:
    - pc1: numpy.ndarray of shape (N, 3) or (H, W, 3), first 3D point cloud.
    - pc2: numpy.ndarray of shape (M, 3) or (H, W, 3), second 3D point cloud.
    - radius: float, the radius of the occlusion ball.

    Returns:
    - iou: float, the Intersection over Union (IoU) value.
    """
    # # Simulate occlusion
    visible_pc1 = simulate_occlusion(pc1, radius)
    visible_pc2 = simulate_occlusion(pc2, radius)
    
    print(pc1.shape, visible_pc1.shape)
    
    # visible_pc1 = pc1
    # visible_pc2 = pc2
    
    print(len(visible_pc1))
    print(len(visible_pc2))
    
    pc1_set = set(map(tuple, np.round(visible_pc1, decimals=5))) 
    pc2_set = set(map(tuple, np.round(visible_pc2, decimals=5)))
    
    # Compute intersection and union
    intersection = pc1_set.intersection(pc2_set)
    
    # intersection = compute_pointcloud_overlaps_scikit(visible_pc1, visible_pc2)
    
    if intersection == 0:
        return 0
    
    print('intersection', len(intersection))
    print('pc1_set', list(pc1_set)[:10])
    print('pc2_set', list(pc2_set)[:10])
    
    union = pc1_set.union(pc2_set)
    
    print('union', len(union))
    
    # IoU calculation
    iou = len(intersection) / len(union) if len(union) > 0 else 0.0
    return iou

def quality_pair_score(iou, alpha):
    """
    Compute the quality pair score s = IoU × 4 cos(α)(1 - cos(α))
    """
    angle_term = 4 * np.cos(alpha) * (1 - np.cos(alpha))
    return iou * angle_term if angle_term > 0 else 0

def select_best_pairs(dataset, iou_threshold=0.75, score_threshold=0.1, pairs_per_image=5, pairs_number=10):
    """
    Select the best image pairs using a greedy algorithm
    """
    # pairs = list(combinations(dataset, 2))
    pair_scores = []

    for i in range(100):
        img1 = dataset[i][0]
        img2 = dataset[i][1]
        
        print(img1['idx'], img2['idx'])
        
        iou = compute_iou_with_occlusion(img1['pts3d'], img2['pts3d'])
        # print('IOU', iou)
        if iou == 0:
            continue
        
        pose1, pose2 = img1['camera_pose'], img2['camera_pose']
        rotation_diff = pose1[:3, :3].T @ pose2[:3, :3]  # Relative rotation matrix
        alpha = np.arccos(np.clip((np.trace(rotation_diff) - 1) / 2, -1, 1))  # Angle in radians
        
        score = quality_pair_score(iou, alpha)
        # print('SCORE', img1['idx'], img2['idx'], score)
        if score > score_threshold:
            pair_scores.append((score, img1, img2))

    pair_scores.sort(key=lambda x: x[0], reverse=True)

    selected_pairs = []
    used_images = set()

    for score, img1, img2 in pair_scores:
        if img1['idx'] in used_images or img2['idx'] in used_images:
            continue

        selected_pairs.append((img1, img2))
        used_images.add(img1['idx'])
        used_images.add(img2['idx'])

        for img in dataset:
            if img['idx'] not in used_images:
                if compute_iou(img['pts3d'], img1['pts3d']) > iou_threshold or \
                   compute_iou(img['pts3d'], img2['pts3d']) > iou_threshold:
                    used_images.add(img['idx'])
                    
    # used_images = defaultdict(int)
    # 
    # for score, img1, img2 in pair_scores:
    #     if used_images[img1['idx']] > pairs_per_image or used_images[img2['idx']] > pairs_per_image:
    #         continue
    # 
    #     selected_pairs.append((img1, img2))
    #     used_images[img1['idx']] += 1
    #     used_images[img2['idx']] += 1

    # if len(selected_pairs) < pairs_number: 
    #     for i in range(len(dataset) - 1):
    #         if len(selected_pairs) >= pairs_number:
    #             break
    #         img1, img2 = dataset[i], dataset[i + 1]
    #         selected_pairs.append((img1, img2))

    return selected_pairs

selected_pairs = select_best_pairs(dataset)
for img1, img2 in selected_pairs:
    print(f"Selected pair: Image {img1['idx']} and Image {img2['idx']}")

(0, 0, 0) (0, 0, 1)
(224, 224, 3) (46286, 3)
46286
45807
intersection 0
pc1_set [(np.float32(0.53887), np.float32(19.93385), np.float32(39.68432)), (np.float32(-4.19409), np.float32(18.60422), np.float32(30.90037)), (np.float32(-0.71065), np.float32(19.1427), np.float32(27.55704)), (np.float32(2.37833), np.float32(19.62042), np.float32(28.97518)), (np.float32(-0.1642), np.float32(19.42238), np.float32(29.22748)), (np.float32(-6.26467), np.float32(17.53969), np.float32(33.99681)), (np.float32(2.3409), np.float32(19.61215), np.float32(28.59921)), (np.float32(0.55356), np.float32(19.51973), np.float32(29.55323)), (np.float32(-0.27119), np.float32(19.12891), np.float32(28.79658)), (np.float32(-0.28487), np.float32(19.9083), np.float32(37.32868))]
pc2_set [(np.float32(-1.54691), np.float32(19.02116), np.float32(27.27707)), (np.float32(-3.343), np.float32(19.24456), np.float32(32.07844)), (np.float32(-0.53884), np.float32(19.69632), np.float32(43.17264)), (np.float32(0.34636), np.float32(19.

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7f0bddba4390>>
Traceback (most recent call last):
  File "/home/aleksandra/miniconda3/envs/mast3r/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 
Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7f0bddba4390>>
Traceback (most recent call last):
  File "/home/aleksandra/miniconda3/envs/mast3r/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 


(224, 224, 3) (45807, 3)
45807
43761
intersection 0
pc1_set [(np.float32(-1.54691), np.float32(19.02116), np.float32(27.27707)), (np.float32(-3.343), np.float32(19.24456), np.float32(32.07844)), (np.float32(-0.53884), np.float32(19.69632), np.float32(43.17264)), (np.float32(0.34636), np.float32(19.88206), np.float32(32.3911)), (np.float32(-1.07006), np.float32(19.31424), np.float32(42.57601)), (np.float32(-1.40352), np.float32(19.7374), np.float32(41.56367)), (np.float32(1.55256), np.float32(19.66838), np.float32(30.01482)), (np.float32(-0.85863), np.float32(19.63175), np.float32(31.45906)), (np.float32(-5.67402), np.float32(19.27238), np.float32(41.01636)), (np.float32(-2.84474), np.float32(19.44769), np.float32(39.65714))]
pc2_set [(np.float32(-5.90428), np.float32(20.62486), np.float32(38.56856)), (np.float32(-9.63354), np.float32(20.08081), np.float32(38.83392)), (np.float32(-0.20941), np.float32(21.22225), np.float32(41.17075)), (np.float32(-10.21953), np.float32(19.67044), np.flo

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7f0bddba4390>>
Traceback (most recent call last):
  File "/home/aleksandra/miniconda3/envs/mast3r/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 
Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7f0bddba4390>>
Traceback (most recent call last):
  File "/home/aleksandra/miniconda3/envs/mast3r/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 


KeyboardInterrupt: 

In [None]:
selected_pairs

In [60]:
for img1, img2 in selected_pairs:
    # print(img1[0])
    print(f"Selected pair: Image {img1[0]['idx']} and Image {img2[0]['idx']}")

In [4]:
from dust3r.datasets import UnderWaterDataset

In [15]:
dataset = UnderWaterDataset(split='train', ROOT='/home/aleksandra/dense_glomap_output', resolution=224)

In [9]:
type(dataset[0][0]['pts3d'])

numpy.ndarray

In [42]:
selected_pairs = select_best_pairs(dataset)
for img1, img2 in selected_pairs:
    print(f"Selected pair: Image {img1['idx']} and Image {img2['idx']}")

float32 float32


TypeError: unhashable type: 'numpy.ndarray'