In [1]:
import torch
from met3r import MEt3R
import glob
from matplotlib import pyplot as plt
import time


class Checker:
    def __init__(self, model_path, is_mast3r=False):
        self.metric = MEt3R(
            img_size=None, # Default. Set to `None` to use the input resolution on the fly!
            use_norm=True, # Default 
            feat_backbone="dino16", # Default 
            featup_weights="mhamilton723/FeatUp", # Default 
            dust3r_weights=model_path, # Default
            use_mast3r_dust3r=is_mast3r # Default. Set to `False` to use original DUSt3R. Make sure to also set the correct weights from huggingface.
        ).cuda()

    def __call__(self, inputs_B_V_C_H_W):
        # Prepare inputs of shape (batch, views, channels, height, width): views must be 2

        # Evaluate MEt3R
        score, *_ = self.metric(
            images=inputs_B_V_C_H_W, 
            return_overlap_mask=False, # Default 
            return_score_map=False, # Default 
            return_projections=False # Default 
        )

        # Should be between 0.25 - 0.35
        # return list of scores

        # Clear up GPU memory
        torch.cuda.empty_cache()
        
        return score.tolist()



In [2]:
ckr = Checker("chkpts/mast3r.pth", is_mast3r=True)
img_dirs = ["imgs/seq_bad/", "imgs/seq_good/"]


Using cache found in /users/bangya/.cache/torch/hub/mhamilton723_FeatUp_main
  from .autonotebook import tqdm as notebook_tqdm
Using cache found in /users/bangya/.cache/torch/hub/facebookresearch_dino_main


... loading model from chkpts/mast3r.pth


  ckpt = torch.load(model_path, map_location='cpu')


instantiating : AsymmetricMASt3R(enc_depth=24, dec_depth=12, enc_embed_dim=1024, dec_embed_dim=768, enc_num_heads=16, dec_num_heads=12, pos_embed='RoPE100',img_size=(512, 512), head_type='catmlp+dpt', output_mode='pts3d+desc24', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), patch_embed_cls='PatchEmbedDust3R', two_confs=True, desc_conf_mode=('exp', 0, inf), landscape_only=False)
<All keys matched successfully>


In [3]:
def check_dir(img_dir):
    img_files = glob.glob(img_dir + "*.png") + glob.glob(img_dir + "*.jpg")
    img_files.sort()

    # read imgs as list of torch tensors
    imgs_C_H_W = [torch.from_numpy(plt.imread(f)).permute(2, 0, 1).float() for f in img_files]
    # make width and height divisible by 16
    imgs_C_H_W = [img[:, 8:-8, 8:-8] for img in imgs_C_H_W]


    # convert to a big tensor with dimensions (batch, viewpair, channels, height, width)
    # the viewpair dimension must be 2, with fashion (f, f+1) for each frame, so total batch should be len(imgs_C_H_W)-1

    imgs_B_V_C_H_W = torch.stack([
        torch.stack([imgs_C_H_W[i], imgs_C_H_W[i + 1]]) for i in range(len(imgs_C_H_W) - 1)
    ]).cuda() # (batch, channels, views, height, width)

    # project rgb to -1 to 1 range
    imgs_B_V_C_H_W = imgs_B_V_C_H_W * 2 - 1 # (batch, channels, views, height, width)
    print(f"rgb range: {imgs_B_V_C_H_W.min():.2f} - {imgs_B_V_C_H_W.max():.2f}") # check the range of rgb values

    # check the shape
    print(f"Input shape: {imgs_B_V_C_H_W.shape}") # should be (batch, channels, 2, height, width)

    MAX_BS = 8
    # check the score
    scores = []

    start = time.time()
    for i in range(0, imgs_B_V_C_H_W.shape[0], MAX_BS):
        inputs_B_V_C_H_W = imgs_B_V_C_H_W[i:i + MAX_BS]
        score = ckr(inputs_B_V_C_H_W)
        scores.extend(score)

    print(f"Time taken: {time.time() - start:.2f} seconds") # measure time taken for the whole batch
    print(f"""Scores:
            avg: {sum(scores) / len(scores):.4f}
            min: {min(scores):.4f}
            max: {max(scores):.4f}
            std: {torch.std(torch.tensor(scores)).item():.4f}
            median: {torch.median(torch.tensor(scores)).item():.4f}""")


for img_dir in img_dirs:
    print(f"eval 3d consistency for: {img_dir} ")
    check_dir(img_dir)


eval 3d consistency for: imgs/seq_bad/ 
rgb range: -1.00 - 1.00
Input shape: torch.Size([32, 2, 3, 384, 624])


  with torch.cuda.amp.autocast(enabled=False):
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Time taken: 32.38 seconds
Scores:
            avg: 0.1407
            min: 0.1009
            max: 0.2124
            std: 0.0292
            median: 0.1296
eval 3d consistency for: imgs/seq_good/ 
rgb range: -1.00 - 1.00
Input shape: torch.Size([32, 2, 3, 384, 624])
Time taken: 31.88 seconds
Scores:
            avg: 0.1475
            min: 0.1043
            max: 0.2161
            std: 0.0270
            median: 0.1387
