In [None]:
#Example NPZ FILE
!mkdir -p data/PittsburghBridge
!wget -P data/PittsburghBridge https://dl.fbaipublicfiles.com/pytorch3d/data/PittsburghBridge/pointcloud.npz

In [1]:
import sys
import torch

need_pytorch3d = False
try:
    import pytorch3d
except ModuleNotFoundError:
    need_pytorch3d = True
if need_pytorch3d:
    pyt_version_str = torch.__version__.split("+")[0].replace(".", "")
    version_str = "".join([
        f"py3{sys.version_info.minor}_cu",
        torch.version.cuda.replace(".", ""),
        f"_pyt{pyt_version_str}"
    ])
    !pip install iopath
    if sys.platform.startswith("linux"):
        print("Trying to install wheel for PyTorch3D")
        !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
        pip_list = !pip freeze
        need_pytorch3d = not any(i.startswith("pytorch3d==") for i in pip_list)
    if need_pytorch3d:
        print(f"failed to find/install wheel for {version_str}")
if need_pytorch3d:
    print("Installing PyTorch3D from source")
    !pip install ninja
    !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'

Collecting iopath
  Downloading iopath-0.1.10.tar.gz (42 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/42.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting portalocker (from iopath)
  Downloading portalocker-3.1.1-py3-none-any.whl.metadata (8.6 kB)
Downloading portalocker-3.1.1-py3-none-any.whl (19 kB)
Building wheels for collected packages: iopath
  Building wheel for iopath (setup.py) ... [?25l[?25hdone
  Created wheel for iopath: filename=iopath-0.1.10-py3-none-any.whl size=31528 sha256=ffdf7af4b26d736da94b7323af1bdf19b6f7927a98537e15d8d854b5101067a4
  Stored in directory: /root/.cache/pip/wheels/9a/a3/b6/ac0fcd1b4ed5cfeb3db92e6a0e476cfd48ed0df92b91080c1d
Successfully built iopath
Installing collected packages: portalocker, iopath
Successfully installed iopath-0.1.10 portaloc

In [2]:
from pytorch3d.structures import Pointclouds

def bounding_sphere_normalize(points: torch.Tensor) -> torch.Tensor:
    """
    points: (N,3) tensor of point coords
    Return normalized points in a unit sphere centered at origin.
    """
    center = points.mean(dim=0, keepdim=True)
    max_dist = (points - center).norm(p=2, dim=1).max()
    points_normed = (points - center) / max_dist
    return points_normed


def load_3d_data(file_path, num_points=10000, device="cuda", do_normalize=True):
    # Load NPZ point cloud directly like in the example
    pointcloud = np.load(file_path)
    verts = torch.Tensor(pointcloud['points']).to(device)
    rgb = torch.Tensor(pointcloud['colors']).to(device)

    # Subsample if needed
    if len(verts) > num_points:
        idx = torch.randperm(len(verts))[:num_points]
        verts = verts[idx]
        rgb = rgb[idx]

    if do_normalize:
        verts = bounding_sphere_normalize(verts)

    # Return both the points tensor and the Pointclouds object
    point_cloud = Pointclouds(points=[verts], features=[rgb])
    return point_cloud  # Return both



In [3]:
from itertools import islice

import torch
from pytorch3d.structures import Pointclouds
from pytorch3d.renderer import (
    look_at_view_transform,
    FoVOrthographicCameras,
    FoVPerspectiveCameras,
    PointsRasterizationSettings,
    PointsRenderer,
    PointsRasterizer,
    AlphaCompositor
)
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torchvision.transforms as T


class MultiViewPointCloudRenderer:
    def __init__(self, image_size=512, base_dist=20, base_elev=10, base_azim=0,
                 device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
        self.device = device
        self.image_size = image_size
        self.base_dist = base_dist
        self.base_elev = base_elev
        self.base_azim = base_azim
        self.to_tensor = T.Compose([
            T.Resize((image_size, image_size)),
            T.ToTensor()
        ])

        # Define the settings for rasterization
        self.raster_settings = PointsRasterizationSettings(
            image_size=image_size,
            radius=0.008,
            points_per_pixel=20
        )

        # Define all views relative to base view
        self.views = {
            'Default': (base_dist, base_elev, base_azim),
            'Y_90deg': (base_dist, base_elev, base_azim + 90),
            'Y_180deg': (base_dist, base_elev, base_azim + 180),
            'Y_-90deg': (base_dist, base_elev, base_azim - 90),
            'X_90deg': (base_dist, base_elev + 90, base_azim),
            'X_-90deg': (base_dist, base_elev - 90, base_azim),
        }


    def get_center_point(self, point_cloud):
        """Calculate the center point of the point cloud"""
        points = point_cloud.points_packed()
        center = torch.mean(points, dim=0)
        return center.unsqueeze(0)  # Add batch dimension

    def create_renderer(self, dist, elev, azim, center_point, background_color=(0, 0, 0)):
        """Create a renderer for specific camera parameters"""
        # Use the center point as the 'at' parameter
        R, T = look_at_view_transform(
            dist=dist,
            elev=elev,
            azim=azim,
            at=center_point,  # Look at the center of the point cloud
        )
        cameras = FoVPerspectiveCameras(
        device=self.device,
        R=R,
        T=T
        )

        rasterizer = PointsRasterizer(cameras=cameras, raster_settings=self.raster_settings)
        renderer = PointsRenderer(
            rasterizer=rasterizer,
            compositor=AlphaCompositor(background_color=background_color)
        )
        return renderer

    def load_background(self, background_path):
        bg_image = Image.open(background_path)
        bg_tensor = self.to_tensor(bg_image).to(self.device)
        return bg_tensor.permute(1, 2, 0)  # Convert to HWC format

    def render_all_views(self, point_cloud, n_views=6, background_path=None,background_color=(0, 0, 0)):
        images = {}
        center_point = self.get_center_point(point_cloud)

        if background_path:
            background = self.load_background(background_path)
        else:
            background = None

        for view_name, (dist, elev, azim) in islice(self.views.items(), n_views):
            renderer = self.create_renderer(dist, elev, azim, center_point,background_color=background_color)
            image = renderer(point_cloud)

            if background is not None:
                # Create binary mask from points
                mask = torch.any(image[0, ..., :3] > 0, dim=-1).float()
                mask = mask.unsqueeze(-1).expand(-1, -1, 3)
                composite = (image[0, ..., :3] * mask) + (background * (1 - mask))
                images[view_name] = composite
            else:
                images[view_name] = image[0, ..., :3]

        return images

In [4]:
import os
import numpy as np
import torch
import torchvision

def save_results(point_cloud, renderer,n_views,device,output_dir,output_name):

    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)

    rendered_images = renderer.render_all_views(point_cloud=point_cloud, n_views=n_views,background_color = (1,1,1))
    # Convert dictionary of images to tensor
    rendered_tensor = []
    for name, img in rendered_images.items():
        rendered_tensor.append(img.to(device))
    rendered_tensor = torch.stack(rendered_tensor)

    # Convert rendered images to CLIP format
    rendered_images = rendered_tensor.permute(0, 3, 1, 2)  # [B, H, W, C] -> [B, C, H, W]

    # Convert to uint8 range [0, 255]
    rendered_images = (rendered_images * 255).clamp(0, 255).to(torch.uint8)

    # Save rendered image using torchvision
    torchvision.utils.save_image(
        rendered_images.float() / 255.0,  # Convert back to [0,1] range
        os.path.join(output_dir, output_name),
         normalize=False  # We've already normalized the values
     )

In [100]:
device="cuda"

point_cloud = load_3d_data(
    "/content/highlighted_points.npz",
    num_points=100000
)


renderer = MultiViewPointCloudRenderer(
    image_size=1024,
    base_dist=2.5,  # Your default view distance
    base_elev=10,  # Your default elevation
    base_azim=0,  # Your default azimuth
    device=device
)

save_results(
    point_cloud=point_cloud,
    renderer=renderer,
    n_views=6,
    output_dir="./output",
    output_name="point_cloud2.png",
    device=device
)

In [73]:
from pytorch3d.vis.plotly_vis import plot_scene
plot_scene({
    "Pointcloud": {
        "person": point_cloud
    }
})

Ground truth was saved as ply file so we convert npz

In [None]:
!pip install open3d numpy

In [143]:
import open3d as o3d
import numpy as np

# Load the .ply file
pcd = o3d.io.read_point_cloud("/content/gt_pointcloud_cut.ply")

# Extract points and colors
points = np.asarray(pcd.points)
colors = np.asarray(pcd.colors)

# Lighten the base colors (adjust factor as needed)
colors = colors * 1.5  # Increase brightness by 50%
colors = np.clip(colors, 0, 1)  # Clip values to the valid range [0, 1]

# Save to .npz
np.savez_compressed("output.npz", points=points, colors=colors)

In [144]:
import numpy as np

# Load both files and check their contents
highlighted_data = np.load('highlighted_points.npz')
gt_data = np.load('output.npz')

# Print keys/contents of each file
print("Keys in highlighted_points.npz:", highlighted_data.files)
print("\nKeys in output.npz:", gt_data.files)

# Let's also look at the shapes of the arrays in each file
print("\nShapes in highlighted_points.npz:")
for key in highlighted_data.files:
    print(f"{key}: {highlighted_data[key].shape}")

print("\nShapes in output.npz:")
for key in gt_data.files:
    print(f"{key}: {gt_data[key].shape}")

# Access the 'points' array, which likely contains coordinates of all points
gt_points = gt_data['points']

# Access the 'colors' array if it exists (this is how color information is stored)
gt_colors = gt_data['colors']

# Print the keys available in the loaded data
print("Keys in output.npz:", gt_data.files)

# Print the shape of the 'points' and 'colors' arrays to understand their structure
gt_colors = gt_data['colors']
colored_points_indices = np.where(np.any(gt_colors > 0, axis=1))[0]
print("Colored Points Indices:", colored_points_indices)

Keys in highlighted_points.npz: ['points', 'colors', 'probabilities']

Keys in output.npz: ['points', 'colors']

Shapes in highlighted_points.npz:
points: (2048, 3)
colors: (2048, 3)
probabilities: (2048, 2)

Shapes in output.npz:
points: (2048, 3)
colors: (2048, 3)
Keys in output.npz: ['points', 'colors']
Colored Points Indices: [   0    1    2 ... 2045 2046 2047]


calculation of IOU

In [145]:
import numpy as np

def print_npz_details(file_path):
    """Prints details about an NPZ file."""
    data = np.load(file_path)
    print(f"Details for: {file_path}")
    print("-" * 30)
    print("Keys:", data.files)
    for key in data.files:
        array = data[key]
        print(f"\nArray: {key}")
        print("Shape:", array.shape)
        print("Data Type:", array.dtype)
        print("First 5 elements:", array[:5])  # Print a few elements for inspection
        print("-" * 20)

# Print details for both files
print_npz_details('output.npz')  # Ground truth file
print_npz_details('highlighted_points.npz')  # Highlighted file

Details for: output.npz
------------------------------
Keys: ['points', 'colors']

Array: points
Shape: (2048, 3)
Data Type: float64
First 5 elements: [[-0.0195685  -0.59234661  0.05510337]
 [-0.00985435  0.99639744  0.08423236]
 [-0.03765397  0.20444049 -0.07541861]
 [ 0.03902398 -0.21181168 -0.11270168]
 [ 0.00905969  0.58440924  0.08427081]]
--------------------

Array: colors
Shape: (2048, 3)
Data Type: float64
First 5 elements: [[0.45294118 0.45294118 0.45294118]
 [0.76470588 0.31764706 0.31764706]
 [0.45294118 0.45294118 0.45294118]
 [0.45294118 0.45294118 0.45294118]
 [0.70588235 0.34117647 0.34117647]]
--------------------
Details for: highlighted_points.npz
------------------------------
Keys: ['points', 'colors', 'probabilities']

Array: points
Shape: (2048, 3)
Data Type: float32
First 5 elements: [[-0.0195685  -0.5923466   0.05510337]
 [-0.00985435  0.99639744  0.08423236]
 [-0.03765397  0.20444049 -0.07541861]
 [ 0.03902398 -0.21181168 -0.11270168]
 [ 0.00905969  0.58440924

In [188]:
import open3d as o3d
import numpy as np
import os

# --------------------- 1) LOAD GROUND-TRUTH FROM .PLY ---------------------
def load_ply_point_cloud(ply_path):
    """
    Loads a PLY file as an Open3D PointCloud. Returns (o3d.geometry.PointCloud, np.ndarray points, np.ndarray colors).
    """
    pcd = o3d.io.read_point_cloud(ply_path)
    points = np.asarray(pcd.points)   # shape (N,3)
    colors = np.asarray(pcd.colors)   # shape (N,3) in [0,1]
    return pcd, points, colors

def color_to_gt_labels(colors):
    """
    Example function that returns a (N,) array of 0/1 for highlight vs. background,
    based on a 'red > 0.7' rule. Adjust threshold as needed.
    """
    r = colors[:, 0]
    g = colors[:, 1]
    b = colors[:, 2]
    mask_red = (
        (r > 0.7) &
        ((r - g) > 0.2) &
        ((r - b) > 0.2)
    )
    return mask_red.astype(np.uint8)

# --------------------- 2) LOAD PREDICTED NPZ ---------------------
def threshold_predictions(probabilities, threshold=0.3, highlight_idx=1):
    """
    probabilities: shape (N,2) or (N,...).
    We pick highlight_idx for the highlight probability.
    Return a (N,) array of 0/1
    """
    highlight_prob = probabilities[:, highlight_idx]
    return (highlight_prob >= threshold).astype(np.uint8)

# --------------------- 3) MESHING THE HIGHLIGHT POINTS ---------------------
def alpha_shape_mesh_from_points(pts, alpha=0.03):
    """
    Construct an alpha shape mesh from Nx3 points using Open3D.
    Return an open3d.geometry.TriangleMesh.
    - If 'alpha' is too small, you get many disconnected pieces or an empty mesh.
    - If 'alpha' is too large, you get a loose hull around all points.
    """
    if len(pts) < 4:
        # Cannot form a 3D mesh with fewer than 4 points
        return None

    # Create an Open3D point cloud
    pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(pts))
    # Create the alpha shape
    try:
        mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_alpha_shape(pcd, alpha=alpha)
        if len(mesh.triangles) == 0:
            return None
        # Clean-up
        mesh.remove_duplicated_vertices()
        mesh.remove_unreferenced_vertices()
        mesh.remove_duplicated_triangles()
        mesh = mesh.compute_triangle_normals()
        return mesh
    except Exception as e:
        print("Alpha shape error:", e)
        return None

def mesh_volume(mesh):
    """
    Return the volume of a (hopefully watertight) TriangleMesh using Open3D.
    """
    if mesh is None:
        return 0.0
    # boolean operations sometimes need "exact" method:
    #   mesh1.boolean_union(mesh2, tolerance=1e-6, method='exact')
    # but for volume, we can do:
    vol = mesh.get_volume()
    return vol if vol is not None else 0.0

# --------------------- 4) BOOLEAN IoU: Intersection, Union ---------------------
def mesh_intersection_union_iou(mesh_a, mesh_b):
    """
    3D IoU = Volume( A ∩ B ) / Volume( A ∪ B ).
    We'll compute:
      volA, volB, volAB (intersection), then union = volA + volB - volAB.
    """
    if mesh_a is None or mesh_b is None:
        return 0.0

    # volumes
    volA = mesh_volume(mesh_a)
    volB = mesh_volume(mesh_b)

    # intersection mesh
    try:
        inter_mesh = mesh_a.boolean_intersection(mesh_b, tolerance=1e-6, method="exact")
        volAB      = mesh_volume(inter_mesh)
    except Exception as e:
        print("Intersection error:", e)
        return 0.0

    union_vol = volA + volB - volAB
    iou = volAB / union_vol if union_vol > 1e-12 else 0.0
    return iou


# --------------------- MAIN SCRIPT EXAMPLE ---------------------
if __name__ == "__main__":
    # 1) Load GT .ply, get highlight points
    gt_ply = "gt_pointcloud_cut.ply"   # <-- path to your GT ply
    _, gt_points, gt_colors = load_ply_point_cloud(gt_ply)

    gt_mask = color_to_gt_labels(gt_colors)
    gt_highlight_pts = gt_points[gt_mask == 1]
    print(f"GT: total {len(gt_points)} pts, highlight = {len(gt_highlight_pts)}")

    # 2) Load predicted .npz, threshold highlight
    pred_npz = "highlighted_points.npz"
    data     = np.load(pred_npz)
    pred_pts = data["points"]           # shape (N,3)
    probs    = data["probabilities"]    # shape (N,2) or ...
    pred_mask = threshold_predictions(probs, threshold=0.3, highlight_idx=1)
    pred_highlight_pts = pred_pts[pred_mask == 1]
    print(f"Pred: total {len(pred_pts)} pts, highlight = {len(pred_highlight_pts)}")

    # 3) Build alpha-shape meshes for each highlight set
    #    Tune 'alpha' to get a watertight mesh that matches your geometry scale
    alpha_gt   = 4
    alpha_pred = 4

    mesh_gt   = alpha_shape_mesh_from_points(gt_highlight_pts, alpha=alpha_gt)
    mesh_pred = alpha_shape_mesh_from_points(pred_highlight_pts, alpha=alpha_pred)

    if mesh_gt is None:
        print("GT highlight mesh is empty. IoU=0.0")
    if mesh_pred is None:
        print("Pred highlight mesh is empty. IoU=0.0")

    # 4) Compute 3D IoU
    iou_3d = mesh_intersection_union_iou(mesh_gt, mesh_pred)
    print(f"3D IoU (Alpha Shape + Boolean): {iou_3d:.4f}")


GT: total 2048 pts, highlight = 16
Pred: total 2048 pts, highlight = 771
GT highlight mesh is empty. IoU=0.0
3D IoU (Alpha Shape + Boolean): 0.0000


In [110]:
import numpy as np

# Load prediction data
highlighted_data = np.load('highlighted_points.npz')
highlighted_points = highlighted_data['points']  # Predicted points (2048, 3)
highlighted_probs = highlighted_data['probabilities']  # Probabilities (2048, 2)

# Load ground truth data
gt_data = np.load('output.npz')
gt_points = gt_data['points']  # Ground truth points (2048, 3)

# Mock `labels_dict` for ground truth affordance labels (if not already loaded)
# Replace this with the actual ground truth labels for your dataset
labels_dict = {
    "cut": np.random.choice([0, 1], size=2048)  # Replace this with actual labels for "cut"
}

# Specify affordance to compute IoU (e.g., 'cut')
affordance = "cut"
gt_labels = labels_dict[affordance]  # Binary ground truth labels for the affordance

# Extract affordance probabilities
affordance_probs = highlighted_probs[:, 1]  # Probabilities for affordance class

# Evaluate IoU across a range of thresholds
thresholds = np.arange(0.0, 1.05, 0.01)  # Thresholds from 0.0 to 1.0 in steps of 0.05
iou_scores = []

for threshold in thresholds:
    # Generate binary predictions based on the current threshold
    highlighted_labels = (affordance_probs >= threshold).astype(int)

    # Calculate IoU
    intersection = np.sum((highlighted_labels == 1) & (gt_labels == 1))
    union = np.sum((highlighted_labels == 1) | (gt_labels == 1))
    iou = intersection / union if union > 0 else 0.0

    iou_scores.append(iou)
    print(f"Threshold: {threshold:.2f}, IoU: {iou:.4f}")

# Find the best threshold
best_threshold = thresholds[np.argmax(iou_scores)]
best_iou = max(iou_scores)

print(f"Best Threshold: {best_threshold:.2f}, Best IoU: {best_iou:.4f}")


Threshold: 0.00, IoU: 0.4810
Threshold: 0.01, IoU: 0.4810
Threshold: 0.02, IoU: 0.4810
Threshold: 0.03, IoU: 0.4810
Threshold: 0.04, IoU: 0.4810
Threshold: 0.05, IoU: 0.4810
Threshold: 0.06, IoU: 0.4810
Threshold: 0.07, IoU: 0.4810
Threshold: 0.08, IoU: 0.4810
Threshold: 0.09, IoU: 0.4810
Threshold: 0.10, IoU: 0.4810
Threshold: 0.11, IoU: 0.4810
Threshold: 0.12, IoU: 0.4810
Threshold: 0.13, IoU: 0.4810
Threshold: 0.14, IoU: 0.4810
Threshold: 0.15, IoU: 0.4810
Threshold: 0.16, IoU: 0.4810
Threshold: 0.17, IoU: 0.4810
Threshold: 0.18, IoU: 0.4810
Threshold: 0.19, IoU: 0.4810
Threshold: 0.20, IoU: 0.4810
Threshold: 0.21, IoU: 0.4810
Threshold: 0.22, IoU: 0.4810
Threshold: 0.23, IoU: 0.4810
Threshold: 0.24, IoU: 0.4810
Threshold: 0.25, IoU: 0.4810
Threshold: 0.26, IoU: 0.4810
Threshold: 0.27, IoU: 0.4810
Threshold: 0.28, IoU: 0.4810
Threshold: 0.29, IoU: 0.4810
Threshold: 0.30, IoU: 0.4810
Threshold: 0.31, IoU: 0.4810
Threshold: 0.32, IoU: 0.4810
Threshold: 0.33, IoU: 0.4810
Threshold: 0.3

In [162]:
!git clone https://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@github.com/amiralichangizi/Affordance3DHighlighter.git


Cloning into 'Affordance3DHighlighter'...
remote: Enumerating objects: 407, done.[K
remote: Counting objects: 100% (169/169), done.[K
remote: Compressing objects: 100% (129/129), done.[K
remote: Total 407 (delta 97), reused 92 (delta 40), pack-reused 238 (from 1)[K
Receiving objects: 100% (407/407), 5.40 MiB | 20.94 MiB/s, done.
Resolving deltas: 100% (237/237), done.


In [163]:
!pip install gdown
!gdown --id 1siZtGusB1LfQVapTvNOiYi8aeKKAgcDF
!unzip full-shape.zip -d /content/Affordance3DHighlighter/data/

Downloading...
From (original): https://drive.google.com/uc?id=1siZtGusB1LfQVapTvNOiYi8aeKKAgcDF
From (redirected): https://drive.google.com/uc?id=1siZtGusB1LfQVapTvNOiYi8aeKKAgcDF&confirm=t&uuid=febf834d-2f83-48de-bc0a-22d23222a853
To: /content/full-shape.zip
100% 558M/558M [00:04<00:00, 112MB/s]
Archive:  full-shape.zip
  inflating: /content/Affordance3DHighlighter/data/full_shape_train_data.pkl  
  inflating: /content/Affordance3DHighlighter/data/full_shape_val_data.pkl  


In [169]:
!pwd

/content


In [None]:
import sys
import os
sys.path.append('/content/Affordance3DHighlighter/src/') # Add the 'src' directory to your Python path
from data_loader_fullshape import FullShapeDataset, create_dataset_splits # Import the class
import os
import pickle
import torch
import numpy as np
from torch.utils.data import DataLoader
from scipy.spatial import cKDTree


# --------------------- 1) Load the dataset ---------------------
pkl_path = "your_dataset.pkl"  # <-- change to your actual dataset path
device   = "cpu"               # or "cuda"

# Instantiate the dataset
dataset = FullShapeDataset("/content/Affordance3DHighlighter/data/full_shape_train_data.pkl", device=device)
train_data, val_data, test_data = create_dataset_splits(dataset, val_ratio=0.1, test_ratio=0.1)


# Print all available shape IDs for inspection
print("Available shape IDs in the dataset:")
for i in range(len(dataset)):
    print(dataset[i]['shape_id'])

# --------------------- 2) Find the desired shape entry by ID ---------------------
target_shape_id = "df0a8c7d1629313915538488147db324"  # Replace with an ID from the printed list if needed
affordance      = "cut"  # or "contain", depending on your data

def find_shape_index(ds, shape_id):
    for i in range(len(ds)):
        if ds[i]['shape_id'] == shape_id:
            return i
    return None

idx = find_shape_index(dataset, target_shape_id)
if idx is None:
    raise ValueError(f"Shape ID {target_shape_id} not found in dataset!")


# ... (r
# --------------------- 3) Extract coords & GT labels from dataset ---------------------
entry      = dataset[idx]
coords     = entry['coords'].cpu().numpy()                    # shape (N, 3)
gt_labels  = entry['labels_dict'][affordance].cpu().numpy()   # shape (N,)

# Optional sanity check
print("Found shape:", entry['shape_id'])
print("Coordinates shape:", coords.shape)
print("GT label shape:", gt_labels.shape)
print(f"Number of positives = {gt_labels.sum()}")


# --------------------- 4) Load the highlight NPZ (prediction) ---------------------
pred_path  = "highlighted_points.npz"  # <-- your model's highlight file
pred_data  = np.load(pred_path)
pred_points = pred_data['points']         # shape (N, 3) - hopefully same order
pred_probs  = pred_data['probabilities']  # shape (N, 2) or similar

# We assume, from your code snippet, that `pred_class = net(points)` gave:
#   pred_class[:,0] => highlight
#   pred_class[:,1] => background
# If it's reversed for your net, flip the index below.
highlight_prob = pred_probs[:, 0]

# Threshold to produce a 0/1 predicted label
threshold   = 0.5
pred_labels = (highlight_prob >= threshold).astype(int)

print("Prediction shape:", pred_points.shape, pred_labels.shape)


# --------------------- 5) Compare & Compute IoU ---------------------
# -- (A) Direct index-to-index approach (only valid if coords match exactly) --
if coords.shape == pred_points.shape and np.allclose(coords, pred_points, atol=1e-6):
    # They appear to match in ordering
    intersection = np.sum((pred_labels == 1) & (gt_labels == 1))
    union        = np.sum((pred_labels == 1) | (gt_labels == 1))
    iou_direct   = intersection / union if union > 0 else 0.0
    print(f"IoU (direct index comparison): {iou_direct:.4f}")
else:
    print("WARNING: The dataset coords and pred coords do NOT match index by index.")
    print("         Attempting nearest-neighbor alignment below...")

    # -- (B) If they don’t line up exactly, do nearest-neighbor approach: --

    # Build k-d tree on your dataset coords
    tree = cKDTree(coords)
    # For each predicted point, find the nearest dataset point
    dist, nn_indices = tree.query(pred_points)

    # Now dataset GT label for that pred point is gt_labels[nn_indices]
    aligned_gt = gt_labels[nn_indices]

    intersection = np.sum((pred_labels == 1) & (aligned_gt == 1))
    union        = np.sum((pred_labels == 1) | (aligned_gt == 1))
    iou_nn       = intersection / union if union > 0 else 0.0
    print(f"IoU (nearest-neighbor alignment): {iou_nn:.4f}")
