In [1]:
import time
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import torch
import pytorch3d.ops
from plyfile import PlyData, PlyElement
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
from argparse import ArgumentParser, Namespace
import cv2

from arguments import ModelParams, PipelineParams
from scene import Scene, GaussianModel, FeatureGaussianModel
from gaussian_renderer import render, render_contrastive_feature


from utils.sh_utils import SH2RGB

def get_combined_args(parser : ArgumentParser, model_path, target_cfg_file = None):
    cmdlne_string = ['--model_path', model_path]
    cfgfile_string = "Namespace()"
    args_cmdline = parser.parse_args(cmdlne_string)
    
    if target_cfg_file is None:
        if args_cmdline.target == 'seg':
            target_cfg_file = "seg_cfg_args"
        elif args_cmdline.target == 'scene' or args_cmdline.target == 'xyz':
            target_cfg_file = "cfg_args"
        elif args_cmdline.target == 'feature' or args_cmdline.target == 'coarse_seg_everything' or args_cmdline.target == 'contrastive_feature' :
            target_cfg_file = "feature_cfg_args"

    try:
        cfgfilepath = os.path.join(model_path, target_cfg_file)
        print("Looking for config file in", cfgfilepath)
        with open(cfgfilepath) as cfg_file:
            print("Config file found: {}".format(cfgfilepath))
            cfgfile_string = cfg_file.read()
    except TypeError:
        print("Config file found: {}".format(cfgfilepath))
        pass
    args_cfgfile = eval(cfgfile_string)

    merged_dict = vars(args_cfgfile).copy()
    for k,v in vars(args_cmdline).items():
        if v != None:
            merged_dict[k] = v

    return Namespace(**merged_dict)

In [2]:
import os
FEATURE_DIM = 32 # fixed

# MODEL_PATH = './output/lerf-fruit_aisle/'
MODEL_PATH = './output/lund_1024' # 30000

FEATURE_GAUSSIAN_ITERATION = 10000

SCALE_GATE_PATH = os.path.join(MODEL_PATH, f'point_cloud/iteration_{str(FEATURE_GAUSSIAN_ITERATION)}/scale_gate.pt')

FEATURE_PCD_PATH = os.path.join(MODEL_PATH, f'point_cloud/iteration_{str(FEATURE_GAUSSIAN_ITERATION)}/contrastive_feature_point_cloud.ply')
SCENE_PCD_PATH = os.path.join(MODEL_PATH, f'point_cloud/iteration_{str(FEATURE_GAUSSIAN_ITERATION)}/scene_point_cloud.ply')

In [3]:
scale_gate = torch.nn.Sequential(
    torch.nn.Linear(1, 32, bias=True),
    torch.nn.Sigmoid()
)

scale_gate.load_state_dict(torch.load(SCALE_GATE_PATH))
scale_gate = scale_gate.cuda()

parser = ArgumentParser(description="Testing script parameters")
model = ModelParams(parser, sentinel=True)
pipeline = PipelineParams(parser)
parser.add_argument('--target', default='scene', type=str)

args = get_combined_args(parser, MODEL_PATH)

dataset = model.extract(args)

# If use language-driven segmentation, load clip feature and original masks
dataset.need_features = True

# To obtain mask scales
dataset.need_masks = True

scene_gaussians = GaussianModel(dataset.sh_degree)

feature_gaussians = FeatureGaussianModel(FEATURE_DIM)
scene = Scene(dataset, scene_gaussians, feature_gaussians, load_iteration=-1, feature_load_iteration=FEATURE_GAUSSIAN_ITERATION, shuffle=False, mode='eval', target='contrastive_feature')


Looking for config file in ./output/lund_1024/cfg_args
Config file found: ./output/lund_1024/cfg_args
Loading trained model at iteration 30000, 10000
Allow Camera Principle Point Shift: False
Reading camera 1196/1196
✅ Loaded 1196 cameras for this GPU (start_idx=0, end_idx=None)
Loading Training Cameras
Loading Test Cameras


