In [1]:
%load_ext autoreload
%autoreload 2

import os, sys
from pathlib import Path

from matplotlib import pyplot as plt

from nerfstudio.utils.eval_utils import eval_setup
# from ns_extension.utils.grouping import GroupingClassifier

[Taichi] version 1.7.3, llvm 15.0.4, commit 5ec301be, linux, python 3.10.18


[I 07/21/25 14:13:50.135 13956] [shell.py:_shell_pop_print@23] Graphical python shell detected, using wrapped sys.stdout


### Load configuration

In [2]:
# Path to the config for a trained model
load_config = '/workspace/fieldwork-data/rats/2024-07-11/environment/C0119/rade-features/2025-07-11_171420/config.yml'
load_config = Path(load_config)

config, pipeline, checkpoint_path, step = eval_setup(load_config)

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


### Question is whether to build grouping on top of the existing data?

Initialize segmentation model

In [3]:
camera, batch = pipeline.datamanager.next_train(0)

fn = pipeline.datamanager.train_dataset.image_filenames[0]

Output()

In [4]:
from PIL import Image
import numpy as np

from ns_extension.utils.features import resize_image
from ns_extension.utils.segmentation import Segmentation

segmentation = Segmentation(
    backend='mobilesamv2',
    strategy='object',
    device='cuda',
)
segmentation.strategy = 'auto'

image = Image.open(fn)
H, W = image.height, image.width

# # Prepare image for segmentation
# image = resize_image(image, longest_edge=1024) # Resize image to SAM resolution

# Apply segmentation masks over features
image = np.asarray(image) # Convert to numpy array
masks, results = segmentation.segment(image)

Using cache found in /workspace/models/hub/RogerQi_MobileSAMV2_main


