In [1]:
%load_ext autoreload
%autoreload 2

from mast3r.model import AsymmetricMASt3R
from mast3r.fast_nn import fast_reciprocal_NNs
import os
import numpy as np
import trimesh
import copy
from scipy.spatial.transform import Rotation
import tempfile
import shutil
import torch
import glob
from mast3r.cloud_opt.sparse_ga import sparse_global_alignment
from mast3r.cloud_opt.tsdf_optimizer import TSDFPostProcess
from mast3r.image_pairs import make_pairs
from mast3r.retrieval.processor import Retriever
from mast3r.utils.misc import mkdir_for
from cust3r.utils.image import load_images
from dust3r.dust3r.utils.device import to_numpy
from dust3r.dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
from dust3r.dust3r.demo import get_args_parser as dust3r_get_args_parser
import matplotlib.pyplot as pl
import imageio.v2 as iio
from cust3r.utils.camera import pose_encoding_to_camera
from cust3r.post_process import estimate_focal_knowing_depth
from cust3r.utils.geometry import geotrf
from cust3r.model import ARCroco3DStereo
from cust3r.inference import inference as inference_cust3r
import time
from boq.boq_infer import get_trained_boq, boq_sort_topk
import json

In [None]:
def _convert_scene_output_to_glb(imgs, pts3d, mask, focals, cams2world, cam_size=0.05,
                                 cam_color=None, as_pointcloud=False,
                                 transparent_cams=False, silent=False):
    assert len(pts3d) == len(mask) <= len(imgs) <= len(cams2world) == len(focals)
    pts3d = to_numpy(pts3d)
    imgs = to_numpy(imgs)
    focals = to_numpy(focals)
    cams2world = to_numpy(cams2world)
    scene = trimesh.Scene()
    # full pointcloud
    if as_pointcloud:
        pts = np.concatenate([p[m.ravel()] for p, m in zip(pts3d, mask)]).reshape(-1, 3)
        col = np.concatenate([p[m] for p, m in zip(imgs, mask)]).reshape(-1, 3)
        valid_msk = np.isfinite(pts.sum(axis=1))
        pct = trimesh.PointCloud(pts[valid_msk], colors=col[valid_msk])
        scene.add_geometry(pct)
    else:
        meshes = []
        for i in range(len(imgs)):
            pts3d_i = pts3d[i].reshape(imgs[i].shape)
            msk_i = mask[i] & np.isfinite(pts3d_i.sum(axis=-1))
            meshes.append(pts3d_to_trimesh(imgs[i], pts3d_i, msk_i))
        mesh = trimesh.Trimesh(**cat_meshes(meshes))
        scene.add_geometry(mesh)
    # add each camera
    for i, pose_c2w in enumerate(cams2world):
        if isinstance(cam_color, list):
            camera_edge_color = cam_color[i]
        else:
            camera_edge_color = cam_color or CAM_COLORS[i % len(CAM_COLORS)]
        add_scene_cam(scene, pose_c2w, camera_edge_color,
                      None if transparent_cams else imgs[i], focals[i],
                      imsize=imgs[i].shape[1::-1], screen_width=cam_size)
    rot = np.eye(4)
    rot[:3, :3] = Rotation.from_euler('y', np.deg2rad(180)).as_matrix()
    scene.apply_transform(np.linalg.inv(cams2world[0] @ OPENGL @ rot))
    return scene

def get_3D_model_from_scene(silent, scene, min_conf_thr=2, as_pointcloud=False, mask_sky=False,
                            clean_depth=False, transparent_cams=False, cam_size=0.05, TSDF_thresh=0):
    """
    extract 3D_model (glb file) from a reconstructed scene
    """
    # get optimized values from scene
    scene = scene
    rgbimg = scene.imgs
    focals = scene.get_focals().cpu()
    cams2world = scene.get_im_poses().cpu()
    # 3D pointcloud from depthmap, poses and intrinsics
    if TSDF_thresh > 0:
        tsdf = TSDFPostProcess(scene, TSDF_thresh=TSDF_thresh)
        pts3d, _, confs = to_numpy(tsdf.get_dense_pts3d(clean_depth=clean_depth))
    else:
        pts3d, _, confs = to_numpy(scene.get_dense_pts3d(clean_depth=clean_depth))
    msk = to_numpy([c > min_conf_thr for c in confs])
    return _convert_scene_output_to_glb(rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
                                        transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)
    