In [4]:
from sklearn.preprocessing import QuantileTransformer
# Borrowed from GARField, but modified
def get_quantile_func(scales: torch.Tensor, distribution="normal"):
    """
    Use 3D scale statistics to normalize scales -- use quantile transformer.
    """
    scales = scales.flatten()

    scales = scales.detach().cpu().numpy()
    print(scales.max(), '?')

    # Calculate quantile transformer
    quantile_transformer = QuantileTransformer(output_distribution=distribution)
    quantile_transformer = quantile_transformer.fit(scales.reshape(-1, 1))

    
    def quantile_transformer_func(scales):
        scales_shape = scales.shape

        scales = scales.reshape(-1,1)
        
        return torch.Tensor(
            quantile_transformer.transform(scales.detach().cpu().numpy())
        ).to(scales.device).reshape(scales_shape)

    return quantile_transformer_func, quantile_transformer
    
all_scales = []
for cam in scene.getTrainCameras():
    scale_path = os.path.join(dataset.source_path, 'mask_scales', cam.image_name + '.pt')
    scales = torch.load(scale_path)
    all_scales.append(scales)

all_scales = torch.cat(all_scales)

upper_bound_scale = all_scales.max().item()
# upper_bound_scale = np.percentile(all_scales.detach().cpu().numpy(), 75)

# all_scales = []
# for cam in scene.getTrainCameras():
#     cam.mask_scales = torch.clamp(cam.mask_scales, 0, upper_bound_scale).detach()
#     all_scales.append(cam.mask_scales)
# all_scales = torch.cat(all_scales)

# quantile transformer
q_trans, q_trans_ = get_quantile_func(all_scales, 'uniform')

18.245272 ?


In [5]:
all_scales.max()

tensor(18.2453, grad_fn=<MaxBackward1>)

In [6]:
upper_bound_scale

18.245271682739258

In [7]:
q_trans(torch.Tensor([70]))

tensor([1.])

In [8]:
'''with torch.no_grad():
    # If the q_trans is normal
    # scale = 2.
    # scale = torch.full((1,), scale).cuda()
    # scale = q_trans(scale)

    # If the q_trans is uniform, the scale can be any value between 0 and 1
    # scale = torch.tensor([0]).cuda()
    # scale = torch.tensor([0.5]).cuda()
    scale = torch.tensor([1.5]).cuda()

    gates = scale_gate(scale)

    #feature_with_scale = rendered_feature
    #feature_with_scale = feature_with_scale * gates.unsqueeze(-1).unsqueeze(-1)
    #scale_conditioned_feature = feature_with_scale.permute([1,2,0])

    #plt.imshow(scale_conditioned_feature[:,:,:3].detach().cpu().numpy())'''

'with torch.no_grad():\n    # If the q_trans is normal\n    # scale = 2.\n    # scale = torch.full((1,), scale).cuda()\n    # scale = q_trans(scale)\n\n    # If the q_trans is uniform, the scale can be any value between 0 and 1\n    # scale = torch.tensor([0]).cuda()\n    # scale = torch.tensor([0.5]).cuda()\n    scale = torch.tensor([1.5]).cuda()\n\n    gates = scale_gate(scale)\n\n    #feature_with_scale = rendered_feature\n    #feature_with_scale = feature_with_scale * gates.unsqueeze(-1).unsqueeze(-1)\n    #scale_conditioned_feature = feature_with_scale.permute([1,2,0])\n\n    #plt.imshow(scale_conditioned_feature[:,:,:3].detach().cpu().numpy())'

In [9]:
with torch.no_grad():
    # Get all scales from all cameras (or just one if batching)
    all_mask_scales = []
    for cam in scene.getTrainCameras():
        scale_path = os.path.join(dataset.source_path, 'mask_scales', cam.image_name + '.pt')
        mask_scale = torch.load(scale_path).reshape(-1)
        all_mask_scales.append(mask_scale)

    all_mask_scales = torch.cat(all_mask_scales).cuda()  # Shape: [N]
    normed_scales = q_trans(all_mask_scales).reshape(-1, 1)  # Shape: [N, 1]

    # Apply scale gate
    gate_output = scale_gate(normed_scales)               # [N, 32]
    gate_scores = gate_output.mean(dim=1)                 # [N]


In [10]:
from copy import deepcopy
cameras = scene.getTrainCameras()
print("There are",len(cameras),"views in the dataset.")
print(upper_bound_scale)

