In [3]:
%load_ext autoreload
%autoreload 2

import os, sys
from pathlib import Path

from matplotlib import pyplot as plt
# from tqdm import tqdm
from tqdm.notebook import tqdm
from pathlib import Path

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Load a classifier

In [6]:
SPLATS = {
    'rats': {
        'base_dir': '/workspace/fieldwork-data/rats/2024-07-11/environment/C0119/rade-features',
        'load_config': "2025-07-25_074037/config.yml",
    },
    'birds': {
        'base_dir': '/workspace/fieldwork-data/birds/2024-02-06/environment/C0043/rade-features',
        'load_config': "2025-07-25_040743/config.yml",
    }
}

In [7]:
import torch
from collab_splats.utils.grouping import GroupingClassifier, GroupingConfig

# Path to the config for a trained model
species = 'birds'
base_dir = Path(SPLATS[species]['base_dir'])
load_config = base_dir / SPLATS[species]['load_config']

# saved_model = Path(base_dir) / "grouping" / "checkpoints" / "grouping-classifier-v1.ckpt"

# if saved_model.exists():
#     print (f"Loading model from {saved_model}")
#     grouping_classifier = GroupingClassifier.load_from_checkpoint(saved_model)

#     grouping_classifier.load_pipeline()
# else:
grouping_config = GroupingConfig(
    segmentation_backend='mobilesamv2', 
    segmentation_strategy='object', 
    front_percentage=0.2, 
    iou_threshold=0.1, 
    num_patches=32,
    identity_dim=8,
    # lr=5e-5
)

grouping_classifier = GroupingClassifier(load_config=load_config, config=grouping_config)

# grouping_classifier.identities

# # Step 2: Load checkpoint state_dict only
# if saved_model.exists():
#     checkpoint = torch.load(saved_model)
#     state_dict = checkpoint['state_dict']
#     grouping_classifier.load_state_dict(state_dict, strict=False)

# # Step 3: Inject runtime pipeline & model
# grouping_classifier.load_pipeline()        # loads the NeRF pipeline at runtime
# grouping_classifier.load_segmentation()    # loads the segmentation backend

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
[Taichi] version 1.7.4, llvm 15.0.4, commit b4b956fd, linux, python 3.10.18


[I 09/25/25 16:30:56.763 126422] [shell.py:_shell_pop_print@23] Graphical python shell detected, using wrapped sys.stdout


Memory bank loaded from /workspace/fieldwork-data/birds/2024-02-06/environment/C0043/rade-features/grouping/memory_bank.pkl with 2569 masks


### Associate masks

In [None]:
grouping_classifier.create_masks()

In [None]:
grouping_classifier.associate()

### Try pytorch lightning datamodule

Train identity embeddings to lift objects from 2d to 3d

In [8]:
grouping_classifier.load_pipeline()

Loading NeRF pipeline and model...


In [9]:
import wandb
from pytorch_lightning.loggers import WandbLogger

logger = WandbLogger(
    project="collab-splats", 
    name=f"grouping_{species}",
    log_model=False
)

# grouping_classifier.config.identity_dim = 16

# # Use simulated data (10 total mask types)
# grouping_classifier.total_masks = 10

grouping_classifier.setup()

Loading NeRF pipeline and model...


In [None]:
grouping_classifier.lift_segmentation(logger=logger)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX A5000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


None


[34m[1mwandb[0m: Currently logged in as: [33mtbotch[0m ([33mfinnlab[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Loading NeRF pipeline and model...


### Random junk

In [None]:
from collab_splats.utils.grouping import GroupingDataModule 

datamodule = GroupingDataModule(
    datamanager=grouping_classifier.pipeline.datamanager,
    mask_dir=grouping_classifier.associated_mask_dir,
    device="cuda",
    train_num_workers=0,
    val_num_workers=0,
    use_simulated=True
)

In [None]:
plt.imshow(x[1]['segmentation'].detach().cpu().numpy())

In [None]:
torch.unique(x[1]['segmentation'])

In [None]:
grouping_classifier.training_step(x, 0)

In [None]:
outs = grouping_classifier(x[0])

In [None]:
plt.imshow(outs['identities'][..., :3].detach().cpu().numpy())

In [None]:
outs['identities']

In [None]:
import numpy as np

grouping_classifier.eval()
logits = grouping_classifier(camera)

labels = logits.argmax(0).detach().cpu().numpy()
unique_labels = np.unique(labels)

print (unique_labels)

plt.imshow(labels)

plt.imshow(data['segmentation'])

In [None]:
from torch.nn import CrossEntropyLoss

identities = identities.unsqueeze(0)
segmentation = data['segmentation'].unsqueeze(0).to(grouping_classifier.model.device)

CrossEntropyLoss(reduction="none")(identities, segmentation)

# grouping_classifier.loss_fn(identities, data['segmentation'])

### Try to map onto the mesh?

In [None]:
import pickle
import open3d as o3d
from collab_splats.utils.mesh import features2vertex


mesh_dir = grouping_classifier.output_dir.parent / 'mesh'

mesh_path = mesh_dir / 'mesh.ply'
transforms_path = mesh_dir / 'transforms.pkl'

with open(transforms_path, 'rb') as f:
    transforms = pickle.load(f)

mesh = o3d.io.read_triangle_mesh(mesh_path)


In [None]:
# Transform the means to the mesh
means = grouping_classifier.model.means.clone()
means = means @ transforms["mesh_transform"][:3, :3].T + transforms["mesh_transform"][:3, 3]

# Get the classes for each point
classes = grouping_classifier.per_gaussian_forward(grouping_classifier.identities)
classes = classes.argmax(-1).unsqueeze(-1)

Map to the mesh

In [None]:
mesh_classes = features2vertex(
    mesh_vertices=mesh.vertices,
    points=means,
    features=classes,
    categorical=True
)

In [None]:
import torch

# Create RGB colors for each unique class
unique_classes = torch.unique(mesh_classes)
n_classes = len(unique_classes)

# Generate distinct colors using HSV colorspace for better visual separation
import matplotlib.pyplot as plt
cmap = plt.get_cmap('tab10')  # or 'viridis', 'plasma', etc.

# Create color mapping
class_to_rgb = {}
for i, class_id in enumerate(unique_classes):
    color = cmap(i / max(1, n_classes - 1))  # Normalize to [0,1]
    class_to_rgb[class_id.item()] = torch.tensor(color[:3], dtype=torch.float32)  # RGB only

# Map classes to RGB colors
rgb_colors = torch.zeros(mesh_classes.shape[0], 3, dtype=torch.float32)
for i, class_id in enumerate(mesh_classes.squeeze()):
    rgb_colors[i] = class_to_rgb[class_id.item()]

rgb_colors

In [None]:
import pyvista as pv

mesh = pv.read(mesh_path.as_posix())

print(mesh.n_points)
print(mesh.n_cells)
print(mesh.bounds)

image = mesh.plot(
    scalars=rgb_colors, 
    rgb=True,
    screenshot=True
)

plt.imshow(image)