def get_reconstructed_scene(model, device, filelist,
                            cache_path,
                            retrieval_model = None,
                            silent = False,
                            optim_level = "refine+depth",
                            lr1 = 0.07, niter1 = 200, lr2 = 0.01, niter2 = 200,
                            min_conf_thr = 1.5,
                            matching_conf_thr = 0.0,
                            as_pointcloud = True, mask_sky = False, clean_depth =True, transparent_cams = False, cam_size = 0.2,
                            scenegraph_type = "complete", winsize=1, win_cyclic=False, refid=0,
                            TSDF_thresh=0.0, shared_intrinsics= False,
                            **kw):
    """
    from a list of images, run mast3r inference, sparse global aligner.
    then run get_3D_model_from_scene
    """
    imgs, imgs_id_dict = load_images(filelist, size=512, verbose=not silent)
    if len(imgs) == 1:
        imgs = [imgs[0], copy.deepcopy(imgs[0])]
        imgs[1]['idx'] = 1
        filelist = [filelist[0], filelist[0] + '_2']
    scene_graph_params = [scenegraph_type]
    if scenegraph_type in ["swin", "logwin"]:
        scene_graph_params.append(str(winsize))
    elif scenegraph_type == "oneref":
        scene_graph_params.append(str(refid))
    elif scenegraph_type == "retrieval":
        scene_graph_params.append(str(winsize))  # Na
        scene_graph_params.append(str(refid))  # k
    if scenegraph_type in ["swin", "logwin"] and not win_cyclic:
        scene_graph_params.append('noncyclic')
    scene_graph = '-'.join(scene_graph_params)
    sim_matrix = None
    if 'retrieval' in scenegraph_type:
        assert retrieval_model is not None
        retriever = Retriever(retrieval_model, backbone=model, device=device)
        with torch.no_grad():
            sim_matrix = retriever(filelist)
        # Cleanup
        del retriever
        torch.cuda.empty_cache()
    boq_topks = None
    if 'boq' in scenegraph_type:
        with open(os.path.join(cache_path, "boq_topk.json"), "r", encoding="utf-8") as f:
            boq_topks = json.load(f)
    pairs = make_pairs(imgs, imgs_id_dict, scene_graph=scene_graph, prefilter=None, symmetrize=False, sim_mat=sim_matrix, boq_topk_dict=boq_topks)
    if optim_level == 'coarse':
        niter2 = 0
    # Sparse GA (forward mast3r -> matching -> 3D optim -> 2D refinement -> triangulation)
    scene = sparse_global_alignment(filelist, pairs, cache_path,
                                    model, lr1=lr1, niter1=niter1, lr2=lr2, niter2=niter2, device=device,
                                    opt_depth='depth' in optim_level, shared_intrinsics=shared_intrinsics,
                                    matching_conf_thr=matching_conf_thr, **kw)
    trimesh_scene = get_3D_model_from_scene(silent, scene, min_conf_thr, as_pointcloud, mask_sky,
                                      clean_depth, transparent_cams, cam_size, TSDF_thresh)
    return trimesh_scene

In [3]:
device = 'cuda:0'
model = AsymmetricMASt3R.from_pretrained("ckpts/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth").to(device)
image_list = []
for filename in glob.glob('data/image-matching-challenge-2025/train/ETs2/*.png'): #assuming gif
    image_list.append(filename)

boq_model = get_trained_boq(backbone_name="dinov2", output_dim=12288).to(device)
boq_model.eval()
boq_topks = boq_sort_topk(image_list, boq_model, device, vis=False)

os.makedirs("outputs/ETs2", exist_ok=True)
with open(os.path.join("outputs/ETs2", "boq_topk.json"), "w", encoding="utf-8") as f:
    json.dump(boq_topks, f, ensure_ascii=False, indent=4)
    
trimesh_scene = get_reconstructed_scene(model, device, image_list, "outputs/ETs2")
del model, boq_model
torch.cuda.empty_cache()
trimesh_scene.show()

... loading model from ckpts/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth
instantiating : AsymmetricMASt3R(enc_depth=24, dec_depth=12, enc_embed_dim=1024, dec_embed_dim=768, enc_num_heads=16, dec_num_heads=12, pos_embed='RoPE100',img_size=(512, 512), head_type='catmlp+dpt', output_mode='pts3d+desc24', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), patch_embed_cls='PatchEmbedDust3R', two_confs=True, desc_conf_mode=('exp', 0, inf), landscape_only=False)
_IncompatibleKeys(missing_keys=[], unexpected_keys=['mask_token'])


