In [1]:
%load_ext autoreload
%autoreload 2

import os, sys
from pathlib import Path

from matplotlib import pyplot as plt

from nerfstudio.utils.eval_utils import eval_setup
# from ns_extension.utils.grouping import GroupingClassifier

### Load configuration

In [2]:
# Path to the config for a trained model
load_config = '/workspace/fieldwork-data/rats/2024-07-11/environment/C0119/rade-features/2025-07-11_171420/config.yml'
load_config = Path(load_config)

config, pipeline, checkpoint_path, step = eval_setup(load_config)

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


KeyboardInterrupt: 

### Question is whether to build grouping on top of the existing data?

Initialize segmentation model

In [1]:
camera, batch = pipeline.datamanager.next_train(0)

fn = pipeline.datamanager.train_dataset.image_filenames[0]

NameError: name 'pipeline' is not defined

In [7]:
from PIL import Image
import numpy as np

from ns_extension.utils.features import resize_image
from ns_extension.utils.segmentation import Segmentation

segmentation = Segmentation(
    backend='mobilesamv2',
    strategy='object',
    device='cuda',
)
segmentation.strategy = 'auto'

image = Image.open(fn)
H, W = image.height, image.width

# # Prepare image for segmentation
# image = resize_image(image, longest_edge=1024) # Resize image to SAM resolution

# Apply segmentation masks over features
image = np.asarray(image) # Convert to numpy array
masks, results = segmentation.segment(image)

Using cache found in /workspace/models/hub/RogerQi_MobileSAMV2_main


checkpoint_load_scucess


NameError: name 'fn' is not defined

### Start making our functions here

In [5]:
"""
Gaga --> gaussian grouping via multiview association + memory bank

Steps:
1. Create masks --> for each view within the dataset, create masks
    - Original implementation saves them out as images, but we could just save them out as tensors

2. Associate masks --> creates the memory bank?
    - Front percentage (0.2)
    - Overlap threshold (0.1)
    - For each camera --> 
        - If no masks, initialize a memory bank for the first view's masks
        - Get gaussian idxs and zcoords (for depth grouping) for the current view
        - Find front gaussians:
            - Create Spatial patch mask --> divides image into patch grid
            - Object masks --> goes through each mask in the image
            - Combines the two masks (i.e., find overlap between patch and object mask)
            - Find frontmost gaussians within each patch for each object
        - Based on this:
            - Stores the indices of the front gaussians
            - Mask ID = tensor of ALL indices of that mask (i.e., all gaussians in that mask)
            - Num masks == number of masks in the memory bank

"""
import torch
from torch import nn

import numpy as np
import math

from typing import Dict
from tqdm import tqdm

from ns_extension.utils.utils import project_gaussians

