In [1]:
%load_ext autoreload
%autoreload 2

import os, sys
from pathlib import Path

from matplotlib import pyplot as plt

from nerfstudio.utils.eval_utils import eval_setup
# from ns_extension.utils.grouping import GroupingClassifier

[Taichi] version 1.7.3, llvm 15.0.4, commit 5ec301be, linux, python 3.10.18


[I 07/15/25 20:36:41.723 64208] [shell.py:_shell_pop_print@23] Graphical python shell detected, using wrapped sys.stdout


### Load configuration

In [2]:
# Path to the config for a trained model
load_config = '/workspace/fieldwork-data/rats/2024-07-11/environment/C0119/rade-features/2025-07-11_171420/config.yml'
load_config = Path(load_config)

config, pipeline, checkpoint_path, step = eval_setup(load_config)

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


### Question is whether to build grouping on top of the existing data?

Initialize segmentation model

In [3]:
camera, batch = pipeline.datamanager.next_train(0)

fn = pipeline.datamanager.train_dataset.image_filenames[0]

Output()

In [7]:
from PIL import Image
import numpy as np

from ns_extension.utils.features import resize_image
from ns_extension.utils.segmentation import Segmentation

segmentation = Segmentation(
    backend='mobilesamv2',
    strategy='object',
    device='cuda',
)
segmentation.strategy = 'auto'

image = Image.open(fn)
H, W = image.height, image.width

# # Prepare image for segmentation
# image = resize_image(image, longest_edge=1024) # Resize image to SAM resolution

# Apply segmentation masks over features
image = np.asarray(image) # Convert to numpy array
masks, results = segmentation.segment(image)

Using cache found in /workspace/models/hub/RogerQi_MobileSAMV2_main