checkpoint_load_scucess


  return F.conv2d(input, weight, bias, self.stride,


### Start making our functions here

In [None]:
"""
Gaga --> gaussian grouping via multiview association + memory bank

Steps:
1. Create masks --> for each view within the dataset, create masks
    - Original implementation saves them out as images, but we could just save them out as tensors

2. Associate masks --> creates the memory bank?
    - Front percentage (0.2)
    - Overlap threshold (0.1)
    - For each camera --> 
        - If no masks, initialize a memory bank for the first view's masks
        - Get gaussian idxs and zcoords (for depth grouping) for the current view
        - Find front gaussians:
            - Create Spatial patch mask --> divides image into patch grid
            - Object masks --> goes through each mask in the image
            - Combines the two masks (i.e., find overlap between patch and object mask)
            - Find frontmost gaussians within each patch for each object
        - Based on this:
            - Stores the indices of the front gaussians
            - Mask ID = tensor of ALL indices of that mask (i.e., all gaussians in that mask)
            - Num masks == number of masks in the memory bank

"""
import torch
from torch import nn

import numpy as np
import math

from typing import Dict
from tqdm import tqdm

from nerfstudio.models.splatfacto import SplatfactoModel
from ns_extension.utils.segmentation import Segmentation, create_patch_mask, create_composite_mask, mask_id_to_binary_mask
from ns_extension.utils.utils import project_gaussians

class GroupingClassifier(nn.Module):
    def __init__(self, load_config: str, segmentation_backend: str, segmentation_strategy: str):
        super(GroupingClassifier, self).__init__()

        self.load_config = load_config
        self.segmentation_backend = segmentation_backend
        self.segmentation_strategy = segmentation_strategy

        # eval_setup(load_config)
        # self.num_masks = num_masks
        # self.num_gaussians = num_gaussians
        # self.classifier = nn.Conv2d(in_channels=num_masks, out_channels=num_gaussians, kernel_size=1)

    #########################################################
    ############## Mask initialization ######################
    #########################################################

    def associate(self):
        """
        Creates a memory bank that associates gaussians within masks across fields-of-view.
        """

        _, pipeline, _, _ = eval_setup(self.load_config)

        assert isinstance(pipeline.model, SplatfactoModel)

        model: SplatfactoModel = pipeline.model

        # Load the segmentation model
        segmentation = Segmentation(
            backend=self.segmentation_backend,
            strategy=self.segmentation_strategy,
            device=model.device
        )

        with torch.no_grad():
            cameras: Cameras = pipeline.datamanager.train_dataset.cameras  # type: ignore
            # TODO: do eval dataset as well

            for image_idx, data in tqdm(
                enumerate(pipeline.datamanager.train_dataset),  # type: ignore
                desc="Processing frames",
                total=len(pipeline.datamanager.train_dataset)
            ):
                # Grab camera and forward pass through model
                camera = cameras[image_idx : image_idx + 1]
                image = data["image"]

                # Forward pass through model to get metadata
                _ = model.get_outputs(camera=camera)

                # Segment the image
                patch_mask = create_patch_mask(image)
                _, results = segmentation.segment(image)

                # Create composite mask
                composite_mask = create_composite_mask(results)

                # Select front gaussians
                front_gaussians = self.select_front_gaussians(
                    meta=model.info, 
                    composite_mask=composite_mask, 
                    patch_mask=patch_mask
                )
                
    #########################################################
    ############## Gaussian selection #######################
    #########################################################

    def select_front_gaussians(self, meta: Dict[str, torch.Tensor], composite_mask: torch.Tensor, patch_mask: torch.Tensor, front_percentage: float = 0.5):
        """
        JIT-compiled version using torch.compile (PyTorch 2.0+).
        Maintains original structure and comments while adding compilation optimization.
        Now with separated helper functions for better code organization.
        """

        proj_results = project_gaussians(meta)
                
        # Prepare masks = Decimate the composite mask into individual masks
        binary_masks = self.mask_id_to_binary_mask(composite_mask)
        flattened_masks = torch.tensor(binary_masks).flatten(start_dim=1)  # (N, H*W)

        # Compute the gaussian lookup table
        max_gaussian_id = proj_results['gaussian_ids'].max() if len(proj_results['gaussian_ids']) > 0 else 0
        valid_gaussian_mask = torch.zeros(max_gaussian_id + 1, dtype=torch.bool, device=proj_results['gaussian_ids'].device)
        valid_gaussian_mask[proj_results['gaussian_ids']] = True

        front_gaussians = []

        for mask in tqdm(flattened_masks, total=len(flattened_masks), desc="Processing masks"):

            if patch_mask is not None:

            # Use compiled function for main processing
            result = self.process_mask_gaussians(
                proj_results, 
                mask, 
                patch_mask,
                valid_gaussian_mask, 
                front_percentage=front_percentage
            )
            
            front_gaussians.append(result)

        return front_gaussians

    @torch.compile(mode="max-autotune")
    def process_mask_gaussians(self,  proj_results: Dict[str, torch.Tensor], mask: torch.Tensor, patch_mask: torch.Tensor, valid_gaussian_mask: torch.Tensor, front_percentage: float = 0.5):
        """
        JIT-compiled function for processing a single mask.
        Optimized for performance with torch.compile.
        """

        ### TLB THIS SECTION COULD BE REFACTORED OUT FOR MORE FLEXIBILITY (PATCH VS NO PATCH)
        # Find intersection between object mask and patch masks
        patch_intersections = mask.unsqueeze(0).unsqueeze(0) & patch_mask

        # Find non-empty patches
        patch_sums = patch_intersections.sum(dim=2)
        non_empty_patches = (patch_sums > 0).nonzero(as_tuple=False)

        if len(non_empty_patches) == 0:
            return torch.tensor([], dtype=torch.long, device=mask.device)
        
        # Extract all patches at once
        mask_gaussians = []
        patches_data = patch_intersections[non_empty_patches[:, 0], non_empty_patches[:, 1]]

        # Go through each non-empty patch and get the front gaussians
        for patch_idx, current_patch in enumerate(patches_data):
            # Projected flattened are the pixel coordinates of each gaussian --> current patch is the pixels of the mask
            # Grab gaussians in the current patch
            patch_gaussians = current_patch[proj_results['proj_flattened']].nonzero().squeeze(-1)
            
            if len(patch_gaussians) == 0:
                continue

            # Filter valid gaussians using pre-computed mask
            overlap_mask = valid_gaussian_mask[patch_gaussians]

            if not overlap_mask.all():
                invalid_count = (~overlap_mask).sum()
                print(f"Found {invalid_count} gaussians not in the IDs")
                print("Gaussians not in the IDs: ", patch_gaussians[~overlap_mask])

            # Note: Error checking moved outside compiled function for better performance
            patch_gaussians = patch_gaussians[overlap_mask]

            if len(patch_gaussians) == 0:
                continue
            
            # Grab the depths of the gaussians in the patch
            num_front_gaussians = max(int(front_percentage * len(patch_gaussians)), 1)
            
            if num_front_gaussians < len(patch_gaussians):
                # Use partial sorting for better performance
                patch_depths = proj_results['proj_depths'][patch_gaussians]
                _, front_indices = torch.topk(patch_depths, num_front_gaussians, largest=False)
                selected_gaussians = patch_gaussians[front_indices]
            else:
                selected_gaussians = patch_gaussians
            
            mask_gaussians.append(selected_gaussians)

        if len(mask_gaussians) > 0:
            mask_gaussians = torch.cat(mask_gaussians)
            return mask_gaussians
        else:
            return torch.tensor([], dtype=torch.long, device=mask.device)

    def associate_masks(self):
        pass

# def get_n_different_colors(n: int) -> np.ndarray:
#     np.random.seed(0)
#     return np.random.randint(1, 256, (n, 3), dtype=np.uint8)

# def visualize_mask(mask: np.ndarray) -> np.ndarray:
#     color_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
#     num_masks = np.max(mask)
#     random_colors = get_n_different_colors(num_masks)
#     for i in range(num_masks):
#         color_mask[mask == i+1] = random_colors[i]
#     return color_mask



IndentationError: expected an indented block after 'if' statement on line 129 (460887998.py, line 132)

In [None]:
# To initialize the classifier, we first need to build the mask association
# This involves 
# 1) segmenting each camera view
# 2) associating the masks to the gaussians


def build_mask_association(model, cameras):
    for camera in tqdm(cameras):
        camera = camera.to(model.device)

        front_gaussians, _ = get_patch_front_gaussian_of_mask(model, camera)
        labels = assign_labels(front_gaussians)
        self._update_gaussian_idx_bank(labels, front_gaussians)

        if self.num_mask == 0:
            self.assigned_gaussians = torch.unique(torch.cat(front_gaussians))


### Test the grouping process

In [9]:
model = pipeline.model

outputs = model.get_outputs(camera)
meta = model.info

proj_results = project_gaussians(meta)

  torch.tensor(get_world2view_transform(R, T, trans, scale)).transpose(0, 1).cuda()
  torch.tensor(get_world2view_transform(R, T, trans, scale)).transpose(0, 1).cuda()


In [11]:
composite_mask = create_composite_mask(results)
patch_mask = create_patch_mask(image, num_patches=32)
front_gaussians = select_front_gaussians(model, camera, composite_mask, patch_mask, front_percentage = 0.5)

NameError: name 'create_composite_mask' is not defined

torch.Size([54])

Find the intersection between a given object mask and the patch masks

In [43]:
# Start with the current mask, but we loop over these in the final product
current_mask = flattened_masks[0]
patch_intersections = current_mask.unsqueeze(0).unsqueeze(0) & patch_mask

# Find non-empty patches
patch_sums = patch_intersections.sum(dim=2)
non_empty_patches = (patch_sums > 0).nonzero(as_tuple=False)


Project the gaussians into pixel space

In [44]:
model = pipeline.model
proj_results = project_gaussians(model, camera)

  torch.tensor(get_world2view_transform(R, T, trans, scale)).transpose(0, 1).cuda()


Grab a non-empty patch

In [45]:
i, j = non_empty_patches[4]
current_patch = patch_intersections[i, j]

Given the current patch, find its associated gaussians

In [46]:
# Projected flattened are the pixel coordinates of each gaussian --> current patch is the pixels of the mask
projected_flattened = proj_results['proj_flattened'] # (M,)
patch_gaussians = current_patch[projected_flattened.cpu()].nonzero().squeeze(-1)

# This should pass --> need to check all found patch gaussians are in the valid gaussians
assert torch.isin(patch_gaussians.detach().cpu(), proj_results['gaussian_ids'].detach().cpu()).all()


In [48]:
# Grab the depths of the gaussians in the patch
patch_gaussian_depths = proj_results['proj_depths'][patch_gaussians]

In [49]:
front_percentage = 0.2
num_patch_gaussians = len(patch_gaussians)
num_front_gaussians = max(int(front_percentage * num_patch_gaussians), 1)

# Sort the gaussians by depth
sorted_gaussian_ids = torch.argsorxxt(patch_gaussian_depths)

# gaussians sorted by their depths and select the front based on percentage
sorted_gaussians = patch_gaussians[sorted_gaussian_ids]
selected_gaussians = sorted_gaussians[:num_front_gaussians]

In [415]:
def get_mask_gaussians(model, image, camera, composite_mask, n_patches: int = 32, front_percentage: float = 0.5):

    # Project the gaussians to 2d
    proj_results = project_gaussians(model, camera)

    # Decimate the composite mask into individual masks
    binary_masks = mask_id_to_binary_mask(composite_mask)
    flattened_masks = torch.tensor(binary_masks).flatten(start_dim=1)

    # Create a patch mask --> find the intersection between the composite mask and the patch mask
    patch_mask = create_patch_mask(image, n_patches)

    gaussians = []

    for mask in flattened_masks:

        # Find the intersection between the mask and the patch mask
        patch_intersections = mask.unsqueeze(0).unsqueeze(0) & patch_mask

        # Find the non-empty patches
        patch_sums = patch_intersections.sum(dim=2)  # Sum pixels per patch
        non_empty_patches = (patch_sums > 0).nonzero(as_tuple=False)

        # If there are no non-empty patches, add an empty tensor to the gaussians list
        if len(non_empty_patches) == 0:
            gaussians.append(torch.tensor([], dtype=torch.long))
            continue
        
        # Find the gaussian ids that are inside the non-empty patches
        mask_gaussians = []

        for patch_idx in non_empty_patches:
            i, j = patch_idx
            current_patch = patch_intersections[i, j]

            # current_patch

            # Find the gaussian ids that are inside the current patch
            patch_gaussians = torch.where(current_patch)[0]

            # Find the gaussian ids that are inside the current patch
            patch_gaussians = torch.where(current_patch)[0]

In [494]:
proj_results = project_gaussians(model, camera)

i, j = non_empty_patches[4]
current_patch = patch_intersections[i, j]

  torch.tensor(get_world2view_transform(R, T, trans, scale)).transpose(0, 1).cuda()