Using cache found in /root/.cache/torch/hub/facebookresearch_dinov2_main
100%|██████████| 10/10 [00:00<00:00, 30.94it/s]


find topk time: 0.02052903175354004 s
>> Loading a list of 10 images
 - adding data/image-matching-challenge-2025/train/ETs2/another_et_another_et006.png with resolution 360x640 --> 224x224
 - adding data/image-matching-challenge-2025/train/ETs2/another_et_another_et002.png with resolution 360x640 --> 224x224
 - adding data/image-matching-challenge-2025/train/ETs2/another_et_another_et004.png with resolution 360x640 --> 224x224
 - adding data/image-matching-challenge-2025/train/ETs2/another_et_another_et001.png with resolution 360x640 --> 224x224
 - adding data/image-matching-challenge-2025/train/ETs2/another_et_another_et008.png with resolution 360x640 --> 224x224
 - adding data/image-matching-challenge-2025/train/ETs2/another_et_another_et010.png with resolution 360x640 --> 224x224
 - adding data/image-matching-challenge-2025/train/ETs2/another_et_another_et003.png with resolution 360x640 --> 224x224
 - adding data/image-matching-challenge-2025/train/ETs2/another_et_another_et009.png

100%|██████████| 45/45 [00:05<00:00,  7.63it/s]
100%|██████████| 10/10 [00:00<00:00, 28.17it/s]


clusters = [1 1 1 1 1 1 1 1 1 1]
clusters_dict = {1: ['another_et_another_et006', 'another_et_another_et002', 'another_et_another_et004', 'another_et_another_et001', 'another_et_another_et008', 'another_et_another_et010', 'another_et_another_et003', 'another_et_another_et009', 'another_et_another_et005', 'another_et_another_et007']}
init focals = [374.71582 396.9805  430.77487 336.8841  386.34192 389.72458 473.49463
 371.7361  391.14825 386.2605 ]


100%|██████████| 200/200 [00:04<00:00, 41.85it/s, lr=0.0000, loss=0.215]


>> final loss = 0.2149742990732193


100%|██████████| 200/200 [00:07<00:00, 28.10it/s, lr=0.0000, loss=0.949]


>> final loss = 0.9493509531021118
Final focals = [389.8515  388.7524  380.88437 397.7855  386.68158 381.34103 383.22336
 383.318   389.00848 388.06528]


In [6]:
def make_all_image_same_size(images):
    if len(images) == 0:
        return
    # Get the size of the first image
    first_image_size = images[0]["img"].shape[-2:]
    first_image_true_shape = images[0]["true_shape"]
    for i in range(1, len(images)):
        if images[i]["img"].shape[-2:] == first_image_size:
            continue
        # Resize the image to match the first image size
        images[i]["img"] = torch.nn.functional.interpolate(
            images[i]["img"],
            size=first_image_size,
            mode="bilinear",
            align_corners=False,
        )
        images[i]["true_shape"] = first_image_true_shape.copy()
    
    

