In [None]:
import os
import numpy as np
import torch
import torch.nn.functional as F
import cv2
import OpenEXR, Imath
import depth_pro
import laspy
from PIL import Image
from scipy import ndimage
from scipy.stats import skew
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from segment_anything.utils.transforms import ResizeLongestSide
import matplotlib.pyplot as plt

# ── Config ──
DATASET = "depth4"
INPUT_FOLDER = f"./data/{DATASET}"
OUTPUT_FOLDER = "./output"
GT_TO_CENTIMETERS = 10000.0
CAMERA_FOV = 90.0
MIN_DEPTH, MAX_DEPTH = 0.1, 50.0

os.makedirs(OUTPUT_FOLDER, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device.upper()}, Dataset: {DATASET}")

In [None]:
def load_exr_rgb(path):
    exr = OpenEXR.InputFile(path)
    dw = exr.header()['dataWindow']
    w, h = dw.max.x - dw.min.x + 1, dw.max.y - dw.min.y + 1
    FLOAT = Imath.PixelType(Imath.PixelType.FLOAT)
    rgb = np.stack([np.frombuffer(exr.channel(c, FLOAT), np.float32).reshape(h, w) for c in 'RGB'], axis=-1)
    return Image.fromarray(np.clip(rgb * 255, 0, 255).astype(np.uint8))

def load_exr_depth(path):
    exr = OpenEXR.InputFile(path)
    header = exr.header()
    dw = header['dataWindow']
    w, h = dw.max.x - dw.min.x + 1, dw.max.y - dw.min.y + 1
    FLOAT = Imath.PixelType(Imath.PixelType.FLOAT)
    channels = list(header['channels'].keys())
    for name in ['R', 'SceneDepth', 'Z']:
        if name in channels:
            depth = np.frombuffer(exr.channel(name, FLOAT), np.float32).reshape(h, w).copy()
            break
    depth_m = (depth * GT_TO_CENTIMETERS) / 100.0
    return depth_m

def load_image(path):
    return load_exr_rgb(path) if path.lower().endswith('.exr') else Image.open(path).convert('RGB')

# ── Find files ──
files = os.listdir(INPUT_FOLDER)
gt_rgb_path = next(os.path.join(INPUT_FOLDER, f) for f in files if f.endswith('.exr') and 'depth' not in f.lower() and 'scenedepth' not in f.lower())
edited_path = next(os.path.join(INPUT_FOLDER, f) for f in files if 'edit' in f.lower() and f.endswith(('.png', '.jpg', '.exr')))
gt_depth_path = next(os.path.join(INPUT_FOLDER, f) for f in files if 'SceneDepth' in f and 'WorldUnits' not in f and f.endswith('.exr'))

print(f"Original: {os.path.basename(gt_rgb_path)}")
print(f"Edited:   {os.path.basename(edited_path)}")
print(f"GT Depth: {os.path.basename(gt_depth_path)}")

In [None]:
# ── Load all data ──
original_img = load_image(gt_rgb_path)
edited_img = load_image(edited_path)
gt_depth = load_exr_depth(gt_depth_path)
h_gt, w_gt = gt_depth.shape

# Resize edited to match original for change detection
edited_resized = edited_img.resize(original_img.size, Image.BILINEAR) if edited_img.size != original_img.size else edited_img

print(f"Original: {original_img.size}, Edited: {edited_img.size}")
print(f"GT depth: {gt_depth.shape}, range: {gt_depth.min():.2f}m - {gt_depth.max():.2f}m")

In [None]:
# ── Change detection (GeSCF: SAM Q/K/V attention features + adaptive threshold) ──
weights_dir = "./weights"
os.makedirs(weights_dir, exist_ok=True)
weights_path = os.path.join(weights_dir, "sam_vit_b_01ec64.pth")

if not os.path.exists(weights_path):
    print("Downloading SAM ViT-B weights (~375 MB)...")
    import urllib.request
    urllib.request.urlretrieve("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth", weights_path)

sam = sam_model_registry["vit_b"](checkpoint=weights_path).to(device).eval()
h_img, w_img = np.array(original_img).shape[:2]

# Hook into block 8 (global attention) to capture Q/K/V features
captured = {}
def _hook_qkv(module, input, output):
    captured["qkv"] = output.detach()
hook = sam.image_encoder.blocks[8].attn.qkv.register_forward_hook(_hook_qkv)

# Preprocess both images for SAM
img_size = sam.image_encoder.img_size
sam_transform = ResizeLongestSide(img_size)
patch_size = 16
sam_scale = img_size / max(h_img, w_img)
feat_h, feat_w = int(h_img * sam_scale + 0.5) // patch_size, int(w_img * sam_scale + 0.5) // patch_size

def prepare_sam(pil_img):
    t = sam_transform.apply_image(np.array(pil_img))
    return sam.preprocess(torch.as_tensor(t, device=device).permute(2, 0, 1).unsqueeze(0).float())

# Extract features and compute cosine distance
with torch.no_grad():
    sam.image_encoder(prepare_sam(original_img))
    qkv1 = captured["qkv"]
    sam.image_encoder(prepare_sam(edited_resized))
    qkv2 = captured["qkv"]
hook.remove()

with torch.no_grad():
    f1, f2 = qkv1.squeeze(0), qkv2.squeeze(0)
    cos_sim = (F.normalize(f1, dim=-1) * F.normalize(f2, dim=-1)).sum(dim=-1)
    dist_map = (1 - cos_sim).cpu().numpy()[:feat_h, :feat_w]

