# Visualize features and matches


In [None]:
import torch
import numpy as np
import open3d as o3d
import plotly.graph_objects as go
from model.model import Backbone
from model.matcher import MatchingHead
from utils.config import get_default_cfg

# ----------------------------
#        Load Config
# ----------------------------
cfg = get_default_cfg()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
layer_id = 0

# ----------------------------
#         Load Models
# ----------------------------
encoder = Backbone(cfg).to(device)
ckpt = torch.load("model_archive/simclr_encoder_curvature_weighted_0630.pth", map_location=device)
encoder.load_state_dict(ckpt["encoder"])
encoder.eval()

# No training checkpoint — use a simple, untrained matching head
matching_head = MatchingHead().to(device)
matching_head.eval()

# ----------------------------
#        Load Point Cloud
# ----------------------------
ply_path = "data/ply/000002_2014-05-26_14-23-37_260595134347_rgbf000103-resize_0000103.ply"
pcd = o3d.io.read_point_cloud(ply_path)
points_np = np.asarray(pcd.points, dtype=np.float32)
colors_np = np.asarray(pcd.colors, dtype=np.float32)

# Normalizing and Sampling
def normalize_points(points):
    centroid = np.mean(points, axis=0)
    points -= centroid
    scale = np.max(np.linalg.norm(points, axis=1))
    return points / scale

def normalize_then_sample(points_np, colors_np, num_points):
    normalized_xyz = normalize_points(points_np)
    if len(points_np) >= num_points:
        idxs = np.random.choice(len(points_np), num_points, replace=False)
    else:
        idxs = np.random.choice(len(points_np), num_points, replace=True)
    return np.concatenate([normalized_xyz[idxs], colors_np[idxs]], axis=1)

original_input = normalize_then_sample(points_np, colors_np, cfg.num_point)
augmented_input = normalize_then_sample(points_np, colors_np, cfg.num_point)

# Augment
def augment(xyzrgb):
    xyz, color = xyzrgb[:, :3], xyzrgb[:, 3:]
    angles = np.random.uniform(0, 2 * np.pi, size=3)
    Rx = np.array([[1, 0, 0],[0, np.cos(angles[0]), -np.sin(angles[0])],[0, np.sin(angles[0]), np.cos(angles[0])]])
    Ry = np.array([[np.cos(angles[1]), 0, np.sin(angles[1])],[0, 1, 0],[-np.sin(angles[1]), 0, np.cos(angles[1])]])
    Rz = np.array([[np.cos(angles[2]), -np.sin(angles[2]), 0],[np.sin(angles[2]), np.cos(angles[2]), 0],[0, 0, 1]])
    xyz = xyz @ (Rx @ Ry @ Rz).T + np.random.normal(scale=0.01, size=xyz.shape)
    return np.concatenate([xyz, color], axis=1)

augmented_input = augment(augmented_input)

# To tensor
original_tensor = torch.from_numpy(original_input).unsqueeze(0).to(device)
augmented_tensor = torch.from_numpy(augmented_input.astype(np.float32)).unsqueeze(0).to(device)

# ----------------------------
#      Forward through model
# ----------------------------
with torch.no_grad():
    _, orig_feats = encoder(original_tensor)
    _, aug_feats = encoder(augmented_tensor)

f1 = orig_feats[layer_id][1].squeeze(0).cpu().numpy()
f2 = aug_feats[layer_id][1].squeeze(0).cpu().numpy()
coords1 = orig_feats[layer_id][0].squeeze(0).cpu().numpy()
coords2 = aug_feats[layer_id][0].squeeze(0).cpu().numpy()

# ----------------------------
#       Run Matching Head
# ----------------------------
f1_tensor = torch.from_numpy(f1).unsqueeze(0).to(device)  # [1, N1, C]
f2_tensor = torch.from_numpy(f2).unsqueeze(0).to(device)  # [1, N2, C]

with torch.no_grad():
    soft_corr, _ = matching_head(f1_tensor, f2_tensor)  # [1, N1, N2], _
    soft_corr = soft_corr.squeeze(0)  # [N1, N2]

top_idx = torch.argmax(soft_corr, dim=1).cpu().numpy()  # [N1]