There are 1196 views in the dataset.
18.245271682739258


In [11]:
# Get 3D Gaussian point features
point_features = feature_gaussians.get_point_features  # [N, 32]

# Get per-Gaussian scale: use norm of the 3D scale vector
scales_3d = scene_gaussians.get_scaling  # [N, 3]
mean_scale = scales_3d.norm(dim=1)       # [N]

# Normalize using quantile transformer
normalized_scale = q_trans(mean_scale)   # [N]

# Compute scale gate values
with torch.no_grad():
    gate_scores = scale_gate(normalized_scale.unsqueeze(1)).mean(dim=1)  # [N]

# Compute scale-aware features
scale_conditioned_point_features = torch.nn.functional.normalize(point_features, dim=-1, p=2) * gate_scores.unsqueeze(1)

normed_point_features = torch.nn.functional.normalize(scale_conditioned_point_features, dim = -1, p = 2)

sample_mask = torch.rand(scale_conditioned_point_features.shape[0]) > 0.98
sampled_point_features = scale_conditioned_point_features[sample_mask]
normed_sampled_point_features = sampled_point_features / torch.norm(sampled_point_features, dim = -1, keepdim = True)

print(len(sampled_point_features))

20254


In [12]:
import hdbscan

clusterer = hdbscan.HDBSCAN(min_cluster_size=50, cluster_selection_epsilon=0.15)

'''cluster_labels = clusterer.fit_predict(normed_sampled_point_features.detach().cpu().numpy())
print(np.unique(cluster_labels))

cluster_centers = torch.zeros(len(np.unique(cluster_labels))-1, normed_sampled_point_features.shape[-1])
for i in range(1, len(np.unique(cluster_labels))):
    cluster_centers[i-1] = torch.nn.functional.normalize(normed_sampled_point_features[cluster_labels == i-1].mean(dim = 0), dim = -1)'''

'cluster_labels = clusterer.fit_predict(normed_sampled_point_features.detach().cpu().numpy())\nprint(np.unique(cluster_labels))\n\ncluster_centers = torch.zeros(len(np.unique(cluster_labels))-1, normed_sampled_point_features.shape[-1])\nfor i in range(1, len(np.unique(cluster_labels))):\n    cluster_centers[i-1] = torch.nn.functional.normalize(normed_sampled_point_features[cluster_labels == i-1].mean(dim = 0), dim = -1)'

In [13]:
from sklearn.decomposition import PCA
from sklearn.neighbors import NearestNeighbors
from collections import Counter
import numpy as np

# Step 1: Extract features and apply PCA
feat_array = normed_sampled_point_features.detach().cpu().numpy()
pca_feats = PCA(n_components=20).fit_transform(feat_array)

# Step 2: Run HDBSCAN on PCA-reduced features
import hdbscan
clusterer = hdbscan.HDBSCAN(min_cluster_size=80, cluster_selection_epsilon=0.1)
cluster_labels = clusterer.fit_predict(pca_feats)

# Step 3: Merge small clusters using PCA features
min_size = 50
label_counts = Counter(cluster_labels)

# Identify small and large cluster masks
big_mask = np.array([label_counts[lbl] >= min_size and lbl != -1 for lbl in cluster_labels])
small_mask = ~big_mask

big_feats = pca_feats[big_mask]
big_labels = cluster_labels[big_mask]
small_feats = pca_feats[small_mask]

# Nearest neighbor matching
nn = NearestNeighbors(n_neighbors=1).fit(big_feats)
_, indices = nn.kneighbors(small_feats)
nearest_labels = big_labels[indices[:, 0]]

# Apply merged labels
merged_labels = cluster_labels.copy()
merged_labels[small_mask] = nearest_labels
full_labels = -np.ones(scale_conditioned_point_features.shape[0], dtype=int)
full_labels[sample_mask.cpu().numpy()] = merged_labels



In [14]:
cluster_centers = torch.zeros(len(np.unique(merged_labels))-1, normed_sampled_point_features.shape[-1])
for i in range(1, len(np.unique(merged_labels))):
    cluster_centers[i-1] = torch.nn.functional.normalize(normed_sampled_point_features[merged_labels == i-1].mean(dim = 0), dim = -1)

