In [1]:
import numpy as np, cv2, torch
print("numpy:", np.__version__)
print("cv2:", cv2.__version__)
print("torch:", torch.__version__)
print("cuda:", torch.version.cuda)
print("gpu:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else None)

numpy: 2.2.6
cv2: 4.13.0
torch: 2.10.0+cu128
cuda: 12.8
gpu: NVIDIA RTX A6000


In [3]:
# ============================================================
# Batch RGB-TIR Registration (SuperGlue + RANSAC + Optional Undistort)
# Outputs (same filename as input):
#  - out_root/warped_rgb/<name>
#  - out_root/overlay/<name>
#  - out_root/heatmap_inliers/<name>
#  - out_root/info_txt/<stem>.txt
# Also writes: out_root/SUMMARY.txt
# ============================================================

# If needed (run once):
# !pip -q install opencv-python transformers torch torchvision pillow matplotlib tqdm

import os
from pathlib import Path
import numpy as np
import cv2
import torch
from PIL import Image
from tqdm import tqdm
from transformers import AutoImageProcessor, SuperGlueForKeypointMatching

# -----------------------------
# OST reader (optional undistort)
# -----------------------------
def read_ost(ost_path: str):
    lines = Path(ost_path).read_text(encoding="utf-8", errors="ignore").splitlines()

    def find_line(name):
        for i, l in enumerate(lines):
            if l.strip().lower() == name.strip().lower():
                return i
        return None

    iK = find_line("camera matrix")
    iD = find_line("distortion")
    if iK is None or iD is None:
        raise ValueError(f"No encontré 'camera matrix' o 'distortion' en {ost_path}")

    K = np.array([
        [float(x) for x in lines[iK+1].split()],
        [float(x) for x in lines[iK+2].split()],
        [float(x) for x in lines[iK+3].split()],
    ], dtype=np.float64)

    dist = np.array([float(x) for x in lines[iD+1].split()], dtype=np.float64).reshape(1, -1)
    return K, dist

def undistort_if_available(img_bgr, K=None, dist=None, alpha=0.0):
    if K is None or dist is None:
        return img_bgr
    h, w = img_bgr.shape[:2]
    newK, _ = cv2.getOptimalNewCameraMatrix(K, dist, (w, h), alpha, (w, h))
    return cv2.undistort(img_bgr, K, dist, None, newK)

# -----------------------------
# Cross-modal preprocessing
# -----------------------------
def to_gray(img_bgr):
    return cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) if img_bgr.ndim == 3 else img_bgr

def normalize_u8(img):
    img = img.astype(np.float32)
    img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX)
    return img.astype(np.uint8)

def clahe_u8(img_u8, clip=3.0, grid=(16, 16)):
    c = cv2.createCLAHE(clipLimit=clip, tileGridSize=grid)
    return c.apply(img_u8)

def gradient_mag_u8(img_u8):
    gx = cv2.Sobel(img_u8, cv2.CV_32F, 1, 0, ksize=3)
    gy = cv2.Sobel(img_u8, cv2.CV_32F, 0, 1, ksize=3)
    mag = cv2.magnitude(gx, gy)
    mag = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
    return mag.astype(np.uint8)

def preprocess_cross_modal(img_bgr, do_grad=True):
    g = to_gray(img_bgr)
    g = normalize_u8(g)
    g = clahe_u8(g, clip=3.0, grid=(16, 16))
    if do_grad:
        g = gradient_mag_u8(g)
    return g

def as_pil_rgb(img_bgr_or_u8):
    if img_bgr_or_u8.ndim == 2:
        img_rgb = cv2.cvtColor(img_bgr_or_u8, cv2.COLOR_GRAY2RGB)
    else:
        img_rgb = cv2.cvtColor(img_bgr_or_u8, cv2.COLOR_BGR2RGB)
    return Image.fromarray(img_rgb)