# Top-K selection
k = 100
significance_1 = np.linalg.norm(f1, axis=1)
significance_2 = np.linalg.norm(f2, axis=1)
top_k_idx = np.argsort(significance_1)[-k:]

# Offset and lines
coords2_offset = coords2.copy()
coords2_offset[:, 0] += 2

lines_x, lines_y, lines_z = [], [], []
for idx in top_k_idx:
    p1, p2 = coords1[idx], coords2_offset[top_idx[idx]]
    lines_x += [p1[0], p2[0], None]
    lines_y += [p1[1], p2[1], None]
    lines_z += [p1[2], p2[2], None]

# ----------------------------
#       Plot with Plotly
# ----------------------------
fig = go.Figure()
fig.add_trace(go.Scatter3d(x=coords1[:, 0], y=coords1[:, 1], z=coords1[:, 2],
    mode='markers', marker=dict(size=3, color=significance_1, colorscale='plasma'), name='Original'))
fig.add_trace(go.Scatter3d(x=coords2_offset[:, 0], y=coords2_offset[:, 1], z=coords2_offset[:, 2],
    mode='markers', marker=dict(size=3, color=significance_2, colorscale='plasma', symbol='cross'), name='Augmented'))
fig.add_trace(go.Scatter3d(x=lines_x, y=lines_y, z=lines_z,
    mode='lines', line=dict(color='gray', width=1), name=f'Top-{k} Matches'))
fig.update_layout(title=f"Top-{k} Correspondences via Matching Head (Layer {layer_id})",
    scene=dict(aspectmode='data'), legend=dict(itemsizing='constant'))
fig.show()

In [14]:
from scipy.spatial.transform import Rotation as R
from sklearn.utils import check_array

def rigid_alignment(src, tgt):
    # Kabsch algorithm: estimate rotation + translation
    src_mean = np.mean(src, axis=0)
    tgt_mean = np.mean(tgt, axis=0)
    src_centered = src - src_mean
    tgt_centered = tgt - tgt_mean

    H = src_centered.T @ tgt_centered
    U, _, Vt = np.linalg.svd(H)
    R_est = Vt.T @ U.T
    if np.linalg.det(R_est) < 0:
        Vt[-1, :] *= -1
        R_est = Vt.T @ U.T
    t_est = tgt_mean - R_est @ src_mean
    return R_est, t_est

# Extract correspondences
src_corr = coords1[top_k_idx]
tgt_corr = coords2[top_idx[top_k_idx]]

# Compute rigid alignment (rotation + translation)
R_est, t_est = rigid_alignment(tgt_corr, src_corr)

# Apply transformation to the entire augmented cloud
coords2_rigid_aligned = (coords2 @ R_est.T) + t_est

# Offset the aligned cloud for side-by-side view
coords2_aligned_offset = coords2_rigid_aligned.copy()
# coords2_aligned_offset[:, 0] += 2

# Plot aligned result
fig = go.Figure()
fig.add_trace(go.Scatter3d(x=coords1[:, 0], y=coords1[:, 1], z=coords1[:, 2],
    mode='markers', marker=dict(size=3, color='red'), name='Original'))
fig.add_trace(go.Scatter3d(x=coords2_aligned_offset[:, 0], y=coords2_aligned_offset[:, 1], z=coords2_aligned_offset[:, 2],
    mode='markers', marker=dict(size=3, color='blue'), name='Aligned Augmented'))

# Re-draw correspondence lines
lines_x, lines_y, lines_z = [], [], []
for idx in top_k_idx:
    p1 = coords1[idx]
    p2 = coords2_aligned_offset[top_idx[idx]]
    lines_x += [p1[0], p2[0], None]
    lines_y += [p1[1], p2[1], None]
    lines_z += [p1[2], p2[2], None]

fig.add_trace(go.Scatter3d(x=lines_x, y=lines_y, z=lines_z,
    mode='lines', line=dict(color='gray', width=1), name='Aligned Matches'))

fig.update_layout(title=f"Top-{k} Matches After Rigid Alignment",
    scene=dict(aspectmode='data'), legend=dict(itemsizing='constant'))
fig.show()