IndexError: The shape of the mask [1007732] at index 0 does not match the shape of the indexed tensor [20254, 32] at index 0

In [None]:
seg_score = torch.einsum('nc,bc->bn', cluster_centers.cpu(), normed_point_features.cpu())

In [None]:
print(seg_score.max().item(), seg_score.mean().item(), seg_score.min().item())


In [None]:
label_to_color = np.random.rand(1000, 3)
point_colors = label_to_color[seg_score.argmax(dim = -1).cpu().numpy()]
point_colors[seg_score.max(dim = -1)[0].detach().cpu().numpy() < 0.2] = (0,0,0)

In [None]:
import numpy as np
from plyfile import PlyData, PlyElement

# Inputs: assume these are numpy arrays
positions = scene_gaussians.get_xyz.detach().cpu().numpy()      # [N, 3]
labels = cluster_labels                                          # [N]

# Optional mask
if 'selected_mask' in locals():
    selected_mask_np = selected_mask.cpu().numpy()
    positions = positions[selected_mask_np]
    labels = labels[selected_mask_np]

# Create random color for each cluster
num_clusters = int(labels.max()) + 1
label_to_color = (np.random.rand(num_clusters + 1, 3) * 255).astype(np.uint8)
colors = label_to_color[labels]

# Build structured array for PlyElement
vertex_data = np.array(
    [(*pos, *color) for pos, color in zip(positions, colors)],
    dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'),
           ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]
)

ply_element = PlyElement.describe(vertex_data, 'vertex')
PlyData([ply_element], text=True).write('clustered_output.ply')

print("Wrote clustered_output.ply with", len(positions), "points.")


In [None]:
try:
    scene_gaussians.roll_back()
except:
    pass

In [None]:
bg_color = [0 for i in range(FEATURE_DIM)]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")


rendered_seg_map = render(cameras[17], scene_gaussians, pipeline.extract(args), background, override_color=torch.from_numpy(point_colors).cuda().float())['render']

In [None]:
plt.imshow(rendered_seg_map.permute([1,2,0]).detach().cpu().numpy())

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Count how many points are in each cluster
labels = merged_labels  # already computed from HDBSCAN
unique_labels, counts = np.unique(labels, return_counts=True)

# Skip noise (-1), if needed
valid = unique_labels != -1
unique_labels = unique_labels[valid]
counts = counts[valid]

# Sort by size (optional)
sorted_indices = np.argsort(-counts)
sorted_labels = unique_labels[sorted_indices]
sorted_counts = counts[sorted_indices]

# Plot
plt.figure(figsize=(12,5))
plt.bar(range(len(sorted_counts)), sorted_counts)
plt.xlabel("Cluster Index (sorted)")
plt.ylabel("Number of Gaussians")
plt.title("Number of Gaussians per Cluster")
plt.show()


In [None]:
#matplotlib notebook  # Use this only in a notebook, not a script
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
img = rendered_seg_map.permute([1, 2, 0]).detach().cpu().numpy()
ax.imshow(img)

def onclick(event):
    if event.xdata is None or event.ydata is None:
        return  # Ignore clicks outside image
    x = int(event.xdata)
    y = int(event.ydata)
    print(f"Clicked at: (x={x}, y={y})")

    # Compute flat index for merged_labels if it's 1D
    idx = y * img.shape[1] + x
    cluster_id = full_labels[idx]
    print(f"Cluster ID at that point: {cluster_id.item()}")

cid = fig.canvas.mpl_connect('button_press_event', onclick)


In [None]:
unique_labels = torch.unique(seg_score.argmax(dim=-1))  # [num_clusters]
normed_point_features = torch.nn.functional.normalize(feature_gaussians.get_point_features, dim=-1, p=2)

cluster_features = []
for label in unique_labels:
    mask = (seg_score.argmax(dim=-1) == label)
    cluster_feat = normed_point_features[mask].mean(dim=0)
    cluster_feat = torch.nn.functional.normalize(cluster_feat, dim=0, p=2)
    cluster_features.append(cluster_feat)