# -----------------------------
# SuperGlue (loaded once)
# -----------------------------
class SuperGlueHF:
    def __init__(self, ckpt="magic-leap-community/superglue_outdoor", device=None):
        self.processor = AutoImageProcessor.from_pretrained(ckpt)
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.model = SuperGlueForKeypointMatching.from_pretrained(ckpt).to(self.device)
        self.model.eval()

    @torch.no_grad()
    def match(self, img0_pil, img1_pil, threshold=0.35, max_matches=1500):
        images = [[img0_pil, img1_pil]]
        inputs = self.processor(images, return_tensors="pt").to(self.device)
        outputs = self.model(**inputs)

        image_sizes = [[(img.height, img.width) for img in pair] for pair in images]
        results = self.processor.post_process_keypoint_matching(outputs, image_sizes, threshold=threshold)
        r = results[0]

        k0 = r["keypoints0"].detach().cpu().numpy().astype(np.float32)  # img0
        k1 = r["keypoints1"].detach().cpu().numpy().astype(np.float32)  # img1
        sc = r["matching_scores"].detach().cpu().numpy().astype(np.float32)

        if len(sc) > max_matches:
            idx = np.argsort(-sc)[:max_matches]
            k0, k1, sc = k0[idx], k1[idx], sc[idx]

        return k0, k1, sc

# -----------------------------
# Robust estimation + QC
# -----------------------------
def estimate_affine_ransac(k_src, k_dst, thr=5.0):
    if k_src is None or len(k_src) < 8:
        return None, 0.0, None
    M, inliers = cv2.estimateAffinePartial2D(
        k_src, k_dst, method=cv2.RANSAC,
        ransacReprojThreshold=thr, maxIters=5000, confidence=0.995
    )
    if M is None or inliers is None:
        return None, 0.0, None
    return M.astype(np.float32), float(inliers.mean()), inliers.astype(bool).ravel()

def estimate_homography_ransac(k_src, k_dst, thr=5.0):
    if k_src is None or len(k_src) < 4:
        return None, 0.0, None
    H, inliers = cv2.findHomography(
        k_src, k_dst, method=cv2.RANSAC,
        ransacReprojThreshold=thr, maxIters=5000, confidence=0.995
    )
    if H is None or inliers is None:
        return None, 0.0, None
    return H.astype(np.float32), float(inliers.mean()), inliers.astype(bool).ravel()

def reprojection_error_homography(H, k_src, k_dst, inlier_mask=None):
    if H is None:
        return np.inf, np.inf
    if inlier_mask is None:
        inlier_mask = np.ones(len(k_src), dtype=bool)
    src = k_src[inlier_mask]
    dst = k_dst[inlier_mask]
    if len(src) < 4:
        return np.inf, np.inf
    src_h = cv2.convertPointsToHomogeneous(src).reshape(-1, 3)
    proj = (H @ src_h.T).T
    proj = proj[:, :2] / proj[:, 2:3]
    err = np.linalg.norm(proj - dst, axis=1)
    return float(np.mean(err)), float(np.median(err))

def qc_pass(num_matches, inlier_ratio, reproj_median,
            min_matches=40, min_inlier=0.40, max_median_err=4.0):
    return (num_matches >= min_matches) and (inlier_ratio >= min_inlier) and (reproj_median <= max_median_err)

# -----------------------------
# Heatmap utilities
# -----------------------------
def make_keypoint_heatmap(h, w, pts_xy, sigma=18, weights=None):
    heat = np.zeros((h, w), dtype=np.float32)
    if pts_xy is None or len(pts_xy) == 0:
        return heat

    pts = np.round(pts_xy).astype(int)
    if weights is None:
        weights = np.ones((len(pts),), dtype=np.float32)
    else:
        weights = weights.astype(np.float32)

    for (x, y), wgt in zip(pts, weights):
        if 0 <= x < w and 0 <= y < h:
            heat[y, x] += float(wgt)

    k = int(6*sigma + 1)
    if k % 2 == 0:
        k += 1
    heat = cv2.GaussianBlur(heat, (k, k), sigmaX=sigma, sigmaY=sigma)

    if heat.max() > 0:
        heat = heat / heat.max()
    return heat

def overlay_heatmap_on_image(img_bgr, heat01, alpha=0.45):
    heat_u8 = (heat01 * 255).astype(np.uint8)
    heat_color = cv2.applyColorMap(heat_u8, cv2.COLORMAP_JET)  # BGR
    out = cv2.addWeighted(img_bgr, 1 - alpha, heat_color, alpha, 0)
    return out, heat_color

