In [2]:
import trimesh
import torch
import mcubes
import numpy as np
from skgstat import Variogram
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from scipy.spatial.distance import pdist, squareform
from ipywidgets import interact, FloatSlider
import math

In [3]:
# Model: Suzanne (Monkey head). 100 images normal, 100 images with right ear removed.
device = 'cuda'
pth_file = 'experiments/suzanne/set100/models/M0.pth'
nerf_model = torch.load(pth_file).to(device)

  nerf_model = torch.load(pth_file).to(device)


In [4]:
# Grid of points on scne p_XYZ
N = 35
scale = 1.5
x = torch.linspace(-scale, scale, N, device=device)
y = torch.linspace(-scale, scale, N, device=device)
z = torch.linspace(-scale, scale, N, device=device)
x, y, z = torch.meshgrid((x, y, z))
xyz = torch.cat((x.reshape(-1, 1), y.reshape(-1, 1), z.reshape(-1, 1)), dim=1).to(device)


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [5]:
# Canonical perspective view directions
view_directions = torch.tensor([
    [0, 0, 1],  # Top-down
    [math.cos(math.radians(45)) * math.cos(math.radians(30)), 
     math.sin(math.radians(45)) * math.cos(math.radians(30)), 
     math.sin(math.radians(30))],  # Front-right (30° elevation, 45° azimuth)
    [math.cos(math.radians(-45)) * math.cos(math.radians(30)), 
     math.sin(math.radians(-45)) * math.cos(math.radians(30)), 
     math.sin(math.radians(30))],  # Front-left (30° elevation, -45° azimuth)
    [math.cos(math.radians(135)) * math.cos(math.radians(30)), 
     math.sin(math.radians(135)) * math.cos(math.radians(30)), 
     math.sin(math.radians(30))],  # Back-right (30° elevation, 135° azimuth)
    [math.cos(math.radians(-135)) * math.cos(math.radians(30)), 
     math.sin(math.radians(-135)) * math.cos(math.radians(30)), 
     math.sin(math.radians(30))]   # Back-left (30° elevation, -135° azimuth)
], device=device)

# Compute averaged RGB from 5 views
averaged_rgb = torch.zeros(xyz.shape[0], 3, device=device)
for direction in view_directions:
    rgb, _ = nerf_model.forward(xyz, direction.expand(xyz.shape[0], -1))
    averaged_rgb += rgb

averaged_rgb /= len(view_directions)

xyz = xyz.cpu().detach().numpy()
averaged_rgb = averaged_rgb.cpu().detach().numpy()

### Masking: How we filter the points we will consider
Helps with restricting points we are interested in estimating change for.

In [6]:
threshold_value = np.quantile(averaged_rgb.mean(axis=1), 0.5)
mask = averaged_rgb.mean(axis=1) >= threshold_value

# Filter points and RGB values
filtered_xyz = xyz[mask]
filtered_rgb = averaged_rgb[mask]

### Computing the Variogram: color
Core idea: correlation between samples decreases over distance.
A measure of spatial continuity. Finds the semi-variance for all points in our space

In [7]:
color_variogram = Variogram(
    filtered_xyz,
    filtered_rgb.mean(axis=1),
    model='spherical', # The spherical model increases linearly at short distances and levels off at a certain range, indicating that beyond this range, data points are uncorrelated.
    normalize=False,
    nugget=0.1
)

### Determining point-wise uncertanties from the Variogram

In [8]:
pairwise_distances = squareform(pdist(filtered_xyz)) # distance between every point for each point
bin_edges = color_variogram.bins
bin_uncertainties = color_variogram.experimental # gets the semi-variance (dismilarity measure) for points in each distance based bin of the variogram
point_uncertainties = np.zeros(filtered_xyz.shape[0])

# determine uncertainty for each point based on the average semi-variance of neighboring points
for i in range(filtered_xyz.shape[0]):
    distances = pairwise_distances[i]
    bin_indices = np.digitize(distances, bin_edges, right=True)
    neighbor_uncertainties = bin_uncertainties[bin_indices - 1]
    point_uncertainties[i] = np.mean(neighbor_uncertainties)

# Normalizing uncertainties [0,1]
point_uncertainties = (point_uncertainties - np.min(point_uncertainties)) / (
    np.max(point_uncertainties) - np.min(point_uncertainties)
)

### Visualization Code

In [9]:
colormap = cm.get_cmap('inferno')
colors = colormap(point_uncertainties)[:, :3]  # Get RGB values

  colormap = cm.get_cmap('inferno')


In [10]:
def update_scene(threshold):
    # Apply dynamic threshold
    threshold_mask = point_uncertainties >= threshold
    filtered_xyz_thresholded = filtered_xyz[threshold_mask]
    point_uncertainties_thresholded = point_uncertainties[threshold_mask]
    colors_thresholded = colors[threshold_mask]

    # Define sphere sizes for the remaining points
    sphere_sizes_thresholded = 0.05 + point_uncertainties_thresholded * 0.001

    # Create spheres for the thresholded points
    spheres = []
    for point, size, color in zip(filtered_xyz_thresholded, sphere_sizes_thresholded, colors_thresholded):
        sphere = trimesh.primitives.Sphere(
            radius=size, center=point, subdivisions=2  # Subdivisions for smoothness
        )
        sphere.visual.vertex_colors = (color * 255).astype(np.uint8)
        spheres.append(sphere)

    # Add mesh for spatial context
    density_np = averaged_rgb.mean(axis=1).reshape(N, N, N)  # Use the full averaged RGB for mesh visualization
    vertices, triangles = mcubes.marching_cubes(density_np, 3 * np.mean(density_np))
    vertices_scaled = (vertices / N) * (2 * scale) - scale
    mesh = trimesh.Trimesh(vertices_scaled, triangles)

    # Draw view direction vectors
    center = np.array([0, 0, 0])  # Assume object is centered at the origin
    view_lines = []
    for direction in view_directions.cpu().numpy():
        arrow_start = center
        arrow_end = center + 3 * direction  # Make vectors 3x longer
        line = trimesh.load_path(np.array([arrow_start, arrow_end]))
        
        # Assign red color to the path
        line_colors = np.array([[255, 0, 0, 255]] * len(line.entities))  # RGBA for red, fully opaque
        line.colors = line_colors  # Assign per-entity colors
        line.width = 2.0  # Set line thickness
        view_lines.append(line)

    # Combine the mesh, spheres, and view vectors into a single scene
    scene = trimesh.Scene([mesh] + spheres + view_lines)

    # Show the scene
    scene.export("scene.glb")

In [None]:
interact(update_scene, threshold=FloatSlider(value=0.5, min=0.0, max=1.0, step=0.01))

interactive(children=(FloatSlider(value=0.5, description='threshold', max=1.0, step=0.01), Output()), _dom_cla…

<function __main__.update_scene(threshold)>

: 