def prepare_input(
    img_paths, img_mask, size, raymaps=None, raymap_mask=None, revisit=1, update=True
):
    """
    Prepare input views for inference from a list of image paths.

    Args:
        img_paths (list): List of image file paths.
        img_mask (list of bool): Flags indicating valid images.
        size (int): Target image size.
        raymaps (list, optional): List of ray maps.
        raymap_mask (list, optional): Flags indicating valid ray maps.
        revisit (int): How many times to revisit each view.
        update (bool): Whether to update the state on revisits.

    Returns:
        list: A list of view dictionaries.
    """

    images = load_images(img_paths, size=size)
    make_all_image_same_size(images)
    views = []

    if raymaps is None and raymap_mask is None:
        # Only images are provided.
        for i in range(len(images)):
            view = {
                "img": images[i]["img"],
                "ray_map": torch.full(
                    (
                        images[i]["img"].shape[0],
                        6,
                        images[i]["img"].shape[-2],
                        images[i]["img"].shape[-1],
                    ),
                    torch.nan,
                ),
                "true_shape": torch.from_numpy(images[i]["true_shape"]),
                "idx": i,
                "instance": str(i),
                "camera_pose": torch.from_numpy(np.eye(4, dtype=np.float32)).unsqueeze(
                    0
                ),
                "img_mask": torch.tensor(True).unsqueeze(0),
                "ray_mask": torch.tensor(False).unsqueeze(0),
                "update": torch.tensor(True).unsqueeze(0),
                "reset": torch.tensor(False).unsqueeze(0),
            }
            views.append(view)
    else:
        # Combine images and raymaps.
        num_views = len(images) + len(raymaps)
        assert len(img_mask) == len(raymap_mask) == num_views
        assert sum(img_mask) == len(images) and sum(raymap_mask) == len(raymaps)

        j = 0
        k = 0
        for i in range(num_views):
            view = {
                "img": (
                    images[j]["img"]
                    if img_mask[i]
                    else torch.full_like(images[0]["img"], torch.nan)
                ),
                "ray_map": (
                    raymaps[k]
                    if raymap_mask[i]
                    else torch.full_like(raymaps[0], torch.nan)
                ),
                "true_shape": (
                    torch.from_numpy(images[j]["true_shape"])
                    if img_mask[i]
                    else torch.from_numpy(np.int32([raymaps[k].shape[1:-1][::-1]]))
                ),
                "idx": i,
                "instance": str(i),
                "camera_pose": torch.from_numpy(np.eye(4, dtype=np.float32)).unsqueeze(
                    0
                ),
                "img_mask": torch.tensor(img_mask[i]).unsqueeze(0),
                "ray_mask": torch.tensor(raymap_mask[i]).unsqueeze(0),
                "update": torch.tensor(img_mask[i]).unsqueeze(0),
                "reset": torch.tensor(False).unsqueeze(0),
            }
            if img_mask[i]:
                j += 1
            if raymap_mask[i]:
                k += 1
            views.append(view)
        assert j == len(images) and k == len(raymaps)

    if revisit > 1:
        new_views = []
        for r in range(revisit):
            for i, view in enumerate(views):
                new_view = copy.deepcopy(view)
                new_view["idx"] = r * len(views) + i
                new_view["instance"] = str(r * len(views) + i)
                if r > 0 and not update:
                    new_view["update"] = torch.tensor(False).unsqueeze(0)
                new_views.append(new_view)
        return new_views

    return views