# -----------------------------
# Per-pair processing (save outputs)
# -----------------------------
def process_one_pair(
    tir_path: Path,
    rgb_path: Path,
    out_warped_dir: Path,
    out_overlay_dir: Path,
    out_heat_dir: Path,
    out_txt_dir: Path,
    matcher: SuperGlueHF,
    K_tir=None, d_tir=None, K_rgb=None, d_rgb=None,
    score_thr=0.35, ransac_thr=5.0,
    min_matches=40, min_inlier=0.40, max_median_err=4.0,
):
    name = rgb_path.name
    stem = rgb_path.stem

    tir_bgr = cv2.imread(str(tir_path), cv2.IMREAD_COLOR)
    rgb_bgr = cv2.imread(str(rgb_path), cv2.IMREAD_COLOR)
    if tir_bgr is None or rgb_bgr is None:
        raise ValueError(f"Error leyendo: {tir_path} o {rgb_path}")

    # undistort optional
    if K_tir is not None and d_tir is not None:
        tir_bgr = undistort_if_available(tir_bgr, K_tir, d_tir, alpha=0.0)
    if K_rgb is not None and d_rgb is not None:
        rgb_bgr = undistort_if_available(rgb_bgr, K_rgb, d_rgb, alpha=0.0)

    # preprocess
    tir_p = preprocess_cross_modal(tir_bgr, do_grad=True)
    rgb_p = preprocess_cross_modal(rgb_bgr, do_grad=True)
    tir_pil = as_pil_rgb(tir_p)
    rgb_pil = as_pil_rgb(rgb_p)

    # match (keypoints0=tIR, keypoints1=RGB)
    k_tir, k_rgb, sc = matcher.match(tir_pil, rgb_pil, threshold=score_thr, max_matches=1500)
    num_matches = int(len(sc))

    # warp RGB->TIR : src=RGB, dst=TIR
    k0 = k_rgb
    k1 = k_tir

    M_aff, inl_aff, mask_aff = estimate_affine_ransac(k0, k1, thr=ransac_thr)
    H, inl_h, mask_h = estimate_homography_ransac(k0, k1, thr=ransac_thr)
    mean_err, med_err = reprojection_error_homography(H, k0, k1, mask_h)

    ok_H = qc_pass(num_matches, inl_h, med_err, min_matches, min_inlier, max_median_err)
    ok_A = (num_matches >= min_matches) and (inl_aff >= min_inlier)

    if ok_A and ok_H:
        use_H = (med_err <= max_median_err * 0.7) and (inl_h > inl_aff + 0.05)
    elif ok_H:
        use_H = True
    else:
        use_H = False

    accepted = bool(ok_A or ok_H)

    # warp
    h, w = tir_bgr.shape[:2]
    warped_rgb = None
    model_used = "NONE"
    if use_H and H is not None:
        warped_rgb = cv2.warpPerspective(rgb_bgr, H, (w, h), flags=cv2.INTER_LINEAR)
        model_used = "H"
    elif M_aff is not None:
        warped_rgb = cv2.warpAffine(rgb_bgr, M_aff, (w, h), flags=cv2.INTER_LINEAR)
        model_used = "A"

    # overlay
    tir_gray = cv2.cvtColor(tir_bgr, cv2.COLOR_BGR2GRAY)
    tir_vis = cv2.cvtColor(tir_gray, cv2.COLOR_GRAY2BGR)
    overlay = cv2.addWeighted(tir_vis, 0.5, warped_rgb, 0.5, 0.0) if warped_rgb is not None else tir_vis

    # heatmap (H inliers in TIR space)
    if (mask_h is not None) and (len(k1) > 0) and (len(mask_h) == len(k1)):
        pts_inl = k1[mask_h]
        w_inl = sc[mask_h] if len(sc) == len(mask_h) else None
    else:
        pts_inl = np.zeros((0, 2), dtype=np.float32)
        w_inl = None

    heat_inl = make_keypoint_heatmap(h, w, pts_inl, sigma=18, weights=w_inl)
    heat_overlay_inl, _ = overlay_heatmap_on_image(overlay.copy(), heat_inl, alpha=0.45)

    # Save images with SAME name
    out_warp_path = out_warped_dir / name
    out_ovr_path = out_overlay_dir / name
    out_heat_path = out_heat_dir / name
    out_txt_path = out_txt_dir / f"{stem}.txt"

    # Keep filename parity even if warp fails (black placeholder)
    if warped_rgb is None:
        warped_rgb = np.zeros((h, w, 3), dtype=np.uint8)

    cv2.imwrite(str(out_warp_path), warped_rgb)
    cv2.imwrite(str(out_ovr_path), overlay)
    cv2.imwrite(str(out_heat_path), heat_overlay_inl)

    # Write TXT (same stem)
    txt = []
    txt.append(f"name: {name}")
    txt.append(f"rgb_path: {rgb_path}")
    txt.append(f"tir_path: {tir_path}")
    txt.append("---- thresholds ----")
    txt.append(f"score_thr: {score_thr}")
    txt.append(f"ransac_thr: {ransac_thr}")
    txt.append(f"min_matches: {min_matches}")
    txt.append(f"min_inlier: {min_inlier}")
    txt.append(f"max_median_err: {max_median_err}")
    txt.append("---- metrics ----")
    txt.append(f"num_matches: {num_matches}")
    txt.append(f"inlier_aff: {inl_aff:.6f}")
    txt.append(f"inlier_H: {inl_h:.6f}")
    txt.append(f"H_median_err_px: {med_err:.6f}")
    txt.append(f"ok_A: {ok_A}")
    txt.append(f"ok_H: {ok_H}")
    txt.append(f"accepted: {accepted}")
    txt.append(f"model_used_for_warp: {model_used}")
    txt.append("---- outputs ----")
    txt.append(f"warped_rgb: {out_warp_path}")
    txt.append(f"overlay: {out_ovr_path}")
    txt.append(f"heatmap_inliers: {out_heat_path}")
    out_txt_path.write_text("\n".join(txt), encoding="utf-8")

    return accepted, model_used, num_matches, inl_h, med_err