class GroupingClassifier(nn.Module):
    def __init__(self, num_masks: int, num_gaussians: int):
        super(GroupingClassifier, self).__init__()

        # eval_setup(load_config)
        self.num_masks = num_masks
        self.num_gaussians = num_gaussians
        self.classifier = nn.Conv2d(in_channels=num_masks, out_channels=num_gaussians, kernel_size=1)

    #########################################################
    ############## Mask initialization ######################
    #########################################################

    def set_patch_mask(self, image, num_patches: int = 32):
        """
        Provided an image of given dimensions, create an array of patches.
        """
        # Get image dimensions
        H, W = image.shape[:2]

        # Get patch dimensions
        patch_width = math.ceil(W / num_patches)
        patch_height = math.ceil(H / num_patches)
        
        # Create flattened coordinates
        total_pixels = H * W
        y_coords = torch.arange(H).unsqueeze(1).expand(-1, W).flatten()
        x_coords = torch.arange(W).unsqueeze(0).expand(H, -1).flatten()
        
        # Calculate patch indices for all pixels at once
        patch_y_indices = torch.clamp(y_coords // patch_height, 0, num_patches - 1)
        patch_x_indices = torch.clamp(x_coords // patch_width, 0, num_patches - 1)
        
        # Create sparse representation
        flatten_patch_mask = torch.zeros((num_patches, num_patches, total_pixels), 
                                    dtype=torch.bool)
        
        # Use indexing to set values
        pixel_indices = torch.arange(total_pixels)
        flatten_patch_mask[patch_y_indices, patch_x_indices, pixel_indices] = True
        
        return flatten_patch_mask
    
    def create_composite_mask(self, results, confidence_threshold=0.85):
        """
        Creates a composite mask from the results of the segmentation model.
        
        Inputs:
            results: list of dicts, each containing a mask and a confidence score
            confidence_threshold: float, the minimum confidence score for a mask to be included in the composite mask

        Outputs:
            composite_mask: numpy array, the composite mask
        """

        selected_masks = []
        for mask in results:
            if mask['predicted_iou'] < confidence_threshold:
                continue

            selected_masks.append(
                (mask['segmentation'], mask['predicted_iou'])
            )
        
        # Store the masks and confidences
        masks, confs = zip(*selected_masks)

        # Create empty image to store mask ids
        H, W = masks[0].shape[:2]
        mask_id = np.zeros((H, W), dtype=np.uint8)

        sorted_idxs = np.argsort(confs)
        for i, idx in enumerate(sorted_idxs, start=1):
            current_mask = masks[idx - 1]
            mask_id[current_mask == 1] = i

        # Find mask indices after having calculated overlap based on ranked confidence
        mask_indices = np.unique(mask_id)
        mask_indices = np.setdiff1d(mask_indices, [0]) # remove 0 item

        composite_mask = np.zeros((H, W), dtype=np.uint8)

        for i, idx in enumerate(mask_indices, start=1):
            mask = (mask_id == idx)
            if mask.sum() > 0 and (mask.sum() / masks[idx-1].sum()) > 0.1:
                composite_mask[mask] = i

        return composite_mask

    def mask_id_to_binary_mask(self, composite_mask: np.ndarray) -> np.ndarray:
        """
        Convert an image with integer mask IDs to a binary mask array.

        Args:
            mask_id (np.ndarray): An (H, W) array where each unique positive integer 
                                represents a separate object mask.

        Returns:
            np.ndarray: A (N, H, W) boolean array where N is the number of masks and each 
                        slice contains a binary mask.
        """
        unique_ids = np.unique(composite_mask)
        unique_ids = unique_ids[unique_ids > 0]  # Ignore background (assumed to be 0)

        binary_masks = (composite_mask[None, ...] == unique_ids[:, None, None])
        return binary_masks

    #########################################################
    ############## Gaussian selection #######################
    #########################################################

    def select_front_gaussians(self, model, camera, composite_mask, front_percentage: float = 0.5):
        """
        JIT-compiled version using torch.compile (PyTorch 2.0+).
        Maintains original structure and comments while adding compilation optimization.
        Now with separated helper functions for better code organization.
        """
        
        # Project gaussians onto 2d image
        proj_results = project_gaussians(model, camera)
        
        # Prepare masks = Decimate the composite mask into individual masks
        binary_masks = self.mask_id_to_binary_mask(composite_mask)
        flattened_masks = torch.tensor(binary_masks).flatten(start_dim=1)  # (N, H*W)

        # Pre-extract proj_results for compiled function
        proj_flattened = proj_results['proj_flattened']
        proj_depths = proj_results['proj_depths']

        # Compute the gaussian lookup table
        max_gaussian_id = proj_results['gaussian_ids'].max() if len(proj_results['gaussian_ids']) > 0 else 0
        valid_gaussian_mask = torch.zeros(max_gaussian_id + 1, dtype=torch.bool, device=proj_results['gaussian_ids'].device)
        valid_gaussian_mask[proj_results['gaussian_ids']] = True

        front_gaussians = []

        for mask in tqdm(flattened_masks, total=len(flattened_masks), desc="Processing masks"):
            # Use compiled function for main processing
            result = self.process_mask_gaussians(
                mask, 
                proj_results, 
                valid_gaussian_mask, 
                front_percentage=front_percentage
            )
            
            front_gaussians.append(result)

        return front_gaussians

    @torch.compile(mode="max-autotune")
    def process_mask_gaussians(self, mask, proj_results: Dict[str, torch.Tensor], valid_gaussian_mask: torch.Tensor, front_percentage: float = 0.5):
        """
        JIT-compiled function for processing a single mask.
        Optimized for performance with torch.compile.
        """
        # Find intersection between object mask and patch masks
        patch_intersections = mask.unsqueeze(0).unsqueeze(0) & self.patch_mask

        # Find non-empty patches
        patch_sums = patch_intersections.sum(dim=2)
        non_empty_patches = (patch_sums > 0).nonzero(as_tuple=False)

        if len(non_empty_patches) == 0:
            return torch.tensor([], dtype=torch.long, device=mask.device)
        
        # Extract all patches at once
        mask_gaussians = []
        patches_data = patch_intersections[non_empty_patches[:, 0], non_empty_patches[:, 1]]

        # Go through each non-empty patch and get the front gaussians
        for patch_idx, current_patch in enumerate(patches_data):
            # Projected flattened are the pixel coordinates of each gaussian --> current patch is the pixels of the mask
            # Grab gaussians in the current patch
            patch_gaussians = current_patch[proj_results['proj_flattened']].nonzero().squeeze(-1)
            
            if len(patch_gaussians) == 0:
                continue

            # Filter valid gaussians using pre-computed mask
            overlap_mask = valid_gaussian_mask[patch_gaussians]

            if not overlap_mask.all():
                invalid_count = (~overlap_mask).sum()
                print(f"Found {invalid_count} gaussians not in the IDs")
                print("Gaussians not in the IDs: ", patch_gaussians[~overlap_mask])

            # Note: Error checking moved outside compiled function for better performance
            patch_gaussians = patch_gaussians[overlap_mask]

            if len(patch_gaussians) == 0:
                continue
            
            # Grab the depths of the gaussians in the patch
            num_front_gaussians = max(int(front_percentage * len(patch_gaussians)), 1)
            
            if num_front_gaussians < len(patch_gaussians):
                # Use partial sorting for better performance
                patch_depths = proj_results['proj_depths'][patch_gaussians]
                _, front_indices = torch.topk(patch_depths, num_front_gaussians, largest=False)
                selected_gaussians = patch_gaussians[front_indices]
            else:
                selected_gaussians = patch_gaussians
            
            mask_gaussians.append(selected_gaussians)

        if len(mask_gaussians) > 0:
            mask_gaussians = torch.cat(mask_gaussians)
            return mask_gaussians
        else:
            return torch.tensor([], dtype=torch.long, device=mask.device)

    def associate_masks(self):
        pass

# def get_n_different_colors(n: int) -> np.ndarray:
#     np.random.seed(0)
#     return np.random.randint(1, 256, (n, 3), dtype=np.uint8)

# def visualize_mask(mask: np.ndarray) -> np.ndarray:
#     color_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
#     num_masks = np.max(mask)
#     random_colors = get_n_different_colors(num_masks)
#     for i in range(num_masks):
#         color_mask[mask == i+1] = random_colors[i]
#     return color_mask



### Test the grouping process

In [6]:
model = pipeline.model
proj_results = project_gaussians(model, camera)

NameError: name 'pipeline' is not defined

In [43]:
results[0]['segmentation'].shape

(960, 540)

In [35]:
composite_mask = create_composite_mask(results)
patch_mask = create_patch_mask(image, num_patches=32)
front_gaussians = select_front_gaussians(model, camera, composite_mask, patch_mask, front_percentage = 0.5)

  torch.tensor(get_world2view_transform(R, T, trans, scale)).transpose(0, 1).cuda()
Processing masks:   0%|          | 0/65 [00:00<?, ?it/s]skipping cudagraphs for unknown reason
Processing masks:   8%|▊         | 5/65 [00:02<00:24,  2.49it/s]

Found 1 gaussians not in the IDs
Gaussians not in the IDs:  tensor([1195031])


Processing masks:  11%|█         | 7/65 [00:02<00:22,  2.64it/s]

Found 1 gaussians not in the IDs
Gaussians not in the IDs:  tensor([1317422])


Processing masks:  37%|███▋      | 24/65 [00:08<00:14,  2.90it/s]

Found 1 gaussians not in the IDs
Gaussians not in the IDs:  tensor([730857])
Found 1 gaussians not in the IDs
Gaussians not in the IDs:  tensor([169239])
Found 1 gaussians not in the IDs
Gaussians not in the IDs:  tensor([281163])
Found 1 gaussians not in the IDs
Gaussians not in the IDs:  tensor([424230])


Processing masks:  60%|██████    | 39/65 [00:14<00:09,  2.85it/s]

Found 2 gaussians not in the IDs
Gaussians not in the IDs:  tensor([1103701, 1178042])


Processing masks:  68%|██████▊   | 44/65 [00:16<00:08,  2.58it/s]

Found 1 gaussians not in the IDs
Gaussians not in the IDs:  tensor([897836])
Found 2 gaussians not in the IDs
Gaussians not in the IDs:  tensor([1089400, 1096504])


Processing masks:  78%|███████▊  | 51/65 [00:18<00:04,  2.95it/s]

Found 1 gaussians not in the IDs
Gaussians not in the IDs:  tensor([1236157])
Found 1 gaussians not in the IDs
Gaussians not in the IDs:  tensor([369113])


Processing masks:  91%|█████████ | 59/65 [00:20<00:01,  3.02it/s]

Found 3 gaussians not in the IDs
Gaussians not in the IDs:  tensor([ 333050,  718835, 1230981])


Processing masks:  94%|█████████▍| 61/65 [00:21<00:01,  2.76it/s]

Found 2 gaussians not in the IDs
Gaussians not in the IDs:  tensor([1264189, 1296009])
Found 1 gaussians not in the IDs
Gaussians not in the IDs:  tensor([1309851])
Found 4 gaussians not in the IDs
Gaussians not in the IDs:  tensor([466000, 526218, 720453, 992983])
Found 2 gaussians not in the IDs
Gaussians not in the IDs:  tensor([392888, 787331])
Found 1 gaussians not in the IDs
Gaussians not in the IDs:  tensor([895690])
Found 1 gaussians not in the IDs
Gaussians not in the IDs:  tensor([1048921])


Processing masks:  97%|█████████▋| 63/65 [00:22<00:00,  2.58it/s]

Found 1 gaussians not in the IDs
Gaussians not in the IDs:  tensor([1303073])
Found 1 gaussians not in the IDs
Gaussians not in the IDs:  tensor([719282])
Found 1 gaussians not in the IDs
Gaussians not in the IDs:  tensor([813548])
Found 1 gaussians not in the IDs
Gaussians not in the IDs:  tensor([1103224])
Found 1 gaussians not in the IDs
Gaussians not in the IDs:  tensor([1231555])
Found 1 gaussians not in the IDs
Gaussians not in the IDs:  tensor([1118953])
Found 1 gaussians not in the IDs
Gaussians not in the IDs:  tensor([1089382])
Found 1 gaussians not in the IDs
Gaussians not in the IDs:  tensor([1254684])
Found 1 gaussians not in the IDs
Gaussians not in the IDs:  tensor([115346])
Found 1 gaussians not in the IDs
Gaussians not in the IDs:  tensor([1191680])
Found 1 gaussians not in the IDs
Gaussians not in the IDs:  tensor([150277])
Found 1 gaussians not in the IDs
Gaussians not in the IDs:  tensor([36077])
Found 1 gaussians not in the IDs
Gaussians not in the IDs:  tensor([19

Processing masks:  98%|█████████▊| 64/65 [00:25<00:01,  1.19s/it]

Found 1 gaussians not in the IDs
Gaussians not in the IDs:  tensor([482238])
Found 1 gaussians not in the IDs
Gaussians not in the IDs:  tensor([1133756])


Processing masks: 100%|██████████| 65/65 [00:26<00:00,  2.50it/s]


torch.Size([54])

Find the intersection between a given object mask and the patch masks

In [43]:
# Start with the current mask, but we loop over these in the final product
current_mask = flattened_masks[0]
patch_intersections = current_mask.unsqueeze(0).unsqueeze(0) & patch_mask

# Find non-empty patches
patch_sums = patch_intersections.sum(dim=2)
non_empty_patches = (patch_sums > 0).nonzero(as_tuple=False)


Project the gaussians into pixel space

In [44]:
model = pipeline.model
proj_results = project_gaussians(model, camera)

  torch.tensor(get_world2view_transform(R, T, trans, scale)).transpose(0, 1).cuda()


Grab a non-empty patch

In [45]:
i, j = non_empty_patches[4]
current_patch = patch_intersections[i, j]

Given the current patch, find its associated gaussians

In [46]:
# Projected flattened are the pixel coordinates of each gaussian --> current patch is the pixels of the mask
projected_flattened = proj_results['proj_flattened'] # (M,)
patch_gaussians = current_patch[projected_flattened.cpu()].nonzero().squeeze(-1)

# This should pass --> need to check all found patch gaussians are in the valid gaussians
assert torch.isin(patch_gaussians.detach().cpu(), proj_results['gaussian_ids'].detach().cpu()).all()


In [48]:
# Grab the depths of the gaussians in the patch
patch_gaussian_depths = proj_results['proj_depths'][patch_gaussians]

In [49]:
front_percentage = 0.2
num_patch_gaussians = len(patch_gaussians)
num_front_gaussians = max(int(front_percentage * num_patch_gaussians), 1)

# Sort the gaussians by depth
sorted_gaussian_ids = torch.argsort(patch_gaussian_depths)

# gaussians sorted by their depths and select the front based on percentage
sorted_gaussians = patch_gaussians[sorted_gaussian_ids]
selected_gaussians = sorted_gaussians[:num_front_gaussians]

In [415]:
def get_mask_gaussians(model, image, camera, composite_mask, n_patches: int = 32, front_percentage: float = 0.5):

    # Project the gaussians to 2d
    proj_results = project_gaussians(model, camera)

    # Decimate the composite mask into individual masks
    binary_masks = mask_id_to_binary_mask(composite_mask)
    flattened_masks = torch.tensor(binary_masks).flatten(start_dim=1)

    # Create a patch mask --> find the intersection between the composite mask and the patch mask
    patch_mask = create_patch_mask(image, n_patches)

    gaussians = []

    for mask in flattened_masks:

        # Find the intersection between the mask and the patch mask
        patch_intersections = mask.unsqueeze(0).unsqueeze(0) & patch_mask

        # Find the non-empty patches
        patch_sums = patch_intersections.sum(dim=2)  # Sum pixels per patch
        non_empty_patches = (patch_sums > 0).nonzero(as_tuple=False)

        # If there are no non-empty patches, add an empty tensor to the gaussians list
        if len(non_empty_patches) == 0:
            gaussians.append(torch.tensor([], dtype=torch.long))
            continue
        
        # Find the gaussian ids that are inside the non-empty patches
        mask_gaussians = []

        for patch_idx in non_empty_patches:
            i, j = patch_idx
            current_patch = patch_intersections[i, j]

            # current_patch

            # Find the gaussian ids that are inside the current patch
            patch_gaussians = torch.where(current_patch)[0]

            # Find the gaussian ids that are inside the current patch
            patch_gaussians = torch.where(current_patch)[0]

In [494]:
proj_results = project_gaussians(model, camera)

i, j = non_empty_patches[4]
current_patch = patch_intersections[i, j]

  torch.tensor(get_world2view_transform(R, T, trans, scale)).transpose(0, 1).cuda()