def prepare_output(outputs, outdir, revisit=1, use_pose=True):
    """
    Process inference outputs to generate point clouds and camera parameters for visualization.
    Args:
        outputs (dict): Inference outputs.
        revisit (int): Number of revisits per view.
        use_pose (bool): Whether to transform points using camera pose.
    Returns:
        tuple: (points, colors, confidence, camera parameters dictionary)
    """
    # Only keep the outputs corresponding to one full pass.
    valid_length = len(outputs["pred"]) // revisit
    outputs["pred"] = outputs["pred"][-valid_length:]
    outputs["views"] = outputs["views"][-valid_length:]
    pts3ds_self_ls = [output["pts3d_in_self_view"].cpu() for output in outputs["pred"]]
    pts3ds_other = [output["pts3d_in_other_view"].cpu() for output in outputs["pred"]]
    conf_self = [output["conf_self"].cpu() for output in outputs["pred"]]
    conf_other = [output["conf"].cpu() for output in outputs["pred"]]
    pts3ds_self = torch.cat(pts3ds_self_ls, 0)
    # Recover camera poses.
    pr_poses = [
        pose_encoding_to_camera(pred["camera_pose"].clone()).cpu()
        for pred in outputs["pred"]
    ]
    R_c2w = torch.cat([pr_pose[:, :3, :3] for pr_pose in pr_poses], 0)
    t_c2w = torch.cat([pr_pose[:, :3, 3] for pr_pose in pr_poses], 0)
    if use_pose:
        transformed_pts3ds_other = []
        for pose, pself in zip(pr_poses, pts3ds_self):
            transformed_pts3ds_other.append(geotrf(pose, pself.unsqueeze(0)))
        pts3ds_other = transformed_pts3ds_other
        conf_other = conf_self
    # Estimate focal length based on depth.
    B, H, W, _ = pts3ds_self.shape
    pp = torch.tensor([W // 2, H // 2], device=pts3ds_self.device).float().repeat(B, 1)
    focal = estimate_focal_knowing_depth(pts3ds_self, pp, focal_mode="weiszfeld")
    colors = [
        0.5 * (output["img"].permute(0, 2, 3, 1) + 1.0) for output in outputs["views"]
    ]
    cam_dict = {
        "focal": focal.cpu().numpy(),
        "pp": pp.cpu().numpy(),
        "R": R_c2w.cpu().numpy(),
        "t": t_c2w.cpu().numpy(),
        "cams2world": torch.cat(pr_poses).cpu().numpy(),
    }
    pts3ds_self_tosave = pts3ds_self  # B, H, W, 3
    depths_tosave = pts3ds_self_tosave[..., 2]
    pts3ds_other_tosave = torch.cat(pts3ds_other)  # B, H, W, 3
    conf_self_tosave = torch.cat(conf_self)  # B, H, W
    conf_other_tosave = torch.cat(conf_other)  # B, H, W
    colors_tosave = torch.cat(
        [
            0.5 * (output["img"].permute(0, 2, 3, 1).cpu() + 1.0)
            for output in outputs["views"]
        ]
    )  # [B, H, W, 3]
    cam2world_tosave = torch.cat(pr_poses)  # B, 4, 4
    intrinsics_tosave = (
        torch.eye(3).unsqueeze(0).repeat(cam2world_tosave.shape[0], 1, 1)
    )  # B, 3, 3
    intrinsics_tosave[:, 0, 0] = focal.detach().cpu()
    intrinsics_tosave[:, 1, 1] = focal.detach().cpu()
    intrinsics_tosave[:, 0, 2] = pp[:, 0]
    intrinsics_tosave[:, 1, 2] = pp[:, 1]
    os.makedirs(os.path.join(outdir, "depth"), exist_ok=True)
    os.makedirs(os.path.join(outdir, "conf"), exist_ok=True)
    os.makedirs(os.path.join(outdir, "color"), exist_ok=True)
    os.makedirs(os.path.join(outdir, "camera"), exist_ok=True)
    for f_id in range(len(pts3ds_self)):
        depth = depths_tosave[f_id].cpu().numpy()
        conf = conf_self_tosave[f_id].cpu().numpy()
        color = colors_tosave[f_id].cpu().numpy()
        c2w = cam2world_tosave[f_id].cpu().numpy()
        intrins = intrinsics_tosave[f_id].cpu().numpy()
        np.save(os.path.join(outdir, "depth", f"{f_id:06d}.npy"), depth)
        np.save(os.path.join(outdir, "conf", f"{f_id:06d}.npy"), conf)
        iio.imwrite(
            os.path.join(outdir, "color", f"{f_id:06d}.png"),
            (color * 255).astype(np.uint8),
        )
        np.savez(
            os.path.join(outdir, "camera", f"{f_id:06d}.npz"),
            pose=c2w,
            intrinsics=intrins,
        )
    return pts3ds_other, colors, conf_other, cam_dict

def run_inference(model, image_list, output_dir, device, min_conf_thr = 2.0, size = 512,
                  transparent_cams=False, cam_size=0.2, as_pointcloud = True, silent= True):
    """
    Execute the full inference and visualization pipeline.
    Args:
        args: Parsed command-line arguments.
    """
    # Set up the computation device.
    if device == "cuda" and not torch.cuda.is_available():
        print("CUDA not available. Switching to CPU.")
        device = "cpu"
    # Prepare image file paths.
    img_mask = [True] * len(image_list)
    # Prepare input views.
    print("Preparing input views...")
    views = prepare_input(img_paths=image_list, img_mask=img_mask, size=size, revisit=1, update=True)
    # Run inference.
    print("Running inference...")
    outputs, state_args = inference_cust3r(views, model, device)
    # Process outputs for visualization.
    print("Preparing output for visualization...")
    pts3ds_other, colors, conf, cam_dict = prepare_output(outputs, output_dir, 1, True)
    # Convert tensors to numpy arrays for visualization.
    pts3ds_to_vis = [p.reshape((-1, 3)).cpu().numpy() for p in pts3ds_other]
    colors_to_vis = [c.squeeze(0).cpu().numpy() for c in colors]
    focals = cam_dict["focal"]
    cams2world = cam_dict["cams2world"]
    msk = [(c > min_conf_thr).squeeze(0).cpu().numpy() for c in conf]
    return _convert_scene_output_to_glb(colors_to_vis, pts3ds_to_vis, msk, focals, cams2world, as_pointcloud=as_pointcloud,
                                        transparent_cams=transparent_cams, cam_size=cam_size, silent=silent)


In [1]:
device = 'cuda'
print(f"Loading model ...")
model = ARCroco3DStereo.from_pretrained("ckpts/cut3r_512_dpt_4_64.pth").to(device)
model.eval()
image_list = []
for filename in glob.glob('data/image-matching-challenge-2025/train/11/*.png'):
    image_list.append(filename)
trimesh_scene = run_inference(model, image_list, "outputs/cust/1111", device, 3.0)
del model
torch.cuda.empty_cache()

trimesh_scene.show()

Loading model ...


NameError: name 'ARCroco3DStereo' is not defined