# -----------------------------
# Batch runner
# -----------------------------
def run_batch(
    rgb_dir: str,
    tir_dir: str,
    out_root: str,
    ost_rgb: str = None,
    ost_tir: str = None,
    score_thr: float = 0.35,
    ransac_thr: float = 5.0,
    min_matches: int = 40,
    min_inlier: float = 0.40,
    max_median_err: float = 4.0,
    exts=(".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"),
):
    rgb_dir = Path(rgb_dir)
    tir_dir = Path(tir_dir)
    out_root = Path(out_root)

    out_warped_dir = out_root / "warped_rgb"
    out_overlay_dir = out_root / "overlay"
    out_heat_dir = out_root / "heatmap_inliers"
    out_txt_dir = out_root / "info_txt"

    for d in (out_warped_dir, out_overlay_dir, out_heat_dir, out_txt_dir):
        d.mkdir(parents=True, exist_ok=True)

    # optional intrinsics
    K_rgb = d_rgb = K_tir = d_tir = None
    if ost_rgb and Path(ost_rgb).exists():
        K_rgb, d_rgb = read_ost(ost_rgb)
    if ost_tir and Path(ost_tir).exists():
        K_tir, d_tir = read_ost(ost_tir)

    # list RGB files
    rgb_files = []
    for ext in exts:
        rgb_files.extend(rgb_dir.glob(f"*{ext}"))
    rgb_files = sorted(rgb_files)

    if not rgb_files:
        raise ValueError(f"No encontré imágenes en {rgb_dir}")

    matcher = SuperGlueHF()

    missing = 0
    processed = 0
    accepted_count = 0
    used_H = 0
    used_A = 0
    failed = 0

    for rgb_path in tqdm(rgb_files, desc="Processing pairs"):
        tir_path = tir_dir / rgb_path.name
        if not tir_path.exists():
            missing += 1
            # write txt for missing
            (out_txt_dir / f"{rgb_path.stem}.txt").write_text(
                f"name: {rgb_path.name}\nstatus: MISSING_TIR\nexpected_tir_path: {tir_path}\n",
                encoding="utf-8"
            )
            continue

        try:
            accepted, model_used, num_matches, inl_h, med_err = process_one_pair(
                tir_path=tir_path,
                rgb_path=rgb_path,
                out_warped_dir=out_warped_dir,
                out_overlay_dir=out_overlay_dir,
                out_heat_dir=out_heat_dir,
                out_txt_dir=out_txt_dir,
                matcher=matcher,
                K_tir=K_tir, d_tir=d_tir,
                K_rgb=K_rgb, d_rgb=d_rgb,
                score_thr=score_thr,
                ransac_thr=ransac_thr,
                min_matches=min_matches,
                min_inlier=min_inlier,
                max_median_err=max_median_err,
            )
            processed += 1
            if accepted:
                accepted_count += 1
            if model_used == "H":
                used_H += 1
            elif model_used == "A":
                used_A += 1
        except Exception as e:
            failed += 1
            (out_txt_dir / f"{rgb_path.stem}.txt").write_text(
                f"name: {rgb_path.name}\nstatus: ERROR\nerror: {repr(e)}\n",
                encoding="utf-8"
            )

    summary = []
    summary.append(f"rgb_dir: {rgb_dir}")
    summary.append(f"tir_dir: {tir_dir}")
    summary.append(f"total_rgb_files: {len(rgb_files)}")
    summary.append(f"missing_tir: {missing}")
    summary.append(f"processed_pairs: {processed}")
    summary.append(f"errors: {failed}")
    summary.append(f"accepted_pairs: {accepted_count}")
    summary.append(f"accept_rate: {accepted_count/max(1,processed):.4f}")
    summary.append(f"warps_used_H: {used_H}")
    summary.append(f"warps_used_A: {used_A}")
    summary.append("---- thresholds ----")
    summary.append(f"score_thr: {score_thr}")
    summary.append(f"ransac_thr: {ransac_thr}")
    summary.append(f"min_matches: {min_matches}")
    summary.append(f"min_inlier: {min_inlier}")
    summary.append(f"max_median_err: {max_median_err}")
    (out_root / "SUMMARY.txt").write_text("\n".join(summary), encoding="utf-8")

    return {
        "out_root": str(out_root),
        "warped_rgb_dir": str(out_warped_dir),
        "overlay_dir": str(out_overlay_dir),
        "heatmap_inliers_dir": str(out_heat_dir),
        "info_txt_dir": str(out_txt_dir),
        "summary_file": str(out_root / "SUMMARY.txt"),
        "total_rgb_files": len(rgb_files),
        "missing_tir": missing,
        "processed_pairs": processed,
        "errors": failed,
        "accepted_pairs": accepted_count,
        "accept_rate": accepted_count/max(1,processed),
    }