# Upsample, smooth, normalize
dist_map = F.interpolate(torch.tensor(dist_map)[None, None].float(), size=(h_img, w_img), mode='bilinear', align_corners=False).squeeze().numpy()
dist_map = ndimage.gaussian_filter(dist_map, sigma=4)
dist_map = (dist_map - dist_map.min()) / (dist_map.max() - dist_map.min() + 1e-8)

# Adaptive threshold (skewness-based, from GeSCF paper)
sk = skew(dist_map.ravel())
k = np.clip(sk, 1.0, 3.0)
threshold = dist_map.mean() + k * dist_map.std()
initial_mask = dist_map > threshold

# SAM segment-level refinement
mask_gen = SamAutomaticMaskGenerator(sam, points_per_side=16, pred_iou_thresh=0.80, stability_score_thresh=0.85, min_mask_region_area=200)
sam_masks = mask_gen.generate(np.array(edited_resized))
changed_mask = np.zeros((h_img, w_img), dtype=bool)
for seg in sam_masks:
    m = seg["segmentation"]
    if m.sum() > 0 and np.logical_and(m, initial_mask).sum() / m.sum() > 0.3:
        changed_mask |= m
if changed_mask.sum() == 0:
    changed_mask = initial_mask

# Free SAM memory
del sam, mask_gen, qkv1, qkv2
torch.cuda.empty_cache()

# Resize mask to GT depth resolution
if changed_mask.shape != gt_depth.shape:
    changed_mask = cv2.resize(changed_mask.astype(np.uint8), (w_gt, h_gt), interpolation=cv2.INTER_NEAREST) > 0
unchanged_mask = ~changed_mask

print(f"Changed pixels: {changed_mask.sum():,} ({changed_mask.mean()*100:.1f}%)")
print(f"Threshold: {threshold:.4f} (skew={sk:.2f}, k={k:.2f})")

In [None]:
# ── Run Depth Pro on edited image ──
model, transform = depth_pro.create_model_and_transforms(device=device)
model.eval()

with torch.no_grad():
    pred = model.infer(transform(edited_img), f_px=None)
pred_depth = pred["depth"].squeeze().cpu().numpy()
pred_depth = cv2.resize(pred_depth, (w_gt, h_gt), interpolation=cv2.INTER_LINEAR)

del model
torch.cuda.empty_cache()
print(f"Predicted depth: {pred_depth.shape}, range: {pred_depth.min():.2f}m - {pred_depth.max():.2f}m")

In [None]:
# ── Least squares scaling on unchanged regions ──
valid = unchanged_mask & (gt_depth > 0.1) & (gt_depth < 100) & np.isfinite(pred_depth) & np.isfinite(gt_depth)
A = np.vstack([pred_depth[valid].flatten(), np.ones(valid.sum())]).T
scale, shift = np.linalg.lstsq(A, gt_depth[valid].flatten(), rcond=None)[0]
depth_scaled = pred_depth * scale + shift

mae = np.mean(np.abs(depth_scaled[unchanged_mask] - gt_depth[unchanged_mask]))
rmse = np.sqrt(np.mean((depth_scaled[unchanged_mask] - gt_depth[unchanged_mask]) ** 2))
print(f"LS fit: scale={scale:.4f}, shift={shift:.4f}")
print(f"Unchanged regions — MAE: {mae:.4f}m, RMSE: {rmse:.4f}m")

In [None]:
# ── Visualization ──
vmax = gt_depth.max()
fig, axes = plt.subplots(1, 4, figsize=(20, 5))

axes[0].imshow(edited_img); axes[0].set_title('Edited Image'); axes[0].axis('off')
axes[1].imshow(gt_depth, cmap='turbo', vmin=0, vmax=vmax); axes[1].set_title('GT Depth'); axes[1].axis('off')
axes[2].imshow(depth_scaled, cmap='turbo', vmin=0, vmax=vmax); axes[2].set_title('Scaled Prediction'); axes[2].axis('off')

error = np.abs(depth_scaled - gt_depth) * unchanged_mask
axes[3].imshow(error, cmap='inferno', vmin=0, vmax=0.5); axes[3].set_title(f'Error (MAE={mae:.3f}m)'); axes[3].axis('off')

plt.tight_layout(); plt.show()

In [None]:
# ── Generate point cloud ──
w_edit, h_edit = edited_img.size
depth_full = cv2.resize(depth_scaled, (w_edit, h_edit), interpolation=cv2.INTER_LINEAR)

focal = w_edit / (2 * np.tan(np.radians(CAMERA_FOV / 2)))
cx, cy = w_edit / 2, h_edit / 2
xx, yy = np.meshgrid(np.arange(w_edit), np.arange(h_edit))

z = depth_full.flatten()
x = (xx.flatten() - cx) * z / focal
y = (yy.flatten() - cy) * z / focal
rgb = np.array(edited_img).reshape(-1, 3)

mask = (z >= MIN_DEPTH) & (z <= MAX_DEPTH)
x, y, z, rgb = x[mask], y[mask], z[mask], rgb[mask]
print(f"Point cloud: {len(z):,} points")

In [None]:
# ── Save as LAS ──
output_path = f"{OUTPUT_FOLDER}/room_edited.las"

header = laspy.LasHeader(point_format=3, version="1.2")
header.scales = np.array([0.001, 0.001, 0.001])
las = laspy.LasData(header=header)

las.x = z       # forward
las.y = -x      # right
las.z = -y      # up
las.red = rgb[:, 0].astype(np.uint16) * 256
las.green = rgb[:, 1].astype(np.uint16) * 256
las.blue = rgb[:, 2].astype(np.uint16) * 256

las.write(output_path)
print(f"Saved {len(z):,} points to {output_path}")