cluster_features = torch.stack(cluster_features).cuda()  # [num_clusters, F]


In [None]:
import clip
import torch

def load_clip():
    model, _ = clip.load("ViT-B/32", device="cuda")
    return model

def encode_text(model, text):
    tokens = clip.tokenize([text]).cuda()
    with torch.no_grad():
        return model.encode_text(tokens).squeeze(0)

# Load CLIP
clip_model = load_clip().eval()

# Positive prompts
positive_prompts = ["house", "home", "residential building", "villa", "apartment", "building"]
positive_feats = [encode_text(clip_model, p).float() for p in positive_prompts]
positive_feat = torch.stack(positive_feats).mean(dim=0)

# Negative prompts
negative_prompts = ["trees", "green","roads", "sky", "car", "grass","plants","bush","hill","garden","water","path","street"]
negative_feats = [encode_text(clip_model, p).float() for p in negative_prompts]
negative_feat = torch.stack(negative_feats).mean(dim=0)

# Normalize both
positive_feat = torch.nn.functional.normalize(positive_feat, dim=0, p=2)
negative_feat = torch.nn.functional.normalize(negative_feat, dim=0, p=2)

# Final composite feature
text_feat = torch.nn.functional.normalize(positive_feat, dim=0)


In [None]:
projection = torch.nn.Linear(32, 512).cuda()
# Project cluster features to CLIP space if needed
projected = projection(cluster_features.float())  # Make sure cluster_features.shape = [C, 32]
projected = torch.nn.functional.normalize(projected, dim=-1)
# Compute similarity with prompt
clip_scores = torch.einsum('cf,f->c', projected, text_feat)

topk = 5
topk_values, topk_indices = torch.topk(clip_scores, topk)

selected_mask = torch.zeros_like(full_labels, dtype=torch.bool)
for idx in topk_indices:
    selected_mask |= (full_labels == idx)
    count = (full_labels == idx).sum().item()
    print(f"Cluster {idx.item()} → Count: {count} Gaussians")




In [None]:
#selected_mask = (seg_score.argmax(dim=-1) == best_label)
scene_gaussians.segment(selected_mask)
torch.save(selected_mask, './segmentation_res/clip_guided_cluster_segment.pt')


In [None]:
bg_color = [1 for _ in range(FEATURE_DIM)]
background = torch.tensor(bg_color, dtype=torch.float32, device="cuda")
rendered = render(cameras[17], scene_gaussians, pipeline.extract(args), background)['render']
plt.imshow(rendered.permute(1, 2, 0).detach().cpu())


In [None]:
import torch
import open3d as o3d
import numpy as np

# Paths
PLY_PATH = "path/to/contrastive_feature_point_cloud.ply"
SCALE_GATE_PATH = "path/to/scale_gate.pt"
OUTPUT_PATH = "filtered_point_cloud.ply"

# Load point cloud
pcd = o3d.io.read_point_cloud(PLY_PATH)
points = np.asarray(pcd.points)

# Prepare dummy "scales" or any relevant feature (this may differ if original .ply had attributes)
# Here we use distances from origin as an example feature. Replace as needed.
distances = np.linalg.norm(points, axis=1).reshape(-1, 1)

# Load scale gate
scale_gate = torch.nn.Sequential(
    torch.nn.Linear(1, 32, bias=True),
    torch.nn.Sigmoid()
)
scale_gate.load_state_dict(torch.load(SCALE_GATE_PATH))
scale_gate.eval()

# Compute gate values
with torch.no_grad():
    input_tensor = torch.tensor(distances, dtype=torch.float32)
    gate_output = scale_gate(input_tensor).mean(dim=1).numpy()  # mean over 32-dim output

# Thresholding
lower_thresh = 0.3
upper_thresh = 0.95
mask = (gate_output >= lower_thresh) & (gate_output <= upper_thresh)
filtered_points = points[mask]

# Save filtered cloud
filtered_pcd = o3d.geometry.PointCloud()
filtered_pcd.points = o3d.utility.Vector3dVector(filtered_points)
o3d.io.write_point_cloud(OUTPUT_PATH, filtered_pcd)
print(f"Filtered point cloud saved to {OUTPUT_PATH}")