# ============================================================
# CONFIG: EDIT THESE PATHS
# ============================================================
RGB_DIR  = "/workspace/dataset/dataset_2025/rgb_org/"
TIR_DIR  = "/workspace/dataset/dataset_2025/tir_v2/"
OUT_ROOT = "/workspace/dataset/dataset_2025/final_folder_II/"

# Optional (set to None if you don't have intrinsics files):
OST_RGB = "/workspace/ost_rgb.txt"  # e.g. "/content/ost_rgb.txt"
OST_TIR = "/workspace/ost_tir.txt"  # e.g. "/content/ost_tir.txt"

# Thresholds (use your current settings):
SCORE_THR = 0.35
RANSAC_THR = 5.0
MIN_MATCHES = 40
MIN_INLIER = 0.40
MAX_MED_ERR = 4.0

# ============================================================
# RUN
# ============================================================
results = run_batch(
    rgb_dir=RGB_DIR,
    tir_dir=TIR_DIR,
    out_root=OUT_ROOT,
    ost_rgb=OST_RGB,
    ost_tir=OST_TIR,
    score_thr=SCORE_THR,
    ransac_thr=RANSAC_THR,
    min_matches=MIN_MATCHES,
    min_inlier=MIN_INLIER,
    max_median_err=MAX_MED_ERR,
)

print("DONE")
print(results)


Loading weights: 100%|██████████| 363/363 [00:00<00:00, 489.26it/s, Materializing param=keypoint_encoder.encoder.4.weight]                            
Processing pairs: 100%|██████████| 500/500 [04:21<00:00,  1.91it/s]

DONE
{'out_root': '/workspace/dataset/dataset_2025/final_folder_II', 'warped_rgb_dir': '/workspace/dataset/dataset_2025/final_folder_II/warped_rgb', 'overlay_dir': '/workspace/dataset/dataset_2025/final_folder_II/overlay', 'heatmap_inliers_dir': '/workspace/dataset/dataset_2025/final_folder_II/heatmap_inliers', 'info_txt_dir': '/workspace/dataset/dataset_2025/final_folder_II/info_txt', 'summary_file': '/workspace/dataset/dataset_2025/final_folder_II/SUMMARY.txt', 'total_rgb_files': 500, 'missing_tir': 0, 'processed_pairs': 500, 'errors': 0, 'accepted_pairs': 365, 'accept_rate': 0.73}