checkpoint_load_scucess


  return F.conv2d(input, weight, bias, self.stride,


### Start making our functions here

In [26]:
import math
import torch

def create_composite_mask(results, confidence_threshold=0.85):
    """
    Creates a composite mask from the results of the segmentation model.
    
    Inputs:
        results: list of dicts, each containing a mask and a confidence score
        confidence_threshold: float, the minimum confidence score for a mask to be included in the composite mask

    Outputs:
        composite_mask: numpy array, the composite mask
    """

    selected_masks = []
    for mask in results:
        if mask['predicted_iou'] < confidence_threshold:
            continue

        selected_masks.append(
            (mask['segmentation'], mask['predicted_iou'])
        )
    
    # Store the masks and confidences
    masks, confs = zip(*selected_masks)

    # Create empty image to store mask ids
    mask_id = np.zeros((H, W), dtype=np.uint8)

    sorted_idxs = np.argsort(confs)
    for i, idx in enumerate(sorted_idxs, start=1):
        current_mask = masks[idx - 1]
        mask_id[current_mask == 1] = i

    # Find mask indices after having calculated overlap based on ranked confidence
    mask_indices = np.unique(mask_id)
    mask_indices = np.setdiff1d(mask_indices, [0]) # remove 0 item

    composite_mask = np.zeros((H, W), dtype=np.uint8)

    for i, idx in enumerate(mask_indices, start=1):
        mask = (mask_id == idx)
        if mask.sum() > 0 and (mask.sum() / masks[idx-1].sum()) > 0.1:
            composite_mask[mask] = i

    return composite_mask

def mask_id_to_binary_mask(composite_mask: np.ndarray) -> np.ndarray:
    """
    Convert an image with integer mask IDs to a binary mask array.

    Args:
        mask_id (np.ndarray): An (H, W) array where each unique positive integer 
                            represents a separate object mask.

    Returns:
        np.ndarray: A (N, H, W) boolean array where N is the number of masks and each 
                    slice contains a binary mask.
    """
    unique_ids = np.unique(composite_mask)
    unique_ids = unique_ids[unique_ids > 0]  # Ignore background (assumed to be 0)

    binary_masks = (composite_mask[None, ...] == unique_ids[:, None, None])
    return binary_masks

def create_patch_mask(image, num_patches):
    image_height, image_width = image.shape[:2]
    
    patch_width = math.ceil(image_width / num_patches)
    patch_height = math.ceil(image_height / num_patches)
    
    # Create flattened coordinates
    total_pixels = image_height * image_width
    y_coords = torch.arange(image_height).unsqueeze(1).expand(-1, image_width).flatten()
    x_coords = torch.arange(image_width).unsqueeze(0).expand(image_height, -1).flatten()
    
    # Calculate patch indices for all pixels at once
    patch_y_indices = torch.clamp(y_coords // patch_height, 0, num_patches - 1)
    patch_x_indices = torch.clamp(x_coords // patch_width, 0, num_patches - 1)
    
    # Create sparse representation
    flatten_patch_mask = torch.zeros((num_patches, num_patches, total_pixels), 
                                   dtype=torch.bool)
    
    # Use indexing to set values
    pixel_indices = torch.arange(total_pixels)
    flatten_patch_mask[patch_y_indices, patch_x_indices, pixel_indices] = True
    
    return flatten_patch_mask

def project_gaussians(model, camera):

    _ = model.get_outputs(camera)
    meta = model.info
    W, H = meta["width"], meta["height"]

    # gaussians where the radius is greater than 1.0 can be seen in the camera frustum
    radii = model.info['radii'].squeeze()
    gaussian_ids = torch.where(torch.sum(radii > 1.0, axis=1))[0]

    # Convert 2D coords to flat pixel indices
    xy_rounded = torch.round(meta['means2d']).squeeze().long()
    x = torch.clamp(xy_rounded[:, 0], 0, W)
    y = torch.clamp(xy_rounded[:, 1], 0, H)
    projected_flattened = x + y * W                      # (M,)

    return {
        "proj_flattened": projected_flattened,                      # (M,)
        "proj_depths": meta['depths'],                                      # (M,)
        "gaussian_ids": gaussian_ids,                 # (M,)
    }

### Test the grouping process

In [18]:
# First create a composite mask
composite_mask = create_composite_mask(results, confidence_threshold=0.85)

# Decimate the composite mask into individual masks
binary_masks = mask_id_to_binary_mask(composite_mask)

# Flatten the binary masks
flattened_masks = torch.tensor(binary_masks).flatten(start_dim=1)

# Create a patch mask --> find the intersection between the composite mask and the patch mask
patch_mask = create_patch_mask(image, num_patches=32)

Find the intersection between a given object mask and the patch masks

In [19]:
patch_intersections = flattened_masks[0].unsqueeze(0).unsqueeze(0) & patch_mask

# Find non-empty patches
patch_sums = patch_intersections.sum(dim=2)  # Sum pixels per patch
non_empty_patches = (patch_sums > 0).nonzero(as_tuple=False)


Project the gaussians into pixel space

In [28]:
model = pipeline.model
proj_results = project_gaussians(model, camera)

  torch.tensor(get_world2view_transform(R, T, trans, scale)).transpose(0, 1).cuda()


AttributeError: 'RadegsFeaturesModelConfig' object has no attribute 'return_packed_info'

Grab a non-empty patch

In [20]:
i, j = non_empty_patches[4]
current_patch = patch_intersections[i, j]

Given the current patch, find its associated gaussians

In [None]:
# Current patch is a flattened image, where there's a set of pixels in that patch we are looking for
patch_gaussians = current_patch[projected_flattened.cpu()].nonzero().squeeze(-1)

torch.Size([518400])

In [24]:
proj_results

NameError: name 'proj_results' is not defined

In [None]:
# projected_flattened are the pixel coordinates of each gaussian --> current patch is the pixels of the mask
patch_gaussians = current_patch[projected_flattened.cpu()].nonzero().squeeze(-1)


In [415]:
def project_gaussians(model, camera):

    _ = model.get_outputs(camera)
    meta = model.info
    W, H = meta["width"], meta["height"]

    # gaussians where the radius is greater than 1.0 can be seen in the camera frustum
    radii = model.info['radii'].squeeze()
    gaussian_ids = torch.where(torch.sum(radii > 1.0, axis=1))[0]

    # Convert 2D coords to flat pixel indices
    xy_rounded = torch.round(meta['means2d']).squeeze().long()
    x = torch.clamp(xy_rounded[:, 0], 0, W)
    y = torch.clamp(xy_rounded[:, 1], 0, H)
    projected_flattened = x + y * W                      # (M,)

    reverse_mapping = {gid.item(): i for i, gid in enumerate(gaussian_ids)}

    return {
        "proj_flattened": projected_flattened,                      # (M,)
        "proj_depths": meta['depths'],                                      # (M,)
        "gaussian_ids": gaussian_ids,                 # (M,)
        "gaussian_id_reverse_mapping": reverse_mapping,      # dict
    }

def get_mask_gaussians(model, image, camera, composite_mask, n_patches: int = 32, front_percentage: float = 0.5):

    # Project the gaussians to 2d
    proj_results = project_gaussians(model, camera)

    # Decimate the composite mask into individual masks
    binary_masks = mask_id_to_binary_mask(composite_mask)
    flattened_masks = torch.tensor(binary_masks).flatten(start_dim=1)

    # Create a patch mask --> find the intersection between the composite mask and the patch mask
    patch_mask = create_patch_mask(image, n_patches)

    gaussians = []

    for mask in flattened_masks:

        # Find the intersection between the mask and the patch mask
        patch_intersections = mask.unsqueeze(0).unsqueeze(0) & patch_mask

        # Find the non-empty patches
        patch_sums = patch_intersections.sum(dim=2)  # Sum pixels per patch
        non_empty_patches = (patch_sums > 0).nonzero(as_tuple=False)

        # If there are no non-empty patches, add an empty tensor to the gaussians list
        if len(non_empty_patches) == 0:
            gaussians.append(torch.tensor([], dtype=torch.long))
            continue
        
        # Find the gaussian ids that are inside the non-empty patches
        mask_gaussians = []

        for patch_idx in non_empty_patches:
            i, j = patch_idx
            current_patch = patch_intersections[i, j]

            # current_patch

            # Find the gaussian ids that are inside the current patch
            patch_gaussians = torch.where(current_patch)[0]

            # Find the gaussian ids that are inside the current patch
            patch_gaussians = torch.where(current_patch)[0]

In [494]:
proj_results = project_gaussians(model, camera)

i, j = non_empty_patches[4]
current_patch = patch_intersections[i, j]

  torch.tensor(get_world2view_transform(R, T, trans, scale)).transpose(0, 1).cuda()


In [501]:
# projected_flattened are the pixel coordinates of each gaussian --> current patch is the pixels of the mask
patch_gaussians = current_patch[projected_flattened.cpu()].nonzero().squeeze(-1)


In [504]:
proj_results['gaussian_ids']

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [493]:
current_patch[projected_flattened.cpu()].nonzero()

tensor([[ 112617],
        [ 119493],
        [ 128966],
        [ 131631],
        [ 150154],
        [ 166959],
        [ 191581],
        [ 206662],
        [ 227541],
        [ 228650],
        [ 258188],
        [ 281990],
        [ 332490],
        [ 356528],
        [ 387293],
        [ 389121],
        [ 396813],
        [ 434758],
        [ 537677],
        [ 639315],
        [ 666595],
        [ 780970],
        [ 917891],
        [ 919133],
        [1046450],
        [1047148],
        [1086774],
        [1087433],
        [1229077],
        [1230453],
        [1237335]])

In [476]:
projected_flattened.unique()

tensor([     0,      1,      2,  ..., 517438, 517439, 517440], device='cuda:0')

In [465]:
proj_results['gaussian_ids'].shape

torch.Size([181014])

In [472]:
proj_results['proj_flattened'].unique().shape

torch.Size([130238])

In [None]:
patch_gaussian_ids = torch.where(patch_gaussians)[0]
patch_gaussian_depths = proj_results['proj_depths'].squeeze()[patch_gaussian_ids]

In [445]:
# Convert 2D coords to flat pixel indices
xy_rounded = torch.round(meta['means2d']).squeeze().long()
x = torch.clamp(xy_rounded[:, 0], 0, W)
y = torch.clamp(xy_rounded[:, 1], 0, H)
projected_flattened = x + y * W                      # (M,)

In [269]:
# Find the IDs of the gaussians within the current camera frustum
gaussian_ids

tensor([      8,      13,      14,  ..., 1339596, 1339602, 1339611], device='cuda:0')